diff --git a/core/ast/__init__.py b/core/ast/__init__.py index 804646a..1474504 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -11,8 +11,8 @@ SubqueryNode, ColumnNode, LiteralNode, - VarNode, - VarSetNode, + ElementVariableNode, + SetVariableNode, OperatorNode, FunctionNode, SelectNode, @@ -22,6 +22,7 @@ GroupByNode, HavingNode, OrderByNode, + OrderByItemNode, LimitNode, OffsetNode, QueryNode @@ -34,8 +35,8 @@ 'SubqueryNode', 'ColumnNode', 'LiteralNode', - 'VarNode', - 'VarSetNode', + 'ElementVariableNode', + 'SetVariableNode', 'OperatorNode', 'FunctionNode', 'SelectNode', diff --git a/core/ast/node.py b/core/ast/node.py index 52e505d..03fb78d 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -168,19 +168,35 @@ def __eq__(self, other): def __hash__(self): return hash((super().__hash__(), self.value, self.unit)) -class VarNode(Node): - """VarSQL variable node""" +class ElementVariableNode(Node): + """Rule element variable ```` (see ``VarType.ElementVariable`` in rule_parser_v2).""" def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VAR, **kwargs) self.name = _name + def __eq__(self, other): + if not isinstance(other, ElementVariableNode): + return False + return super().__eq__(other) and self.name == other.name + + def __hash__(self): + return hash((super().__hash__(), self.name)) + -class VarSetNode(Node): - """VarSQL variable set node""" +class SetVariableNode(Node): + """Rule set variable ``<>`` (see ``VarType.SetVariable`` in rule_parser_v2).""" def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VARSET, **kwargs) self.name = _name + def __eq__(self, other): + if not isinstance(other, SetVariableNode): + return False + return super().__eq__(other) and self.name == other.name + + def __hash__(self): + return hash((super().__hash__(), self.name)) + class OperatorNode(Node): """Operator node""" diff --git a/core/query_parser.py b/core/query_parser.py index 19494bb..8e30c59 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -4,9 +4,9 @@ CaseNode, WhenThenNode, OperatorNode, UnaryOperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, - VarNode, VarSetNode, JoinNode, ListNode + ElementVariableNode, SetVariableNode, JoinNode, ListNode ) -# TODO: implement VarNode, VarSetNode +# TODO: implement ElementVariableNode, SetVariableNode from core.ast.enums import JoinType, SortOrder import mo_sql_parsing as mosql import json diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py new file mode 100644 index 0000000..8bf392d --- /dev/null +++ b/core/rule_parser_v2.py @@ -0,0 +1,535 @@ +# Rule parser v2: self-contained rule preprocessing (duplicated from v1 on purpose), then +# QueryParser and ElementVariableNode / SetVariableNode rule AST via parse(). + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Tuple + +from core.ast.enums import NodeType +from core.ast.node import ( + CaseNode, + ColumnNode, + FromNode, + FunctionNode, + GroupByNode, + HavingNode, + IntervalNode, + JoinNode, + LiteralNode, + LimitNode, + ListNode, + Node, + OffsetNode, + OperatorNode, + OrderByItemNode, + OrderByNode, + QueryNode, + SelectNode, + SubqueryNode, + TableNode, + UnaryOperatorNode, + ElementVariableNode, + SetVariableNode, + WhenThenNode, + WhereNode, +) +from core.query_parser import QueryParser + + +# Variable types (v2 naming; same placeholder syntax as v1). +# AST: ```` → ElementVariableNode, ``<>`` → SetVariableNode. +# +class VarType(Enum): + ElementVariable = 1 # → ElementVariableNode in rule AST + SetVariable = 2 # <> → SetVariableNode in rule AST + + +# Placeholder markers and internal token prefixes for rule variables. +# +VarTypesInfo = { + VarType.ElementVariable: { + "markerStart": "<", + "markerEnd": ">", + "internalBase": "EV", + "externalBase": "x", + }, + VarType.SetVariable: { + "markerStart": "<<", + "markerEnd": ">>", + "internalBase": "SV", + "externalBase": "y", + }, +} + + +# Scope of pattern/rewrite fragment (same as v1). +# +class Scope(Enum): + SELECT = 1 + FROM = 2 + WHERE = 3 + CONDITION = 4 + + +# Partial-SQL prefix for extendToFullSQL (same as v1). +# +ScopeExtension = { + Scope.CONDITION: "SELECT * FROM t WHERE ", + Scope.WHERE: "SELECT * FROM t ", + Scope.FROM: "SELECT * ", + Scope.SELECT: "", +} + + +# Result of RuleParserV2.parse: rule AST with external variable names restored. +# +@dataclass(frozen=True) +class RuleParseResult: + pattern_ast: Node + rewrite_ast: Node + mapping: Dict[str, str] + + +class RuleParserV2: + + # mosql parsing can report mismatching brackets at a confusing index; detect common + # wrong delimiters around rule variables (same logic as v1 RuleParser.find_malformed_brackets). + # + @staticmethod + def find_malformed_brackets(pattern: str) -> int: + CommonMistakeVarTypesInfo = { + "markerStart": [r"\(", r"\{", r"\["], + "markerEnd": [r"\)", r"\}", r"\]"], + } + + for i in range(len(CommonMistakeVarTypesInfo["markerStart"])): + regexPatternVarStart = ( + CommonMistakeVarTypesInfo["markerStart"][i] + + r"(\w+)" + + VarTypesInfo[VarType.ElementVariable]["markerEnd"] + ) + regexPatternVarEnd = ( + VarTypesInfo[VarType.ElementVariable]["markerStart"] + + r"(\w+)" + + CommonMistakeVarTypesInfo["markerEnd"][i] + ) + + varStart = re.search(regexPatternVarStart, pattern) + varEnd = re.search(regexPatternVarEnd, pattern) + + if varStart: + return varStart.start() + if varEnd: + return varEnd.start() + + return -1 + + # Extend pattern/rewrite fragment to full SQL (same as v1 RuleParser.extendToFullSQL). + # + @staticmethod + def extendToFullSQL(partialSQL: str) -> Tuple[str, Scope]: + # Special case: condition on subquery + # e.g., group_users.group_id IN (SELECT ... ) + # Remove subquery in (*) before checking SELECT / FROM / WHERE. + # + sanitisedPartialSQL = re.sub(r"\(.*\)", "(x)", partialSQL) + + # case-1: no SELECT and no FROM and no WHERE + if ( + "SELECT" not in sanitisedPartialSQL.upper() + and "FROM" not in sanitisedPartialSQL.upper() + and "WHERE" not in sanitisedPartialSQL.upper() + ): + scope = Scope.CONDITION + # case-2: no SELECT and no FROM but has WHERE + elif ( + "SELECT" not in sanitisedPartialSQL.upper() + and "FROM" not in sanitisedPartialSQL.upper() + ): + scope = Scope.WHERE + # case-3: no SELECT but has FROM + elif "SELECT" not in sanitisedPartialSQL.upper(): + scope = Scope.FROM + # case-4: has SELECT (and typically FROM) + else: + scope = Scope.SELECT + + partialSQL = ScopeExtension[scope] + partialSQL + return partialSQL, scope + + # Replace user-facing rule variables with internal tokens. + # e.g., ==> EV001, <> ==> SV001 + # + @staticmethod + def replaceVars(pattern: str, rewrite: str) -> Tuple[str, str, Dict[str, str]]: + + def _replace_one_var_type( + pattern: str, rewrite: str, varType: VarType, mapping: Dict[str, str] + ) -> Tuple[str, str]: + regexPattern = ( + VarTypesInfo[varType]["markerStart"] + + r"(\w+)" + + VarTypesInfo[varType]["markerEnd"] + ) + found = re.findall(regexPattern, pattern) + varInternalBase = VarTypesInfo[varType]["internalBase"] + varInternalCount = 1 + for var in found: + if var not in mapping: + specificRegexPattern = ( + VarTypesInfo[varType]["markerStart"] + + var + + VarTypesInfo[varType]["markerEnd"] + ) + varInternal = varInternalBase + str(varInternalCount).zfill(3) + varInternalCount += 1 + pattern = re.sub(specificRegexPattern, varInternal, pattern) + rewrite = re.sub(specificRegexPattern, varInternal, rewrite) + mapping[var] = varInternal + return pattern, rewrite + + mapping: Dict[str, str] = {} + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.SetVariable, mapping) + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.ElementVariable, mapping) + return pattern, rewrite, mapping + + # parse a rule into project AST nodes (ElementVariableNode / SetVariableNode for rule variables) + # + @staticmethod + def parse(pattern: str, rewrite: str) -> RuleParseResult: + + # 1. Replace user-faced variables and variable lists + # with internal representations + # + pattern_sql, rewrite_sql, mapping = RuleParserV2.replaceVars(pattern, rewrite) + + # 2. Extend partial SQL statement to full SQL statement + # for the sake of sql parser + # + pattern_full, pattern_scope = RuleParserV2.extendToFullSQL(pattern_sql) + rewrite_full, rewrite_scope = RuleParserV2.extendToFullSQL(rewrite_sql) + + # 3. Parse extended full SQL statement into AST (QueryParser) + # + qparser = QueryParser() + pattern_query = qparser.parse(pattern_full) + rewrite_query = qparser.parse(rewrite_full) + + # 4. Map internal tokens (EV00x / SV00x) to ElementVariableNode / SetVariableNode across the full query AST + # + internal_to_external = {internal: external for external, internal in mapping.items()} + pattern_after_vars = RuleParserV2._substitute_rule_vars(pattern_query, internal_to_external) + rewrite_after_vars = RuleParserV2._substitute_rule_vars(rewrite_query, internal_to_external) + + # 5. Reduce to the rule fragment for the inferred scope (CONDITION / WHERE / FROM / SELECT) + # + pattern_ast = RuleParserV2._extract_rule_fragment(pattern_after_vars, pattern_scope) + rewrite_ast = RuleParserV2._extract_rule_fragment(rewrite_after_vars, rewrite_scope) + + # 6. Return AST result + mapping + # + return RuleParseResult( + pattern_ast=pattern_ast, + rewrite_ast=rewrite_ast, + mapping=mapping, + ) + + # Find first child of query with given clause type (SELECT, FROM, WHERE, ...). + # + @staticmethod + def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: + for child in query.children: + if child.type == clause_type: + return child + return None + + # Apply internal_to_external across an entire parsed query (EV00x / SV00x -> ElementVariableNode, etc.). + # + @staticmethod + def _substitute_rule_vars( + query: QueryNode, internal_to_external: Dict[str, str] + ) -> QueryNode: + out = RuleParserV2._as_rule_ast(query, internal_to_external) + if not isinstance(out, QueryNode): + raise TypeError("expected QueryNode after substituting rule variables on full query") + return out + + # Slice a fully substituted query to the rule fragment for this scope (no variable-node pass). + # + @staticmethod + def _extract_rule_fragment(query: QueryNode, scope: Scope) -> Node: + frm = RuleParserV2._get_clause(query, NodeType.FROM) + wh = RuleParserV2._get_clause(query, NodeType.WHERE) + gb = RuleParserV2._get_clause(query, NodeType.GROUP_BY) + hav = RuleParserV2._get_clause(query, NodeType.HAVING) + ob = RuleParserV2._get_clause(query, NodeType.ORDER_BY) + lim = RuleParserV2._get_clause(query, NodeType.LIMIT) + off = RuleParserV2._get_clause(query, NodeType.OFFSET) + + # case CONDITION: predicate only + # + if scope == Scope.CONDITION: + if wh is None or not list(wh.children): + raise ValueError("CONDITION scope requires a WHERE predicate") + return list(wh.children)[0] + + # case WHERE: query without select/from lists + # + if scope == Scope.WHERE: + return QueryNode( + _select=None, + _from=None, + _where=wh, + _group_by=gb, + _having=hav, + _order_by=ob, + _limit=lim, + _offset=off, + ) + + # case FROM: from + following clauses, no select list + # + if scope == Scope.FROM: + return QueryNode( + _select=None, + _from=frm, + _where=wh, + _group_by=gb, + _having=hav, + _order_by=ob, + _limit=lim, + _offset=off, + ) + + # case SELECT: full query + # + return query + + # Run ElementVariableNode / SetVariableNode substitution on one subtree (None stays None). + # + @staticmethod + def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> Optional[Node]: + if node is None: + return None + return RuleParserV2._substitute_placeholders(node, internal_to_external) + + # Build ElementVariableNode or SetVariableNode from internal token prefix (EV... vs SV...). + # + @staticmethod + def _placeholder_varnode(internal_token: str, external_name: str) -> Node: + if internal_token.startswith(VarTypesInfo[VarType.SetVariable]["internalBase"]): + return SetVariableNode(external_name) + return ElementVariableNode(external_name) + + # Structural recursion: replace internal identifiers with ElementVariableNode / SetVariableNode where appropriate. + # + @staticmethod + def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: + def _replace_internal_in_string(s: str) -> str: + # Replace EV00x / SV00x occurrences inside strings (e.g., '%EV001%'). + out = s + for internal, external in rev.items(): + out = out.replace(internal, external) + return out + + if node.type == NodeType.COLUMN: + col = node + if not isinstance(col, ColumnNode): + return node + pa = col.parent_alias + nm = col.name + new_alias = _replace_internal_in_string(col.alias) if isinstance(col.alias, str) else col.alias + new_pa = _replace_internal_in_string(pa) if isinstance(pa, str) else pa + if pa is None and nm in rev: + return RuleParserV2._placeholder_varnode(nm, rev[nm]) + if pa is not None and pa in rev and nm in rev: + return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=rev[pa]) + if pa is not None and pa in rev: + return ColumnNode(nm, _alias=new_alias, _parent_alias=rev[pa]) + if pa is not None and nm in rev: + return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=new_pa) + return ColumnNode(nm, _alias=new_alias, _parent_alias=new_pa) + + if node.type == NodeType.TABLE: + t = node + if not isinstance(t, TableNode): + return node + new_name = rev.get(t.name, t.name) if isinstance(t.name, str) else t.name + if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: + new_alias = rev[t.alias] + else: + new_alias = t.alias + return TableNode(new_name, new_alias) + + if node.type == NodeType.LITERAL: + lit = node + if not isinstance(lit, LiteralNode): + return node + if isinstance(lit.value, str): + return LiteralNode(_replace_internal_in_string(lit.value)) + return LiteralNode(lit.value) + + if node.type == NodeType.QUERY: + q = node + if not isinstance(q, QueryNode): + return node + return QueryNode( + _select=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.SELECT), rev), + _from=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.FROM), rev), + _where=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.WHERE), rev), + _group_by=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.GROUP_BY), rev), + _having=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.HAVING), rev), + _order_by=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.ORDER_BY), rev), + _limit=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.LIMIT), rev), + _offset=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.OFFSET), rev), + ) + + if node.type == NodeType.SELECT: + sn = node + if not isinstance(sn, SelectNode): + return node + items: List[Node] = [] + don = sn.distinct_on + for ch in sn.children: + if don is not None and ch is don: + continue + items.append(RuleParserV2._substitute_placeholders(ch, rev)) + new_don = ( + RuleParserV2._substitute_placeholders(don, rev) if don is not None else None + ) + return SelectNode(items, _distinct=sn.distinct, _distinct_on=new_don) + + if node.type == NodeType.FROM: + fn = node + if not isinstance(fn, FromNode): + return node + return FromNode([RuleParserV2._substitute_placeholders(c, rev) for c in fn.children]) + + if node.type == NodeType.WHERE: + wn = node + if not isinstance(wn, WhereNode): + return node + return WhereNode([RuleParserV2._substitute_placeholders(c, rev) for c in wn.children]) + + if node.type == NodeType.GROUP_BY: + g = node + if not isinstance(g, GroupByNode): + return node + return GroupByNode([RuleParserV2._substitute_placeholders(c, rev) for c in g.children]) + + if node.type == NodeType.HAVING: + h = node + if not isinstance(h, HavingNode): + return node + return HavingNode([RuleParserV2._substitute_placeholders(c, rev) for c in h.children]) + + if node.type == NodeType.ORDER_BY: + o = node + if not isinstance(o, OrderByNode): + return node + return OrderByNode([RuleParserV2._substitute_placeholders(c, rev) for c in o.children]) + + if node.type == NodeType.LIMIT: + lim = node + if not isinstance(lim, LimitNode): + return node + if isinstance(lim.limit, str): + return LimitNode(_replace_internal_in_string(lim.limit)) + return LimitNode(lim.limit) + + if node.type == NodeType.OFFSET: + off = node + if not isinstance(off, OffsetNode): + return node + if isinstance(off.offset, str): + return OffsetNode(_replace_internal_in_string(off.offset)) + return OffsetNode(off.offset) + + if node.type == NodeType.ORDER_BY_ITEM: + oi = node + if not isinstance(oi, OrderByItemNode): + return node + inner = list(oi.children)[0] + return OrderByItemNode(RuleParserV2._substitute_placeholders(inner, rev), oi.sort) + + if node.type == NodeType.JOIN: + j = node + if not isinstance(j, JoinNode): + return node + ch = list(j.children) + left = RuleParserV2._substitute_placeholders(ch[0], rev) + right = RuleParserV2._substitute_placeholders(ch[1], rev) + on_expr = ( + RuleParserV2._substitute_placeholders(ch[2], rev) if len(ch) > 2 else None + ) + return JoinNode(left, right, j.join_type, on_expr) + + if node.type == NodeType.SUBQUERY: + sq = node + if not isinstance(sq, SubqueryNode): + return node + inner = list(sq.children)[0] + alias = _replace_internal_in_string(sq.alias) if isinstance(sq.alias, str) else sq.alias + return SubqueryNode(RuleParserV2._substitute_placeholders(inner, rev), alias) + + if node.type == NodeType.FUNCTION: + f = node + if not isinstance(f, FunctionNode): + return node + new_args = [RuleParserV2._substitute_placeholders(a, rev) for a in f.children] + alias = _replace_internal_in_string(f.alias) if isinstance(f.alias, str) else f.alias + return FunctionNode(f.name, _args=new_args, _alias=alias) + + if node.type == NodeType.LIST: + ln = node + if not isinstance(ln, ListNode): + return node + return ListNode([RuleParserV2._substitute_placeholders(c, rev) for c in ln.children]) + + if node.type == NodeType.INTERVAL: + inv = node + if not isinstance(inv, IntervalNode): + return node + if isinstance(inv.value, Node): + return IntervalNode( + RuleParserV2._substitute_placeholders(inv.value, rev), + inv.unit, # type: ignore[arg-type] + ) + return IntervalNode(inv.value, inv.unit) # type: ignore[arg-type] + + if node.type == NodeType.CASE: + cn = node + if not isinstance(cn, CaseNode): + return node + new_whens: List[WhenThenNode] = [] + for wt in cn.whens: + new_whens.append( + WhenThenNode( + RuleParserV2._substitute_placeholders(wt.when, rev), + RuleParserV2._substitute_placeholders(wt.then, rev), + ) + ) + new_else = ( + RuleParserV2._substitute_placeholders(cn.else_val, rev) if cn.else_val else None + ) + return CaseNode(new_whens, new_else) + + if node.type == NodeType.OPERATOR: + if isinstance(node, UnaryOperatorNode): + op = node + inner = list(op.children)[0] if op.children else op.operand + return UnaryOperatorNode(RuleParserV2._substitute_placeholders(inner, rev), op.name) + op = node + ch = list(op.children) + if len(ch) == 1: + return OperatorNode(RuleParserV2._substitute_placeholders(ch[0], rev), op.name) + return OperatorNode( + RuleParserV2._substitute_placeholders(ch[0], rev), + op.name, + RuleParserV2._substitute_placeholders(ch[1], rev), + ) + + return node diff --git a/tests/ast_util.py b/tests/ast_util.py index 274e074..07a3f54 100644 --- a/tests/ast_util.py +++ b/tests/ast_util.py @@ -7,7 +7,7 @@ Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode, - VarNode, VarSetNode + ElementVariableNode, SetVariableNode ) @@ -222,8 +222,8 @@ def _node_to_string(node: Node, indent: int = 0) -> str: for line in child_lines: result.append(line) - elif isinstance(node, (VarNode, VarSetNode)): - # VarNode/VarSetNode: VarSQL variable, display as "var: name" or "varset: name" + elif isinstance(node, (ElementVariableNode, SetVariableNode)): + # ElementVariableNode / SetVariableNode: rule variables ( / <>) result.append(f"{prefix}{node_type}: {node.name}") else: diff --git a/tests/test_ast.py b/tests/test_ast.py index deb6b2b..f9932d6 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,5 +1,5 @@ from core.ast.node import ( - TableNode, ColumnNode, LiteralNode, VarNode, VarSetNode, + TableNode, ColumnNode, LiteralNode, ElementVariableNode, SetVariableNode, OperatorNode, UnaryOperatorNode, FunctionNode, SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, QueryNode ) @@ -42,9 +42,9 @@ def test_operand_nodes(): print(f" {null_literal.value} -> Type: {null_literal.type}") # Test VarSQL nodes - var_table = VarNode("V001") - var_column = VarNode("V002") - var_set = VarSetNode("VS001") + var_table = ElementVariableNode("V001") + var_column = ElementVariableNode("V002") + var_set = SetVariableNode("VS001") print(f"\nVarSQL nodes:") print(f" Variable {var_table.name} -> Type: {var_table.type}") @@ -243,11 +243,11 @@ def test_varsql_pattern_matching(): print("="*50) # Pattern: SELECT V1 FROM V2 WHERE V3 op V4 - var_select = VarNode("V1") # Any select item - var_table = VarNode("V2") # Any table - var_left = VarNode("V3") # Left operand of condition - var_op = VarNode("OP") # Any operator - var_right = VarNode("V4") # Right operand of condition + var_select = ElementVariableNode("V1") # Any select item + var_table = ElementVariableNode("V2") # Any table + var_left = ElementVariableNode("V3") # Left operand of condition + var_op = ElementVariableNode("OP") # Any operator + var_right = ElementVariableNode("V4") # Right operand of condition # Build pattern query pattern_select = SelectNode({var_select}) @@ -268,7 +268,7 @@ def test_varsql_pattern_matching(): print(f" Total pattern variables: 4 (V1, V2, V3, V4)") # Test VarSet for multiple columns - var_columns = VarSetNode("COLS") + var_columns = SetVariableNode("COLS") multi_select = SelectNode({var_columns}) print(f"\nVarSet pattern for multiple columns:") print(f" VarSet {var_columns.name} can match multiple SELECT items") diff --git a/tests/test_rule_parser.py b/tests/test_rule_parser.py index 72596cf..aa6dcee 100644 --- a/tests/test_rule_parser.py +++ b/tests/test_rule_parser.py @@ -1,5 +1,4 @@ -from core.rule_parser import RuleParser -from core.rule_parser import Scope +from core.rule_parser import RuleParser, Scope def test_extendToFullSQL(): @@ -151,63 +150,50 @@ def test_parse(): assert rewrite_json == internal_rule['rewrite_json'] -#incorrect brackets def test_brackets_1(): - - pattern = '''WHERE 11 + pattern = '''WHERE 11 AND a <= 11 ''' - - index = RuleParser.find_malformed_brackets(pattern) - assert index == 6 + index = RuleParser.find_malformed_brackets(pattern) + assert index == 6 + - #incorrect brackets - def test_brackets_2(): - +def test_brackets_2(): pattern = '''WHERE 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 -#incorrect brackets + def test_parse_validator_3(): - - pattern = '''WHERE 11 + pattern = '''WHERE 11 AND a <= 11 ''' + index = RuleParser.find_malformed_brackets(pattern) + assert index == 6 - index = RuleParser.find_malformed_brackets(pattern) - assert index == 6 -#incorrect brackets - def test_parse_validator_4(): - +def test_parse_validator_4(): pattern = '''WHERE [x> > 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 -#incorrect brackets - def test_parse_validator_5(): - +def test_parse_validator_5(): pattern = '''WHERE (x> > 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 - -#incorrect brackets - def test_parse_validator_6(): - + + +def test_parse_validator_6(): pattern = '''WHERE {x> > 11 AND a <= 11 ''' index = RuleParser.find_malformed_brackets(pattern) assert index == 6 - + diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py new file mode 100644 index 0000000..a25ce79 --- /dev/null +++ b/tests/test_rule_parser_v2.py @@ -0,0 +1,774 @@ +from __future__ import annotations + +import re +from typing import Iterator, List, Optional + +import pytest + +from core.ast.enums import NodeType +from core.ast.node import ( + CaseNode, + ColumnNode, + DataTypeNode, + FromNode, + FunctionNode, + GroupByNode, + HavingNode, + JoinNode, + LimitNode, + ListNode, + LiteralNode, + Node, + OffsetNode, + OperatorNode, + OrderByItemNode, + OrderByNode, + QueryNode, + SelectNode, + SubqueryNode, + TableNode, + UnaryOperatorNode, + ElementVariableNode, + SetVariableNode, + WhenThenNode, + WhereNode, +) +from core.rule_parser_v2 import RuleParseResult, RuleParserV2, Scope, VarType, VarTypesInfo +from data.rules import rules as RULES_CATALOG + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helpers +# ═══════════════════════════════════════════════════════════════════════════════ + +_TOKEN_RE = re.compile(r"^(EV|SV)\d{3}$") + + +def _walk(node: Optional[Node]) -> Iterator[Node]: + """Depth-first walk of the AST.""" + if node is None: + return + yield node + ch = getattr(node, "children", None) + if ch: + for child in ch: + yield from _walk(child) + + +def _walk_var_names(node: Optional[Node]) -> Iterator[str]: + """Yield names of all ElementVariableNode / SetVariableNode in the tree.""" + for n in _walk(node): + if isinstance(n, (ElementVariableNode, SetVariableNode)): + yield n.name + + +def _find_first(node: Optional[Node], cls: type) -> Optional[Node]: + """Find first node of given type in the tree.""" + for n in _walk(node): + if isinstance(n, cls): + return n + return None + + +def _find_all(node: Optional[Node], cls: type) -> List[Node]: + """Find all nodes of given type in the tree.""" + return [n for n in _walk(node) if isinstance(n, cls)] + + +def _assert_varnodes_declared(result: RuleParseResult) -> None: + """Every ElementVariableNode / SetVariableNode must use an external name in ``mapping``.""" + keys = set(result.mapping.keys()) + for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: + for name in _walk_var_names(tree): + assert name in keys, ( + f"{tree_label} AST has variable node {name!r} but mapping keys are {sorted(keys)}" + ) + + +def _assert_no_internal_tokens(result: RuleParseResult) -> None: + """No EV00x / SV00x tokens should survive in identifier-bearing AST fields.""" + internal_tokens = set(result.mapping.values()) + + for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: + for n in _walk(tree): + if isinstance(n, ColumnNode): + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} as ColumnNode.name" + ) + if isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as ColumnNode.alias" + ) + if n.parent_alias in internal_tokens: + assert not _TOKEN_RE.match(n.parent_alias), ( + f"{tree_label} AST has raw internal token {n.parent_alias!r} " + f"as ColumnNode.parent_alias" + ) + + if isinstance(n, TableNode) and isinstance(n.name, str): + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} as TableNode.name" + ) + if isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as TableNode.alias" + ) + + if isinstance(n, SubqueryNode) and isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as SubqueryNode.alias" + ) + + if isinstance(n, FunctionNode) and isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as FunctionNode.alias" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# extendToFullSQL +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_extendToFullSQL(): + # CONDITION scope + pattern = "CAST(V1 AS DATE)" + rewrite = "V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM t WHERE CAST(V1 AS DATE)" + assert scope == Scope.CONDITION + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM t WHERE V1" + assert scope == Scope.CONDITION + + # WHERE scope + pattern = "WHERE CAST(V1 AS DATE)" + rewrite = "WHERE V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM t WHERE CAST(V1 AS DATE)" + assert scope == Scope.WHERE + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM t WHERE V1" + assert scope == Scope.WHERE + + # FROM scope + pattern = "FROM lineitem" + rewrite = "FROM v_lineitem" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM lineitem" + assert scope == Scope.FROM + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM v_lineitem" + assert scope == Scope.FROM + + # SELECT scope with FROM and WHERE + pattern = """ + select VL1 + from V1 V2, + V3 V4 + where V2.V6=V4.V8 + and VL2 + """ + rewrite = """ + select VL1 + from V1 V2 + where VL2 + """ + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == """ + select VL1 + from V1 V2, + V3 V4 + where V2.V6=V4.V8 + and VL2 + """ + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == """ + select VL1 + from V1 V2 + where VL2 + """ + assert scope == Scope.SELECT + + # SELECT scope with FROM + pattern = "SELECT VL1 FROM lineitem" + rewrite = "SELECT VL1 FROM v_lineitem" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT VL1 FROM lineitem" + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT VL1 FROM v_lineitem" + assert scope == Scope.SELECT + + # SELECT scope with only SELECT + pattern = "SELECT CAST(V1 AS DATE)" + rewrite = "SELECT V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT CAST(V1 AS DATE)" + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT V1" + assert scope == Scope.SELECT + + +def test_extendToFullSQL_subquery_not_confused(): + """Subquery inside parens shouldn't cause false FROM/SELECT scope detection.""" + sql, scope = RuleParserV2.extendToFullSQL( + "x IN (SELECT id FROM sub WHERE flag = 1)" + ) + assert scope == Scope.CONDITION + + +def test_extendToFullSQL_from_with_subquery_in_where(): + sql, scope = RuleParserV2.extendToFullSQL( + "FROM t WHERE x IN (SELECT id FROM sub)" + ) + assert scope == Scope.FROM + + +def test_extendToFullSQL_case_insensitive(): + sql, scope = RuleParserV2.extendToFullSQL("from my_table where x = 1") + assert scope == Scope.FROM + + +# ═══════════════════════════════════════════════════════════════════════════════ +# replaceVars +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_replaceVars(): + # Single element var + pattern = "CAST( AS DATE)" + rewrite = "" + pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) + assert pattern == "CAST(EV001 AS DATE)" + assert rewrite == "EV001" + assert mapping == {"x": "EV001"} + + # Multiple var and varList case + pattern = """ + select <> + from , + + where .=. + and <> + """ + rewrite = """ + select <> + from + where <> + """ + pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) + assert pattern == """ + select SV001 + from EV001 EV002, + EV003 EV004 + where EV002.EV005=EV004.EV006 + and SV002 + """ + assert rewrite == """ + select SV001 + from EV001 EV002 + where SV002 + """ + assert mapping == { + "s1": "SV001", + "p1": "SV002", + "tb1": "EV001", + "t1": "EV002", + "tb2": "EV003", + "t2": "EV004", + "a1": "EV005", + "a2": "EV006", + } + + +def test_replaceVars_distinct_names(): + """Set vars and element vars with different names get separate tokens.""" + p, r, m = RuleParserV2.replaceVars( + "SELECT <> FROM WHERE <>", + "SELECT <> FROM WHERE <>", + ) + assert m["cols"] == "SV001" + assert m["preds"] == "SV002" + assert m["tbl"] == "EV001" + + +def test_replaceVars_multiple_unique_tokens(): + p, r, m = RuleParserV2.replaceVars(" + + ", " + + ") + assert len(set(m.values())) == 3 + assert all(v.startswith("EV") for v in m.values()) + + +def test_replaceVars_same_var_in_both(): + """Same variable name in pattern and rewrite maps to the same token.""" + p, r, m = RuleParserV2.replaceVars(" = ", " = ") + assert m["x"] in p and m["x"] in r + assert m["y"] in p and m["y"] in r + + +# ═══════════════════════════════════════════════════════════════════════════════ +# find_malformed_brackets +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.mark.parametrize("bad_pattern,expected_index", [ + ("WHERE 11 AND a <= 11", 6), + ("WHERE 11 AND a <= 11", 6), + ("WHERE 11 AND a <= 11", 6), + ("WHERE [x> > 11 AND a <= 11", 6), + ("WHERE (x> > 11 AND a <= 11", 6), + ("WHERE {x> > 11 AND a <= 11", 6), +]) +def test_find_malformed_brackets(bad_pattern, expected_index): + assert RuleParserV2.find_malformed_brackets(bad_pattern) == expected_index + + +def test_well_formed_brackets_return_negative(): + assert RuleParserV2.find_malformed_brackets(" = ") == -1 + assert RuleParserV2.find_malformed_brackets("<> AND <>") == -1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — CONDITION scope: deep AST structure +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_parse_ast_cast_rule(): + """CAST( AS DATE) -> FunctionNode(cast, [ElementVariableNode, DataTypeNode])""" + result = RuleParserV2.parse("CAST( AS DATE)", "") + assert isinstance(result, RuleParseResult) + assert result.mapping == {"x": "EV001"} + assert isinstance(result.pattern_ast, FunctionNode) + assert result.pattern_ast.name.lower() == "cast" + cast_args = list(result.pattern_ast.children) + assert len(cast_args) == 2 + assert isinstance(cast_args[0], ElementVariableNode) and cast_args[0].name == "x" + assert isinstance(cast_args[1], DataTypeNode) + assert isinstance(result.rewrite_ast, ElementVariableNode) and result.rewrite_ast.name == "x" + + +def test_parse_ast_strpos_ilike_rule(): + """STRPOS(LOWER(), '') > 0 — deep operator / function / variable structure.""" + result = RuleParserV2.parse( + "STRPOS(LOWER(), '') > 0", + " ILIKE '%%'", + ) + assert result.mapping == {"x": "EV001", "s": "EV002"} + # Pattern: > operator + pat = result.pattern_ast + assert isinstance(pat, OperatorNode) and pat.name == ">" + ch = list(pat.children) + assert isinstance(ch[0], FunctionNode) and ch[0].name.upper() == "STRPOS" + assert isinstance(ch[1], LiteralNode) and ch[1].value == 0 + # STRPOS -> LOWER -> ElementVariableNode + strpos_args = list(ch[0].children) + lower = strpos_args[0] + assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" + assert isinstance(list(lower.children)[0], ElementVariableNode) + assert list(lower.children)[0].name == "x" + assert isinstance(strpos_args[1], LiteralNode) + # Rewrite: ILIKE + rew = result.rewrite_ast + assert isinstance(rew, FunctionNode) and rew.name.lower() == "ilike" + ilike_args = list(rew.children) + assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" + assert isinstance(ilike_args[1], LiteralNode) + assert ilike_args[1].value == "%s%" + + +def test_substitute_placeholders_limit_offset_string_tokens(): + """Directly exercise LIMIT/OFFSET token replacement for string payloads.""" + lim = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] + LimitNode("EV001"), {"EV001": "x"} + ) + off = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] + OffsetNode("EV002"), {"EV002": "y"} + ) + assert isinstance(lim, LimitNode) and lim.limit == "x" + assert isinstance(off, OffsetNode) and off.offset == "y" + + +def test_parse_substitutes_alias_fields(): + """Column/function/subquery aliases should not leak EV/SV internal tokens.""" + result = RuleParserV2.parse( + "SELECT SUM() AS , t.c AS FROM (SELECT FROM ) AS , t", + "SELECT SUM() AS , t.c AS FROM (SELECT FROM ) AS , t", + ) + assert isinstance(result, RuleParseResult) + assert result.mapping["f_alias"].startswith("EV") + assert result.mapping["c_alias"].startswith("EV") + assert result.mapping["sq_alias"].startswith("EV") + _assert_no_internal_tokens(result) + + +def test_parse_ast_max_distinct(): + """MAX(DISTINCT ) -> MAX()""" + result = RuleParserV2.parse("MAX(DISTINCT )", "MAX()") + assert isinstance(result.pattern_ast, FunctionNode) and result.pattern_ast.name.lower() == "max" + assert isinstance(result.rewrite_ast, FunctionNode) and result.rewrite_ast.name.lower() == "max" + assert "x" in list(_walk_var_names(result.pattern_ast)) + assert "x" in list(_walk_var_names(result.rewrite_ast)) + + +def test_parse_ast_contradiction(): + """ > AND <= -> FALSE""" + result = RuleParserV2.parse(" > AND <= ", "FALSE") + assert isinstance(result.pattern_ast, OperatorNode) and result.pattern_ast.name.lower() == "and" + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + + +def test_parse_ast_combine_or_to_in(): + """ = OR = -> IN (, )""" + result = RuleParserV2.parse(" = OR = ", " IN (, )") + assert isinstance(result.pattern_ast, OperatorNode) + assert isinstance(result.rewrite_ast, (OperatorNode, FunctionNode)) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + + +def test_parse_ast_or_to_case(): + """OR chain -> CASE WHEN — verifies CaseNode with 3 whens + else.""" + result = RuleParserV2.parse( + " OR OR ", + "1 = CASE WHEN THEN 1 WHEN THEN 1 WHEN THEN 1 ELSE 0 END", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + case_nodes = _find_all(result.rewrite_ast, CaseNode) + assert len(case_nodes) >= 1, "Rewrite should contain a CaseNode" + case = case_nodes[0] + assert len(case.whens) == 3 + assert case.else_val is not None + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — WHERE scope +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_parse_ast_where_scope(): + result = RuleParserV2.parse("WHERE = 1", "WHERE = 1") + assert result.mapping == {"x": "EV001"} + assert isinstance(result.pattern_ast, QueryNode) + wh = next(c for c in result.pattern_ast.children if c.type == NodeType.WHERE) + assert isinstance(wh, WhereNode) + pred = list(wh.children)[0] + assert isinstance(pred, OperatorNode) and pred.name == "=" + lhs, rhs = list(pred.children) + assert isinstance(lhs, ElementVariableNode) and lhs.name == "x" + assert isinstance(rhs, LiteralNode) and rhs.value == 1 + + +def test_parse_where_scope_strips_select_and_from(): + """WHERE scope extraction should produce no SelectNode or FromNode.""" + result = RuleParserV2.parse("WHERE > ", "WHERE > ") + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, SelectNode) is None + assert _find_first(result.pattern_ast, FromNode) is None + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — FROM scope +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_parse_ast_from_scope(): + result = RuleParserV2.parse("FROM li", "FROM li") + assert result.mapping == {"t": "EV001"} + assert isinstance(result.pattern_ast, QueryNode) + frm = next(c for c in result.pattern_ast.children if c.type == NodeType.FROM) + assert isinstance(frm, FromNode) + tab = list(frm.children)[0] + assert isinstance(tab, TableNode) and tab.name == "t" and tab.alias == "li" + + +def test_parse_from_scope_strips_select(): + """FROM scope extraction should produce no SelectNode.""" + result = RuleParserV2.parse("FROM ", "FROM ") + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, SelectNode) is None + + +def test_parse_from_scope_with_where(): + """FROM WHERE ... — pattern keeps WHERE, rewrite without WHERE drops it.""" + result = RuleParserV2.parse( + "FROM WHERE > - 2", + "FROM ", + ) + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, FromNode) is not None + assert _find_first(result.pattern_ast, WhereNode) is not None + + +def test_parse_from_scope_with_join(): + """FROM with INNER JOIN should produce JoinNode.""" + result = RuleParserV2.parse( + "FROM INNER JOIN ON . = .", + "FROM INNER JOIN ON . = .", + ) + assert isinstance(result.pattern_ast, QueryNode) + assert len(_find_all(result.pattern_ast, JoinNode)) >= 1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — SELECT scope: complex rules +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_parse_ast_select_list_varset(): + """SetVariableNode in the SELECT list.""" + result = RuleParserV2.parse( + "select <> from lineitem where 1 = 1", + "select <> from lineitem where 1 = 1", + ) + assert isinstance(result.pattern_ast, QueryNode) + select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) + assert isinstance(select, SelectNode) + first = list(select.children)[0] + assert isinstance(first, SetVariableNode) and first.name == "s1" + + +def test_parse_self_join_rule(): + """Remove Self Join: 2 tables in pattern, 1 in rewrite, SetVariableNodes present.""" + result = RuleParserV2.parse( + """select <> + from , + where .=. and <>""", + """select <> + from + where 1=1 and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, TableNode)) >= 2 + assert len(_find_all(result.rewrite_ast, TableNode)) >= 1 + pat_svs = [n for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)] + assert len(pat_svs) >= 2 # s1 and p1 + + +def test_parse_subquery_to_join_rule(): + """IN (SELECT ...) pattern has SubqueryNode; comma-join rewrite does not.""" + result = RuleParserV2.parse( + """select <> from + where in (select from where <>) + and <>""", + """select distinct <> from , + where . = . + and <> and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, SubqueryNode)) >= 1 + assert len(_find_all(result.rewrite_ast, SubqueryNode)) == 0 + + +def test_parse_join_to_filter_rule(): + """Double INNER JOIN pattern has more JoinNodes than single INNER JOIN rewrite.""" + result = RuleParserV2.parse( + """select <> + from + inner join on . = . + inner join on . = . + where . = and <>""", + """select <> + from + inner join on . = . + where . = and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, JoinNode)) > len( + _find_all(result.rewrite_ast, JoinNode) + ) + + +def test_parse_distinct_on(): + """DISTINCT ON should be preserved in the SelectNode.""" + result = RuleParserV2.parse( + "SELECT DISTINCT ON () , FROM ", + "SELECT , FROM ", + ) + _assert_varnodes_declared(result) + pat_sel = _find_first(result.pattern_ast, SelectNode) + assert pat_sel is not None + assert pat_sel.distinct or getattr(pat_sel, "distinct_on", None) is not None + + +def test_parse_order_by_and_limit(): + """ORDER BY and LIMIT should produce their respective node types.""" + result = RuleParserV2.parse( + "SELECT FROM ORDER BY ASC LIMIT ", + "SELECT FROM ORDER BY ASC LIMIT ", + ) + _assert_varnodes_declared(result) + assert _find_first(result.pattern_ast, OrderByNode) is not None + assert _find_first(result.pattern_ast, OrderByItemNode) is not None + assert _find_first(result.pattern_ast, LimitNode) is not None + + +def test_parse_distinct_to_group_by(): + """SELECT DISTINCT -> GROUP BY rewrite.""" + result = RuleParserV2.parse( + "SELECT DISTINCT <> FROM <> WHERE <>", + "SELECT <> FROM <> WHERE <> GROUP BY <>", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert _find_first(result.rewrite_ast, GroupByNode) is not None + pat_sel = _find_first(result.pattern_ast, SelectNode) + if pat_sel is not None: + assert pat_sel.distinct is True + + +def test_parse_set_variable_in_select_and_where(): + """SetVariableNode should appear in both SELECT and WHERE.""" + result = RuleParserV2.parse( + "SELECT <> FROM tbl WHERE <>", + "SELECT <> FROM tbl WHERE <>", + ) + sv_names = {n.name for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)} + assert "cols" in sv_names + assert "preds" in sv_names + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Column + parent_alias substitution +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_qualified_column_both_parts_substituted(): + """. — both parent_alias and name should become external names.""" + result = RuleParserV2.parse(". = 1", ". = 1") + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + cols = _find_all(result.pattern_ast, ColumnNode) + qualified = [c for c in cols if c.parent_alias is not None] + assert len(qualified) >= 1 + for c in qualified: + assert c.parent_alias in result.mapping + assert c.name in result.mapping + + +def test_qualified_column_only_parent_alias_is_var(): + """.fixed_col — only the alias is a variable; the column name is a literal.""" + result = RuleParserV2.parse(".created_at = 1", ".created_at = 1") + _assert_no_internal_tokens(result) + cols = _find_all(result.pattern_ast, ColumnNode) + qualified = [c for c in cols if c.parent_alias is not None] + assert len(qualified) >= 1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Error paths +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_invalid_sql_raises(): + """Completely invalid SQL should raise during parse.""" + with pytest.raises(Exception): + RuleParserV2.parse("!!NOT_VALID_SQL!!", "") + + +def test_deeply_nested_parens(): + """Deeply nested expressions should not confuse scope detection.""" + result = RuleParserV2.parse( + "((( + ) * ) > 0)", + "( + ) * > 0", + ) + _assert_varnodes_declared(result) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# No internal token leak — parametrized across shapes +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.mark.parametrize("pattern,rewrite", [ + ("CAST( AS DATE)", ""), + ("STRPOS(LOWER(), '') > 0", " ILIKE '%%'"), + ("MAX(DISTINCT )", "MAX()"), + (" = OR = ", " IN (, )"), + ("WHERE = 1", "WHERE = 1"), + ("FROM ", "FROM "), +]) +def test_no_internal_tokens_survive(pattern, rewrite): + result = RuleParserV2.parse(pattern, rewrite) + _assert_no_internal_tokens(result) + _assert_varnodes_declared(result) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Mapping consistency — parse() vs replaceVars() +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.mark.parametrize("pattern,rewrite", [ + ("CAST( AS DATE)", ""), + ("STRPOS(LOWER(), '') > 0", " ILIKE '%%'"), + ("SELECT <> FROM WHERE <>", "SELECT <> FROM WHERE <>"), +]) +def test_parse_mapping_matches_replaceVars(pattern, rewrite): + _, _, expected = RuleParserV2.replaceVars(pattern, rewrite) + result = RuleParserV2.parse(pattern, rewrite) + assert result.mapping == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Variable coverage +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_rewrite_vars_subset_of_pattern(): + """For simple rules, rewrite variables are a subset of pattern variables.""" + result = RuleParserV2.parse("CAST( AS DATE)", "") + assert set(_walk_var_names(result.rewrite_ast)) <= set(_walk_var_names(result.pattern_ast)) + + +def test_identity_rule_same_vars(): + """An identity rewrite has the same variable set in both trees.""" + result = RuleParserV2.parse(" = ", " = ") + assert set(_walk_var_names(result.pattern_ast)) == set(_walk_var_names(result.rewrite_ast)) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# VarType / VarTypesInfo metadata +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_element_var_markers(): + info = VarTypesInfo[VarType.ElementVariable] + assert info["markerStart"] == "<" + assert info["markerEnd"] == ">" + assert info["internalBase"] == "EV" + + +def test_set_var_markers(): + info = VarTypesInfo[VarType.SetVariable] + assert info["markerStart"] == "<<" + assert info["markerEnd"] == ">>" + assert info["internalBase"] == "SV" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# data/rules.py catalog — parametrized over all rules +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.mark.parametrize( + "rule", + RULES_CATALOG, + ids=[r["key"] for r in RULES_CATALOG], +) +class TestCatalogRules: + + def test_parse_succeeds(self, rule): + """Full parse pipeline completes without error.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + assert isinstance(result, RuleParseResult) + assert result.pattern_ast is not None + assert result.rewrite_ast is not None + + def test_mapping_matches_replaceVars(self, rule): + """parse() returns the same mapping as replaceVars().""" + _, _, expected = RuleParserV2.replaceVars(rule["pattern"], rule["rewrite"]) + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + assert result.mapping == expected + + def test_varnodes_declared_in_mapping(self, rule): + """Every variable node in the AST uses an external name present in mapping.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + _assert_varnodes_declared(result) + + def test_no_internal_tokens_leak(self, rule): + """No EV00x / SV00x tokens survive as raw identifiers.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + _assert_no_internal_tokens(result) \ No newline at end of file