diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 7b929356af3..ba25de02d98 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -53,7 +53,9 @@ SRC_FILES_HRW4U=src/visitor.py \ src/suggestions.py \ src/procedures.py \ src/sandbox.py \ - src/kg_visitor.py + src/kg_visitor.py \ + src/ast_nodes.py \ + src/ast_visitor.py ALL_HRW4U_FILES=$(SHARED_FILES) $(UTILS_FILES) $(SRC_FILES_HRW4U) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py new file mode 100644 index 00000000000..8667290cc15 --- /dev/null +++ b/tools/hrw4u/src/ast_nodes.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Union + + +class ValueKind(Enum): + STRING = "string" + IDENT = "ident" + PARAM_REF = "param_ref" + IP = "ip" + REGEX = "regex" + + +@dataclass(frozen=True, kw_only=True) +class Value: + raw: str + kind: ValueKind + + +@dataclass(frozen=True, kw_only=True) +class Node: + line: int + + +@dataclass(frozen=True) +class Target: + namespace: str | None + field: str + + @staticmethod + def from_dotted(name: str) -> Target: + # TODO: the grammar lexes dotted paths as a single IDENT token; + # ideally the grammar would split namespace/field so this + # heuristic isn't needed. + dot = name.rfind(".") + if dot == -1: + return Target(namespace=None, field=name) + return Target(namespace=name[:dot], field=name[dot + 1:]) + + +@dataclass(frozen=True, kw_only=True) +class Assignment(Node): + target: Target + operator: str # "=" or "+=" + value: Value | int | bool | tuple + + +@dataclass(frozen=True, kw_only=True) +class FunctionCall(Node): + name: str + args: tuple[Value | int | bool, ...] + + +@dataclass(frozen=True, kw_only=True) +class Break(Node): + pass + + +@dataclass(frozen=True, kw_only=True) +class Comparison(Node): + left: Value | FunctionCall + operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" + right: Value | int | bool | tuple + modifiers: tuple[str, ...] + + +@dataclass(frozen=True, kw_only=True) +class LogicalOp(Node): + operator: str # "&&" or "||" + left: ConditionExpr + right: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class NotOp(Node): + operand: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class BoolLiteral(Node): + value: bool + + +@dataclass(frozen=True, kw_only=True) +class IdentCondition(Node): + name: str + + +@dataclass(frozen=True, kw_only=True) +class ElifBranch(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class IfBlock(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + elif_branches: tuple[ElifBranch, ...] + else_body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class Section(Node): + type: str + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class ProcParam(Node): + name: str + default: Value | int | bool | None + + +@dataclass(frozen=True, kw_only=True) +class VarDecl(Node): + name: str + type_name: str + slot: int | None + + +@dataclass(frozen=True, kw_only=True) +class VarSection(Node): + scope: str + declarations: tuple[VarDecl, ...] + + +@dataclass(frozen=True, kw_only=True) +class UseDirective(Node): + spec: str + + +@dataclass(frozen=True, kw_only=True) +class ProcedureDecl(Node): + name: str + params: tuple[ProcParam, ...] + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class HRW4UAST: + body: tuple[TopLevelNode, ...] + + +# Type aliases: must follow all class definitions (evaluated at runtime). +ConditionExpr = Union[Comparison, LogicalOp, NotOp, BoolLiteral, IdentCondition, FunctionCall] +BodyNode = Union[Assignment, FunctionCall, IfBlock, Break] +TopLevelNode = Union[UseDirective, VarSection, ProcedureDecl, Section] diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py new file mode 100644 index 00000000000..ec8fcedf1aa --- /dev/null +++ b/tools/hrw4u/src/ast_visitor.py @@ -0,0 +1,299 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from hrw4u.hrw4uVisitor import hrw4uVisitor +from hrw4u.ast_nodes import ( + HRW4UAST, + Section, + Assignment, + FunctionCall, + Break, + Target, + IfBlock, + ElifBranch, + BoolLiteral, + Comparison, + LogicalOp, + NotOp, + IdentCondition, + ProcParam, + VarDecl, + VarSection, + UseDirective, + ProcedureDecl, + Value, + ValueKind, +) + + +class ASTVisitor(hrw4uVisitor): + """ANTLR visitor that walks an HRW4U parse tree and produces an AST for HRW4U.""" + + # Only visitProgram is overridden from the ANTLR visitor interface; + # all other traversal uses private _visit_* helpers so that each + # method has an explicit return type and full control over how + # child results are assembled into parent AST nodes. + + def visitProgram(self, ctx): + items = [] + for item in ctx.programItem(): + if item.useDirective() is not None: + items.append(self._visit_use_directive(item.useDirective())) + elif item.procedureDecl() is not None: + items.append(self._visit_procedure_decl(item.procedureDecl())) + elif item.section() is not None: + items.append(self._visit_section(item.section())) + elif item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled programItem alternative at line {item.start.line}") + return HRW4UAST(body=tuple(items)) + + def _visit_use_directive(self, ctx): + return UseDirective( + spec=ctx.QUALIFIED_IDENT().getText(), + line=ctx.start.line, + ) + + def _visit_procedure_decl(self, ctx): + name = ctx.QUALIFIED_IDENT().getText() + params = () + if ctx.paramList(): + params = tuple( + self._visit_proc_param(p) for p in ctx.paramList().param() + ) + body = tuple(self._visit_body(ctx.block().blockItem())) + return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) + + def _visit_proc_param(self, ctx): + name = ctx.IDENT().getText() + default = self._extract_value(ctx.value()) if ctx.value() else None + return ProcParam(name=name, default=default, line=ctx.start.line) + + def _visit_section(self, ctx): + if ctx.varSection() is not None: + return self._visit_var_section(ctx.varSection(), "txn") + if ctx.sessionVarSection() is not None: + return self._visit_var_section(ctx.sessionVarSection(), "session") + name = ctx.name.text + body = self._visit_body(ctx.sectionBody()) + return Section(type=name, body=tuple(body), line=ctx.start.line) + + def _visit_var_section(self, ctx, scope): + decls = [] + for var_item in ctx.variables().variablesItem(): + if var_item.variableDecl() is not None: + decls.append(self._visit_var_decl(var_item.variableDecl())) + return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) + + def _visit_var_decl(self, ctx): + return VarDecl( + name=ctx.name.text, + type_name=ctx.typeName.text, + slot=int(ctx.slot.text) if ctx.slot else None, + line=ctx.start.line, + ) + + def _visit_body(self, items): + """Shared helper for sectionBody and blockItem lists.""" + result = [] + for item in items: + if item.statement() is not None: + result.append(self._visit_statement(item.statement())) + elif item.conditional() is not None: + result.append(self._visit_conditional(item.conditional())) + return result + + def _visit_statement(self, ctx): + line = ctx.start.line + if ctx.BREAK(): + return Break(line=line) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.EQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="=", value=value, line=line) + if ctx.PLUSEQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="+=", value=value, line=line) + if ctx.op: + return FunctionCall(name=ctx.op.text, args=(), line=line) + raise ValueError(f"Unhandled statement alternative at line {line}") + + def _visit_function_call(self, ctx): + name = ctx.funcName.text + args = () + if ctx.argumentList(): + args = tuple( + self._extract_value(v) for v in ctx.argumentList().value() + ) + return FunctionCall(name=name, args=args, line=ctx.start.line) + + def _extract_value(self, ctx): + if ctx.number is not None: + return int(ctx.number.text) + if ctx.str_ is not None: + return Value(raw=ctx.str_.text[1:-1], kind=ValueKind.STRING) + if ctx.TRUE(): + return True + if ctx.FALSE(): + return False + if ctx.ident is not None: + return Value(raw=ctx.ident.text, kind=ValueKind.IDENT) + if ctx.ip(): + return Value(raw=ctx.ip().getText(), kind=ValueKind.IP) + if ctx.iprange(): + return tuple(Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip()) + if ctx.paramRef(): + return Value(raw=ctx.paramRef().IDENT().getText(), kind=ValueKind.PARAM_REF) + return Value(raw=ctx.getText(), kind=ValueKind.IDENT) + + def _visit_conditional(self, ctx): + if_stmt = ctx.ifStatement() + condition = self._visit_condition(if_stmt.condition()) + block = if_stmt.block() + body = tuple(self._visit_body(block.blockItem())) if block else () + + elif_branches = [] + for elif_ctx in ctx.elifClause(): + elif_cond = self._visit_condition(elif_ctx.condition()) + elif_block = elif_ctx.block() + elif_body = tuple(self._visit_body(elif_block.blockItem())) if elif_block else () + elif_branches.append(ElifBranch( + condition=elif_cond, + body=elif_body, + line=elif_ctx.start.line, + )) + + else_body = () + if ctx.elseClause(): + else_block = ctx.elseClause().block() + if else_block: + else_body = tuple(self._visit_body(else_block.blockItem())) + + return IfBlock( + condition=condition, + body=body, + elif_branches=tuple(elif_branches), + else_body=else_body, + line=ctx.start.line, + ) + + def _visit_condition(self, ctx): + return self._visit_expression(ctx.expression()) + + def _visit_expression(self, ctx): + if ctx.OR(): + left = self._visit_expression(ctx.expression()) + right = self._visit_term(ctx.term()) + return LogicalOp( + operator="||", left=left, right=right, line=ctx.start.line, + ) + return self._visit_term(ctx.term()) + + def _visit_term(self, ctx): + if ctx.AND(): + left = self._visit_term(ctx.term()) + right = self._visit_factor(ctx.factor()) + return LogicalOp( + operator="&&", left=left, right=right, line=ctx.start.line, + ) + return self._visit_factor(ctx.factor()) + + def _visit_factor(self, ctx): + if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": + return NotOp( + operand=self._visit_factor(ctx.factor()), + line=ctx.start.line, + ) + if ctx.LPAREN(): + return self._visit_expression(ctx.expression()) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.comparison(): + return self._visit_comparison(ctx.comparison()) + if ctx.ident is not None: + return IdentCondition(name=ctx.ident.text, line=ctx.start.line) + if ctx.TRUE(): + return BoolLiteral(value=True, line=ctx.start.line) + if ctx.FALSE(): + return BoolLiteral(value=False, line=ctx.start.line) + raise ValueError(f"Unhandled factor alternative at line {ctx.start.line}") + + def _visit_comparison(self, ctx): + line = ctx.start.line + comp = ctx.comparable() + if comp.ident is not None: + left = Value(raw=comp.ident.text, kind=ValueKind.IDENT) + else: + left = self._visit_function_call(comp.functionCall()) + + operator = self._detect_comparison_operator(ctx) + right = self._extract_comparison_rhs(ctx, operator) + modifiers = self._extract_modifiers(ctx) + + return Comparison( + left=left, operator=operator, right=right, + modifiers=modifiers, line=line, + ) + + def _detect_comparison_operator(self, ctx): + if ctx.EQUALS(): + return "==" + if ctx.NEQ(): + return "!=" + if ctx.GT(): + return ">" + if ctx.LT(): + return "<" + if ctx.TILDE(): + return "~" + if ctx.NOT_TILDE(): + return "!~" + if ctx.IN(): + for child in ctx.children: + if hasattr(child, "getText") and child.getText() == "!": + return "!in" + return "in" + raise ValueError(f"Unhandled comparison operator at line {ctx.start.line}") + + def _extract_comparison_rhs(self, ctx, operator): + if operator in ("~", "!~"): + return Value(raw=ctx.regex().getText()[1:-1], kind=ValueKind.REGEX) + if operator in ("in", "!in"): + if ctx.set_(): + return tuple( + self._extract_value(v) for v in ctx.set_().value() + ) + if ctx.iprange(): + return tuple( + Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip() + ) + if ctx.value(): + return self._extract_value(ctx.value()) + raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") + + def _extract_modifiers(self, ctx): + if ctx.modifier(): + return tuple( + tok.text for tok in ctx.modifier().modifierList().mods + ) + return () diff --git a/tools/hrw4u/tests/test_ast_nodes.py b/tools/hrw4u/tests/test_ast_nodes.py new file mode 100644 index 00000000000..1c2320c4e53 --- /dev/null +++ b/tools/hrw4u/tests/test_ast_nodes.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import Target + + +class TestTarget: + def test_dotted_path(self): + t = Target.from_dotted("inbound.req.X-Foo") + assert t.namespace == "inbound.req" + assert t.field == "X-Foo" + + def test_two_segments(self): + t = Target.from_dotted("inbound.ip") + assert t.namespace == "inbound" + assert t.field == "ip" + + def test_no_dots(self): + t = Target.from_dotted("bool_0") + assert t.namespace is None + assert t.field == "bool_0" + + def test_deep_namespace(self): + t = Target.from_dotted("http.cntl.TXN_DEBUG") + assert t.namespace == "http.cntl" + assert t.field == "TXN_DEBUG" diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py new file mode 100644 index 00000000000..6d3cf574b05 --- /dev/null +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -0,0 +1,758 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import ( + Target, Assignment, FunctionCall, Break, Section, HRW4UAST, + Comparison, IfBlock, ElifBranch, BoolLiteral, NotOp, LogicalOp, IdentCondition, + VarSection, VarDecl, UseDirective, ProcedureDecl, ProcParam, + Value, ValueKind, +) +from utils import parse_input_text +from hrw4u.ast_visitor import ASTVisitor + + +def _build(source: str) -> HRW4UAST: + _, tree = parse_input_text(source) + return ASTVisitor().visit(tree) + + +class TestAssignments: + def test_simple_assignment(self): + ast = _build('REMAP {\n inbound.req.X-Foo = "test";\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.target == Target.from_dotted("inbound.req.X-Foo") + assert a.operator == "=" + assert a.value == Value(raw="test", kind=ValueKind.STRING) + + def test_bool_value(self): + ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value is True + + def test_int_value(self): + ast = _build('REMAP {\n http.cntl.INTERCEPT_RETRY = 1;\n}') + a = ast.body[0].body[0] + assert a.value == 1 + + def test_plus_equals(self): + ast = _build('REMAP {\n inbound.req.X-Foo += "extra";\n}') + a = ast.body[0].body[0] + assert a.operator == "+=" + + def test_ip_value(self): + ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == Value(raw="10.0.0.1", kind=ValueKind.IP) + + def test_param_ref_value(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == Value(raw="tag", kind=ValueKind.PARAM_REF) + + +class TestFunctionCalls: + def test_no_args(self): + ast = _build('REMAP {\n set-debug();\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "set-debug" + assert fc.args == () + + def test_with_args(self): + ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') + fc = ast.body[0].body[0] + assert fc.name == "set-header" + assert fc.args == (Value(raw="X-Foo", kind=ValueKind.STRING), Value(raw="bar", kind=ValueKind.STRING)) + + def test_standalone_operator(self): + ast = _build('REMAP {\n skip-remap;\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "skip-remap" + assert fc.args == () + + def test_break(self): + ast = _build('REMAP {\n if true {\n break;\n }\n}') + body = ast.body[0].body[0].body + assert isinstance(body[0], Break) + + +class TestSections: + def test_section_type(self): + ast = _build('REMAP {\n set-debug();\n}') + s = ast.body[0] + assert isinstance(s, Section) + assert s.type == "REMAP" + + def test_multiple_sections(self): + src = 'REMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 2 + assert sections[0].type == "REMAP" + assert sections[1].type == "SEND_RESPONSE" + + def test_use_directive(self): + src = 'use test::add-debug-header\nREMAP {\n test::add-debug-header("tag");\n}' + ast = _build(src) + assert len(ast.body) == 2 + u = ast.body[0] + assert isinstance(u, UseDirective) + assert u.spec == "test::add-debug-header" + + def test_item_ordering(self): + src = 'VARS {\n x: bool;\n}\nREMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + assert len(ast.body) == 3 + assert isinstance(ast.body[0], VarSection) + assert isinstance(ast.body[1], Section) + assert isinstance(ast.body[2], Section) + + +class TestVarSections: + def test_txn_scope(self): + src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "txn" + assert len(vs.declarations) == 1 + assert vs.declarations[0].name == "flag" + assert vs.declarations[0].type_name == "bool" + assert vs.declarations[0].slot is None + + def test_session_scope(self): + src = 'SESSION_VARS {\n counter: int;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "session" + assert vs.declarations[0].name == "counter" + + def test_slot(self): + src = 'VARS {\n x: int @3;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.declarations[0].slot == 3 + + def test_multiple_declarations(self): + src = 'VARS {\n a: bool;\n b: int;\n c: string;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert len(vs.declarations) == 3 + assert vs.declarations[0].name == "a" + assert vs.declarations[1].name == "b" + assert vs.declarations[2].name == "c" + + +class TestProcedures: + def test_basic_decl(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = "$tag";\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.name == "local::stamp" + assert len(pd.params) == 1 + assert pd.params[0].name == "tag" + assert pd.params[0].default is None + + def test_default_param(self): + src = 'procedure local::cache($ttl=300) {\n set-debug();\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.params[0].name == "ttl" + assert pd.params[0].default == 300 + + def test_body(self): + src = ('procedure local::multi() {\n inbound.req.X = "a";\n' + ' set-debug();\n}\nREMAP {\n set-debug();\n}') + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert len(pd.body) == 2 + assert isinstance(pd.body[0], Assignment) + assert isinstance(pd.body[1], FunctionCall) + + +class TestConditionExpressions: + def _first_condition(self, source: str): + ast = _build(source) + return ast.body[0].body[0].condition + + def test_equality_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.left == Value(raw="inbound.req.X-Foo", kind=ValueKind.IDENT) + assert cond.operator == "==" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + assert cond.modifiers == () + + def test_regex_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "~" + assert isinstance(cond.right, Value) + assert cond.right.kind == ValueKind.REGEX + + def test_in_set(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (Value(raw="a", kind=ValueKind.STRING), Value(raw="b", kind=ValueKind.STRING)) + + def test_not_in_set(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!in" + + def test_in_iprange(self): + cond = self._first_condition( + 'REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (Value(raw="10.0.0.0/8", kind=ValueKind.IP),) + + def test_modifiers(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.modifiers == ("NOCASE",) + + def test_function_call_comparable(self): + cond = self._first_condition( + 'REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert isinstance(cond.left, FunctionCall) + assert cond.left.name == "url" + assert cond.left.args == (True,) + + def test_bool_literal_true(self): + cond = self._first_condition( + 'REMAP {\n if true {\n set-debug();\n }\n}' + ) + assert isinstance(cond, BoolLiteral) + assert cond.value is True + + def test_ident_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, IdentCondition) + assert cond.name == "inbound.resp.All-Cache" + + def test_not_condition(self): + cond = self._first_condition( + 'REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, IdentCondition) + + def test_and_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, Comparison) + assert isinstance(cond.right, Comparison) + + def test_or_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + + def test_function_call_in_condition(self): + cond = self._first_condition( + 'REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}' + ) + assert isinstance(cond, FunctionCall) + assert cond.name == "access" + assert cond.args == (Value(raw="/tmp/bar", kind=ValueKind.STRING),) + + def test_not_tilde_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!~" + assert isinstance(cond.right, Value) + assert cond.right.kind == ValueKind.REGEX + + def test_greater_than_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == ">" + assert cond.right == 1000 + + def test_less_than_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "<" + assert cond.right == 500 + + def test_neq_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!=" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + + def test_parenthesized_condition(self): + cond = self._first_condition( + 'REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "==" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + + def test_and_binds_tighter_than_or(self): + # a || b && c should parse as a || (b && c) + cond = self._first_condition( + 'REMAP {\n' + ' if inbound.req.X-A == "a" || inbound.req.X-B == "b" && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, Comparison) + assert cond.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert isinstance(cond.right, LogicalOp) + assert cond.right.operator == "&&" + assert cond.right.left.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert cond.right.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + + def test_not_with_and(self): + # !ident && comparison should parse as (!ident) && comparison + cond = self._first_condition( + 'REMAP {\n' + ' if !inbound.resp.All-Cache && inbound.req.X-B == "b" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, IdentCondition) + assert cond.left.operand.name == "inbound.resp.All-Cache" + assert isinstance(cond.right, Comparison) + assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + + def test_not_comparison_with_or(self): + # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x") || inbound.req.X-B == "y" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, Comparison) + assert cond.left.operand.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert cond.left.operand.right == Value(raw="x", kind=ValueKind.STRING) + assert isinstance(cond.right, Comparison) + assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + + def test_double_negation(self): + cond = self._first_condition( + 'REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, NotOp) + assert isinstance(cond.operand.operand, IdentCondition) + assert cond.operand.operand.name == "inbound.resp.All-Cache" + + def test_not_bool_literal(self): + cond = self._first_condition( + 'REMAP {\n if !false {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, BoolLiteral) + assert cond.operand.value is False + + def test_parens_override_precedence(self): + # (a || b) && c — parens force || to bind first + cond = self._first_condition( + 'REMAP {\n' + ' if (inbound.req.X-A == "a" || inbound.req.X-B == "b") && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, LogicalOp) + assert cond.left.operator == "||" + assert cond.left.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert cond.left.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert isinstance(cond.right, Comparison) + assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + + def test_nested_parens_with_not(self): + # !(a == "x" || b == "y") && c == "z" + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x" || inbound.req.X-B == "y") && inbound.req.X-C == "z" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, LogicalOp) + assert cond.left.operand.operator == "||" + assert isinstance(cond.right, Comparison) + assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + + +class TestIfBlocks: + def test_simple_if(self): + ast = _build( + 'REMAP {\n if true {\n inbound.req.X = "y";\n }\n}' + ) + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.body) == 1 + assert ib.elif_branches == () + assert ib.else_body == () + + def test_if_else(self): + src = 'REMAP {\n if true {\n inbound.req.X = "a";\n } else {\n inbound.req.X = "b";\n }\n}' + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.else_body) == 1 + + def test_if_elif_else(self): + src = ('SEND_RESPONSE {\n if inbound.url.path == "foo" {\n' + ' inbound.resp.X = "f";\n } elif inbound.url.path == "bar" {\n' + ' inbound.resp.X = "b";\n } else {\n' + ' inbound.resp.X = "other";\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.elif_branches) == 1 + assert isinstance(ib.elif_branches[0], ElifBranch) + assert len(ib.elif_branches[0].body) == 1 + assert len(ib.else_body) == 1 + + def test_multiple_elif(self): + src = ('SEND_RESPONSE {\n if inbound.url.path == "a" {\n set-debug();\n' + ' } elif inbound.url.path == "b" {\n set-debug();\n' + ' } elif inbound.url.path == "c" {\n set-debug();\n' + ' } else {\n set-debug();\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.elif_branches) == 2 + + def test_nested_if(self): + src = ('REMAP {\n if inbound.req.X == "a" {\n' + ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') + ast = _build(src) + outer = ast.body[0].body[0] + assert isinstance(outer, IfBlock) + inner = outer.body[0] + assert isinstance(inner, IfBlock) + + def test_mixed_body(self): + src = ('REMAP {\n inbound.req.X = "before";\n' + ' if true {\n set-debug();\n }\n' + ' inbound.req.Y = "after";\n}') + ast = _build(src) + body = ast.body[0].body + assert len(body) == 3 + assert isinstance(body[0], Assignment) + assert isinstance(body[1], IfBlock) + assert isinstance(body[2], Assignment) + + +class TestLineNumbers: + SRC = ( + "use test::helper\n" # line 1 + "VARS {\n" # line 2 + " flag: bool;\n" # line 3 + "}\n" # line 4 + "procedure local::stamp($tag) {\n" # line 5 + " inbound.req.X-Stamp = $tag;\n" # line 6 + "}\n" # line 7 + "REMAP {\n" # line 8 + ' inbound.req.X-Foo = "val";\n' # line 9 + " set-debug();\n" # line 10 + " skip-remap;\n" # line 11 + ' if inbound.req.X-A == "a" {\n' # line 12 + " break;\n" # line 13 + ' } elif inbound.req.X-B == "b" {\n' # line 14 + ' inbound.req.X = "elif";\n' # line 15 + " } else {\n" # line 16 + ' inbound.req.X = "else";\n' # line 17 + " }\n" # line 18 + ' if inbound.req.X-C == "c" && inbound.req.X-D == "d" {\n' # line 19 + " set-debug();\n" # line 20 + " }\n" # line 21 + " if !inbound.resp.All-Cache {\n" # line 22 + " set-debug();\n" # line 23 + " }\n" # line 24 + " if true {\n" # line 25 + " set-debug();\n" # line 26 + " }\n" # line 27 + " if inbound.resp.All-Cache {\n" # line 28 + " set-debug();\n" # line 29 + " }\n" # line 30 + "}\n" # line 31 + ) + + def setup_method(self): + self.ast = _build(self.SRC) + + def test_use_directive(self): + u = self.ast.body[0] + assert isinstance(u, UseDirective) + assert u.line == 1 + + def test_var_section(self): + vs = self.ast.body[1] + assert isinstance(vs, VarSection) + assert vs.line == 2 + + def test_var_decl(self): + vd = self.ast.body[1].declarations[0] + assert isinstance(vd, VarDecl) + assert vd.line == 3 + + def test_procedure_decl(self): + pd = self.ast.body[2] + assert isinstance(pd, ProcedureDecl) + assert pd.line == 5 + + def test_proc_param(self): + pp = self.ast.body[2].params[0] + assert isinstance(pp, ProcParam) + assert pp.line == 5 + + def test_procedure_body_assignment(self): + a = self.ast.body[2].body[0] + assert isinstance(a, Assignment) + assert a.line == 6 + + def test_section(self): + s = self.ast.body[3] + assert isinstance(s, Section) + assert s.line == 8 + + def test_assignment(self): + a = self.ast.body[3].body[0] + assert isinstance(a, Assignment) + assert a.line == 9 + + def test_function_call(self): + fc = self.ast.body[3].body[1] + assert isinstance(fc, FunctionCall) + assert fc.line == 10 + + def test_standalone_operator(self): + fc = self.ast.body[3].body[2] + assert isinstance(fc, FunctionCall) + assert fc.line == 11 + + def test_if_block(self): + ib = self.ast.body[3].body[3] + assert isinstance(ib, IfBlock) + assert ib.line == 12 + + def test_comparison_in_condition(self): + cond = self.ast.body[3].body[3].condition + assert isinstance(cond, Comparison) + assert cond.line == 12 + + def test_break(self): + brk = self.ast.body[3].body[3].body[0] + assert isinstance(brk, Break) + assert brk.line == 13 + + def test_elif_branch(self): + eb = self.ast.body[3].body[3].elif_branches[0] + assert isinstance(eb, ElifBranch) + assert eb.line == 14 + + def test_elif_condition(self): + cond = self.ast.body[3].body[3].elif_branches[0].condition + assert isinstance(cond, Comparison) + assert cond.line == 14 + + def test_logical_op(self): + cond = self.ast.body[3].body[4].condition + assert isinstance(cond, LogicalOp) + assert cond.line == 19 + + def test_not_op(self): + cond = self.ast.body[3].body[5].condition + assert isinstance(cond, NotOp) + assert cond.line == 22 + + def test_bool_literal(self): + cond = self.ast.body[3].body[6].condition + assert isinstance(cond, BoolLiteral) + assert cond.line == 25 + + def test_ident_condition(self): + cond = self.ast.body[3].body[7].condition + assert isinstance(cond, IdentCondition) + assert cond.line == 28 + + +class TestRealConfigs: + def test_nested_ifs_from_test_data(self): + """Validates AST for tests/data/conds/nested-ifs.input.txt pattern.""" + src = '''VARS { + bool_0: bool; + bool_1: bool; + bool_2: bool; +} + +REMAP { + if inbound.req.X-Foo == "bar" { + inbound.req.X-Hello = "there"; + if inbound.req.X-Fie == "fie" { + inbound.req.X-first = "1"; + if bool_0 || (bool_1 && bool_2) { + inbound.req.X-Parsed = "more"; + } else { + inbound.req.X-Parsed = "yes"; + } + } elif inbound.req.X-Fum == "bar" { + inbound.req.X-Parsed = "no"; + } else { + inbound.req.X-More = "yes"; + } + } elif inbound.req.X-Foo == "foo" with NOCASE,PRE { + inbound.req.X-Nocase = "foo"; + } else { + inbound.req.X-Something = "no-bar"; + } +}''' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 1 + s = sections[0] + assert s.type == "REMAP" + + # Top-level if block + outer = s.body[0] + assert isinstance(outer, IfBlock) + + # Body: assignment + nested if + assert isinstance(outer.body[0], Assignment) + assert isinstance(outer.body[1], IfBlock) + middle = outer.body[1] + + # Middle if has elif and else + assert len(middle.elif_branches) == 1 + assert len(middle.else_body) == 1 + + # Deepest nested if (3 levels) + inner = middle.body[1] + assert isinstance(inner, IfBlock) + assert isinstance(inner.condition, LogicalOp) + assert inner.condition.operator == "||" + + # Outer elif has modifiers + assert len(outer.elif_branches) == 1 + elif_cond = outer.elif_branches[0].condition + assert isinstance(elif_cond, Comparison) + assert elif_cond.modifiers == ("NOCASE", "PRE") + + def test_http_cntl_booleans(self): + """Validates value coercion for boolean-like assignments.""" + src = '''SEND_RESPONSE { + http.cntl.TXN_DEBUG = true; + http.cntl.LOGGING = FALSE; +}''' + ast = _build(src) + body = ast.body[0].body + assert body[0].value is True + assert body[1].value is False + + def test_ip_range_condition(self): + """Validates IP range handling from tests/data/conds/ip.input.txt.""" + src = '''SEND_REQUEST { + if inbound.ip in {192.168.0.0/16, 10.0.0.0/8} { + set-debug(); + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert len(cond.right) == 2 + + def test_set_membership_with_modifier(self): + """From tests/data/conds/in-sets.input.txt.""" + src = '''REMAP { + if inbound.url.path in ["php", "php3", "php4"] with EXT { + inbound.req.X-Is-PHP = "yes"; + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (Value(raw="php", kind=ValueKind.STRING), Value(raw="php3", kind=ValueKind.STRING), Value(raw="php4", kind=ValueKind.STRING)) + assert cond.modifiers == ("EXT",) + + def test_debug_pattern_for_lint_rules(self): + """Validates the exact pattern the no-debug lint rule will match.""" + src = '''REMAP { + set-debug(); + http.cntl.TXN_DEBUG = true; + inbound.req.X-Foo = "test"; +}''' + ast = _build(src) + body = ast.body[0].body + + # set-debug() function call + assert isinstance(body[0], FunctionCall) + assert body[0].name == "set-debug" + + # TXN_DEBUG assignment with True + assert isinstance(body[1], Assignment) + assert body[1].target == Target.from_dotted("http.cntl.TXN_DEBUG") + assert body[1].value is True + + # Regular assignment (not flagged) + assert isinstance(body[2], Assignment) + assert body[2].target.namespace == "inbound.req"