Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 342 additions & 11 deletions core/query_parser.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,354 @@
from core.ast.node import QueryNode
from core.ast.node import (
Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'SubqueryNode' is not used.
Import of 'VarNode' is not used.
Import of 'VarSetNode' is not used.

Suggested change
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed this version of implementation does not consider subquery. We will consider it in the next iteration.

)
from core.ast.enums import JoinType, SortOrder
import mo_sql_parsing as mosql

class QueryParser:
@staticmethod
def normalize_to_list(value):
"""Normalize mo_sql_parsing output to a list format.

mo_sql_parsing returns:
- list when multiple items
- dict when single item with structure
- str when single simple value

This normalizes all cases to a list.
"""
if value is None:
return []
elif isinstance(value, list):
return value
elif isinstance(value, (dict, str)):
return [value]
else:
return [value]
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The docstring states that the function normalizes to list format and handles dict, str, and other types, but the final else clause (lines 27-28) will wrap ANY other type (including unexpected types) into a list without validation. This could hide bugs where unexpected types are passed. Consider either being more explicit about what "other types" are expected, or raising an error for truly unexpected types.

Suggested change
return [value]
raise TypeError(
f"normalize_to_list: Unexpected type {type(value).__name__} for value {value!r}. "
"Expected None, list, dict, or str."
)

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good suggestion. @HazelYuAhiru, let's adopt this suggestion.


def parse(self, query: str) -> QueryNode:
# Implement parsing logic using self.rules
pass

# [1] Call mo_sql_parser
# str -> Any (JSON)
mosql_ast = mosql.parse(query)

# [2] Our new code
# Any (JSON) -> AST (QueryNode)
self.aliases = {}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.aliases instance variable is initialized within the parse() method rather than in __init__. This means the aliases dictionary persists between parse calls and could lead to stale alias references affecting subsequent parses. Consider either: 1) initializing self.aliases = {} in a proper __init__ method, or 2) passing aliases as a parameter through the helper methods instead of using instance state.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point which I also noticed. If we use the instance member to store aliases, this instance is not safe for multi-threading. [1] One way to solve this issue is through a helper parameter being passed through call the following function calls through the function arguments, similar to what we use as the memo in the rewriter engine. [2] Or we can make sure the QueryParser is declared as non-thread-safe, and every place we need to use the parse function, we create a new instance of the parser, i.e., parser = QueryParser(), and if we use this approach, it makes more sense to init the self.aliases = {} in the __init__ function instead of this parse function. @HazelYuAhiru , please make a decision with coordination with @colinthebomb1 's PR, if he uses the same approach (some internal state shared by different functions as an instance-level state), we can use [2]. If he already used [1], you may change to use [1]. If you decide to use [2], please find a proper annotation tag in Python to say it is non-thread-safe.


select_clause = None
from_clause = None
where_clause = None
group_by_clause = None
having_clause = None
order_by_clause = None
limit_clause = None
offset_clause = None

if 'select' in mosql_ast:
select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select']))
if 'from' in mosql_ast:
from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from']))
if 'where' in mosql_ast:
where_clause = self.parse_where(mosql_ast['where'])
if 'groupby' in mosql_ast:
group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby']))
if 'having' in mosql_ast:
having_clause = self.parse_having(mosql_ast['having'])
if 'orderby' in mosql_ast:
order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby']))
if 'limit' in mosql_ast:
limit_clause = LimitNode(mosql_ast['limit'])
if 'offset' in mosql_ast:
offset_clause = OffsetNode(mosql_ast['offset'])

return QueryNode(
_select=select_clause,
_from=from_clause,
_where=where_clause,
_group_by=group_by_clause,
_having=having_clause,
_order_by=order_by_clause,
_limit=limit_clause,
_offset=offset_clause
)

def parse_select(self, select_list: list) -> SelectNode:
items = set()
for item in select_list:
if isinstance(item, dict) and 'value' in item:
expression = self.parse_expression(item['value'])
# Handle alias - set for any node that has alias attribute
if 'name' in item:
alias = item['name']
if hasattr(expression, 'alias'):
expression.alias = alias
self.aliases[alias] = expression

items.add(expression)
else:
# Handle direct expression (string, int, etc.)
expression = self.parse_expression(item)
items.add(expression)

return SelectNode(items)

def parse_from(self, from_list: list) -> FromNode:
sources = set()
left_source = None # Can be a table or the result of a previous join

for item in from_list:
# Check for JOIN first (before checking for 'value')
if isinstance(item, dict):
# Look for any join key
join_key = next((k for k in item.keys() if 'join' in k.lower()), None)

if join_key:
# This is a JOIN
if left_source is None:
raise ValueError("JOIN found without a left table")
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parse_from method raises a generic ValueError with message "JOIN found without a left table" without providing context about which join or what the input was. Consider including more diagnostic information in the error message, such as the join_key or item being processed, to aid debugging.

Suggested change
raise ValueError("JOIN found without a left table")
raise ValueError(f"JOIN found without a left table. join_key={join_key}, item={item}")

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good suggestion. Let's adopt it.


join_info = item[join_key]
# Handle both string and dict join_info
if isinstance(join_info, str):
table_name = join_info
alias = None
else:
table_name = join_info['value'] if isinstance(join_info, dict) else join_info
alias = join_info.get('name') if isinstance(join_info, dict) else None

right_table = TableNode(table_name, alias)
# Track table alias
if alias:
self.aliases[alias] = right_table

on_condition = None
if 'on' in item:
on_condition = self.parse_expression(item['on'])

# Create join node - left_source might be a table or a previous join
join_type = self.parse_join_type(join_key)
join_node = JoinNode(left_source, right_table, join_type, on_condition)
# The result of this JOIN becomes the new left source for potential next JOIN
left_source = join_node
Comment on lines +131 to +133
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JoinNode constructor expects _left_table: 'TableNode' as its first parameter (see core/ast/node.py:163), but left_source could be a JoinNode from a previous iteration (as indicated by the comment on line 132 and assignment on line 133). This will cause a type error when chaining multiple joins. The JoinNode signature should be updated to accept Node types, or the parsing logic should be restructured to handle join chains differently.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Should we change the constructor of JoinNode to allow either JoinNode or TableNode as _left_table?

elif 'value' in item:
# This is a table reference
table_name = item['value']
alias = item.get('name')
table_node = TableNode(table_name, alias)
# Track table alias
if alias:
self.aliases[alias] = table_node

if left_source is None:
# First table becomes the left source
left_source = table_node
else:
# Multiple tables without explicit JOIN (cross join)
sources.add(table_node)
elif isinstance(item, str):
# Simple string table name
table_node = TableNode(item)
if left_source is None:
left_source = table_node
else:
sources.add(table_node)

# Add the final left source (which might be a single table or chain of joins)
if left_source is not None:
sources.add(left_source)

return FromNode(sources)
Comment on lines +97 to +161
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FromNode constructor expects a list but receives a set here. Throughout the parse_from method, sources is maintained as a set (initialized on line 97, with add() operations on lines 148, 155, 159), but FromNode.init signature expects _sources: List['Node'] (see core/ast/node.py:195). This type inconsistency needs to be fixed by making sources a list throughout this method.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Let's fix it.


def parse_where(self, where_dict: dict) -> WhereNode:
predicates = set()
predicates.add(self.parse_expression(where_dict))
return WhereNode(predicates)
Comment on lines +163 to +166
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WhereNode constructor expects a list but receives a set here. The predicates variable is initialized as a set (line 164) and set operations are used (line 165), but WhereNode.init signature expects _predicates: List['Node'] (see core/ast/node.py:201). Change to use a list instead of a set.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.


def parse_group_by(self, group_by_list: list) -> GroupByNode:
items = []
for item in group_by_list:
if isinstance(item, dict) and 'value' in item:
expr = self.parse_expression(item['value'])
# Resolve aliases
expr = self.resolve_aliases(expr)
items.append(expr)
else:
# Handle direct expression (string, int, etc.)
expr = self.parse_expression(item)
expr = self.resolve_aliases(expr)
items.append(expr)

def format(self, query: QueryNode) -> str:
# Implement formatting logic to convert AST back to SQL string
pass
return GroupByNode(items)

def parse_having(self, having_dict: dict) -> HavingNode:
predicates = set()
expr = self.parse_expression(having_dict)
# Check if this expression references an aliased function from SELECT
expr = self.resolve_aliases(expr)

predicates.add(expr)
Comment on lines +185 to +190
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HavingNode constructor expects a list but receives a set here. The predicates variable is initialized as a set (line 185) and set operations are used (line 190), but HavingNode.init signature expects _predicates: List['Node'] (see core/ast/node.py:213). Change to use a list instead of a set.

Suggested change
predicates = set()
expr = self.parse_expression(having_dict)
# Check if this expression references an aliased function from SELECT
expr = self.resolve_aliases(expr)
predicates.add(expr)
predicates = []
expr = self.parse_expression(having_dict)
# Check if this expression references an aliased function from SELECT
expr = self.resolve_aliases(expr)
predicates.append(expr)

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same


# [1] Our new code
# AST (QueryNode) -> JSON
return HavingNode(predicates)

def parse_order_by(self, order_by_list: list) -> OrderByNode:
items = []
for item in order_by_list:
if isinstance(item, dict) and 'value' in item:
value = item['value']
# Check if this is an alias reference
if isinstance(value, str) and value in self.aliases:
column = self.aliases[value]
else:
# Parse normally for other cases
column = self.parse_expression(value)

# Get sort order (default is ASC)
sort_order = SortOrder.ASC
if 'sort' in item:
sort_str = item['sort'].upper()
if sort_str == 'DESC':
sort_order = SortOrder.DESC

# Wrap in OrderByItemNode
order_by_item = OrderByItemNode(column, sort_order)
items.append(order_by_item)
else:
# Handle direct expression (string, int, etc.)
column = self.parse_expression(item)
order_by_item = OrderByItemNode(column, SortOrder.ASC)
items.append(order_by_item)

# [2] Call mo_sql_format
# Any (JSON) -> str
return OrderByNode(items)

def resolve_aliases(self, expr: Node) -> Node:
if isinstance(expr, OperatorNode):
# Recursively resolve aliases in operator operands
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
Comment on lines +227 to +229
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the OperatorNode is created with unary operators (e.g., 'NOT'), only one operand is passed (line 309). The resolve_aliases method assumes binary operators and always accesses children[0] and children[1] (lines 227-228). This will cause an IndexError when resolving aliases for unary operators. Add a check for the number of children before accessing indices.

Suggested change
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
if len(expr.children) == 1:
child = self.resolve_aliases(expr.children[0])
return OperatorNode(child, expr.name)
elif len(expr.children) == 2:
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
else:
# Unexpected number of children; return as is
return expr

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. But I doubt if we can resolve it in this PR since we don't have such test cases. Let's push this fix to later PRs where we introduce more test cases.

elif isinstance(expr, FunctionNode):
# Check if this function matches an aliased function from SELECT
if expr.alias is None:
for alias, aliased_expr in self.aliases.items():
if isinstance(aliased_expr, FunctionNode):
if (expr.name == aliased_expr.name and
len(expr.children) == len(aliased_expr.children) and
all(expr.children[i] == aliased_expr.children[i]
for i in range(len(expr.children)))):
# This function matches an aliased one, use the alias
expr.alias = alias
break
return expr
elif isinstance(expr, ColumnNode):
# Check if this column matches an aliased column from SELECT
if expr.alias is None:
for alias, aliased_expr in self.aliases.items():
if isinstance(aliased_expr, ColumnNode):
if (expr.name == aliased_expr.name and
expr.parent_alias == aliased_expr.parent_alias):
# This column matches an aliased one, use the alias
expr.alias = alias
break
return expr
else:
return expr
Comment on lines +224 to +255
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The alias resolution in the HAVING clause iterates through all aliases to find matching FunctionNode or ColumnNode instances (lines 233-241, 246-252). For large SELECT clauses with many aliases, this could be inefficient. Consider optimizing by creating a separate lookup structure or only checking relevant aliases.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is OK for now.


def parse_expression(self, expr) -> Node:
if isinstance(expr, str):
# Column reference
if '.' in expr:
parts = expr.split('.', 1)
return ColumnNode(parts[1], _parent_alias=parts[0])
return ColumnNode(expr)

if isinstance(expr, (int, float, bool)):
return LiteralNode(expr)

if isinstance(expr, list):
# List literals (for IN clauses) - convert to tuple for hashability
parsed = tuple(self.parse_expression(item) for item in expr)
return LiteralNode(parsed)
Comment on lines +269 to +271
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LiteralNode constructor expects _value: str|int|float|bool|datetime|None (see core/ast/node.py:96), but a tuple is being passed here when parsing list expressions. Tuples are not in the allowed types for LiteralNode values. This could cause type errors or unexpected behavior. Consider either updating LiteralNode to support tuple types or handling list literals differently.

Suggested change
# List literals (for IN clauses) - convert to tuple for hashability
parsed = tuple(self.parse_expression(item) for item in expr)
return LiteralNode(parsed)
# List literals (for IN clauses)
parsed = [self.parse_expression(item) for item in expr]
return parsed

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid point. Can we fix it?


if isinstance(expr, dict):
# Special cases first
if 'all_columns' in expr:
return ColumnNode('*')
if 'literal' in expr:
return LiteralNode(expr['literal'])

# Skip metadata keys
skip_keys = {'value', 'name', 'on', 'sort'}

# Find the operator/function key
for key in expr.keys():
if key in skip_keys:
continue

value = expr[key]
op_name = self.normalize_operator_name(key)

# Pattern 1: Binary/N-ary operator with list of operands
if isinstance(value, list):
if len(value) == 0:
return LiteralNode(None)
if len(value) == 1:
return self.parse_expression(value[0])

# Parse all operands
operands = [self.parse_expression(v) for v in value]

# Chain multiple operands with the same operator
result = operands[0]
for operand in operands[1:]:
result = OperatorNode(result, op_name, operand)
return result

# Pattern 2: Unary operator
if key == 'not':
return OperatorNode(self.parse_expression(value), 'NOT')

# Pattern 3: Function call
# Special case: COUNT(*), SUM(*), etc.
if value == '*':
return FunctionNode(op_name, [ColumnNode('*')])
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] In the test's expected AST, count_star is created with _args as a keyword argument, but in the parser (line 314), FunctionNode is called with positional arguments. While this works since _args is the second positional parameter after _name, using keyword arguments consistently would improve code clarity and match the test's style. Consider changing line 314 to: return FunctionNode(op_name, _args=[ColumnNode('*')])

Suggested change
return FunctionNode(op_name, [ColumnNode('*')])
return FunctionNode(op_name, _args=[ColumnNode('*')])

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check. I am not sure.


# Regular function
args = [self.parse_expression(value)]
return FunctionNode(op_name, args)

# No valid key found
import json
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The json module is imported inline (line 321) rather than at the top of the file. Move this import to the top of the file with other imports to follow Python best practices (PEP 8).

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix it.

return LiteralNode(json.dumps(expr, sort_keys=True))

# Other types
return LiteralNode(expr)

@staticmethod
def normalize_operator_name(key: str) -> str:
"""Convert mo_sql_parsing operator keys to SQL operator names."""
mapping = {
'eq': '=', 'neq': '!=', 'ne': '!=',
'gt': '>', 'gte': '>=',
'lt': '<', 'lte': '<=',
'and': 'AND', 'or': 'OR',
}
return mapping.get(key.lower(), key.upper())

@staticmethod
def parse_join_type(join_key: str) -> JoinType:
"""Extract JoinType from mo_sql_parsing join key."""
key_lower = join_key.lower().replace(' ', '_')

if 'inner' in key_lower:
return JoinType.INNER
elif 'left' in key_lower:
return JoinType.LEFT
elif 'right' in key_lower:
return JoinType.RIGHT
elif 'full' in key_lower:
return JoinType.FULL
elif 'cross' in key_lower:
return JoinType.CROSS

return JoinType.INNER # By default
Loading