From 676d0e6f4190661463a5d6f4b262df4279c87513 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 15:01:15 -0700 Subject: [PATCH] Add Yardstick upstream parity CI --- .github/workflows/yardstick-upstream.yml | 53 ++ docs/compatibility/yardstick.md | 45 +- pyproject.toml | 1 + sidemantic/adapters/yardstick.py | 12 +- sidemantic/sql/aggregation_detection.py | 13 +- sidemantic/sql/query_rewriter.py | 262 +++++++++- .../queries/test_yardstick_measures_replay.py | 465 +++++++++++++++++- .../queries/test_yardstick_query_rewriter.py | 80 ++- 8 files changed, 892 insertions(+), 39 deletions(-) create mode 100644 .github/workflows/yardstick-upstream.yml diff --git a/.github/workflows/yardstick-upstream.yml b/.github/workflows/yardstick-upstream.yml new file mode 100644 index 00000000..e89ad8ab --- /dev/null +++ b/.github/workflows/yardstick-upstream.yml @@ -0,0 +1,53 @@ +name: Yardstick Upstream Parity + +on: + workflow_dispatch: + inputs: + yardstick_ref: + description: "Yardstick ref to test against" + required: true + default: "main" + schedule: + - cron: "17 10 * * *" + pull_request: + paths: + - ".github/workflows/yardstick-upstream.yml" + - "docs/compatibility/yardstick.md" + - "sidemantic/adapters/yardstick.py" + - "sidemantic/sql/aggregation_detection.py" + - "sidemantic/sql/query_rewriter.py" + - "tests/queries/test_yardstick_measures_replay.py" + - "tests/queries/test_yardstick_query_rewriter.py" + +permissions: + contents: read + +concurrency: + group: yardstick-upstream-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + upstream-parity: + name: Live upstream SQL replay + runs-on: ubuntu-latest + timeout-minutes: 15 + env: + SIDEMANTIC_YARDSTICK_UPSTREAM_TESTS: "1" + YARDSTICK_UPSTREAM_REF: ${{ github.event.inputs.yardstick_ref || 'main' }} + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --extra dev + + - name: Replay upstream Yardstick tests + run: uv run pytest -q tests/queries/test_yardstick_measures_replay.py -m yardstick_upstream diff --git a/docs/compatibility/yardstick.md b/docs/compatibility/yardstick.md index f27e82f6..8027472f 100644 --- a/docs/compatibility/yardstick.md +++ b/docs/compatibility/yardstick.md @@ -1,6 +1,6 @@ # Yardstick Compatibility -Sidemantic's Yardstick adapter parses SQL files containing `CREATE VIEW` statements that use the `AS MEASURE` syntax from Julian Hyde's ["Measures in SQL" proposal](https://arxiv.org/abs/2307.14009). It maps Yardstick concepts to Sidemantic's semantic model (Model, Dimension, Metric) and supports the `SEMANTIC SELECT`, `AGGREGATE()`, and `AT` query modifiers for measure-aware SQL queries. +Sidemantic's Yardstick adapter parses SQL files containing `CREATE VIEW` statements that use the `AS MEASURE` syntax from Julian Hyde's ["Measures in SQL" proposal](https://arxiv.org/abs/2307.14009). It maps Yardstick concepts to Sidemantic's semantic model (Model, Dimension, Metric) and supports `SEMANTIC SELECT`, optional-prefix `AGGREGATE()`, and `AT` query modifiers for measure-aware SQL queries. Features are marked **supported**, **partial support**, or **unsupported**. Partial support entries include notes explaining the limitation. @@ -110,15 +110,16 @@ Derived measure detection works by scanning the expression's column references a | `MODE(expr) AS MEASURE name` | Supported (stored as raw SQL expression metric with `agg=None`) | | `PERCENTILE_CONT(n) WITHIN GROUP (ORDER BY expr) AS MEASURE name` | Supported (stored as raw SQL expression metric) | | `CASE WHEN AGG(...) THEN ... END AS MEASURE name` | Supported (detected as having aggregate semantics; stored as raw SQL expression metric) | -| Other aggregate functions not in the standard list | Supported (full expression preserved as `Metric.sql`) | +| `PRODUCT(expr)`, `ENTROPY(expr)`, `KURTOSIS(expr)`, `SKEWNESS(expr)`, `LIST(expr)`, and related DuckDB aggregate functions | Supported (stored as raw SQL expression metrics with aggregate semantics) | +| Other aggregate functions not in the standard list | Supported when sqlglot identifies them as aggregates; otherwise preserved as raw SQL only when aggregate semantics can be detected | -When a measure expression contains aggregate functions (detected by walking the AST for `AggFunc` nodes or known anonymous aggregations like `mode`) but doesn't match a simple aggregation pattern, the full expression is preserved as-is for query-time evaluation. +When a measure expression contains aggregate functions (detected by walking the AST for `AggFunc` nodes or known anonymous aggregations like `mode`, `product`, and `entropy`) but doesn't match a simple aggregation pattern, the full expression is preserved as-is for query-time evaluation. --- ## Query Semantics -The Yardstick adapter works in tandem with Sidemantic's query rewriter to support the `SEMANTIC SELECT`, `AGGREGATE()`, and `AT` modifiers described in the Measures in SQL proposal. +The Yardstick adapter works in tandem with Sidemantic's query rewriter to support `SEMANTIC SELECT`, optional-prefix `AGGREGATE()`, and `AT` modifiers described in the Measures in SQL proposal. ### SEMANTIC Prefix @@ -126,7 +127,10 @@ The Yardstick adapter works in tandem with Sidemantic's query rewriter to suppor |---------|--------| | `SEMANTIC SELECT ...` | Supported (enables measure-aware query rewriting) | | `SEMANTIC WITH ... SELECT ...` | Supported (CTEs within semantic queries) | -| Implicit measure detection without `SEMANTIC` prefix | Supported (queries containing `AT` modifiers or curly-brace measure references are auto-detected) | +| `SELECT ... AGGREGATE(measure)` without `SEMANTIC` prefix | Supported | +| `CREATE TABLE ... AS SELECT ... AGGREGATE(...)` | Supported | +| `INSERT INTO ... SELECT ... AGGREGATE(...)` | Supported | +| Implicit measure detection without `SEMANTIC` prefix | Supported (queries containing `AGGREGATE()`, `AT` modifiers, or curly-brace measure references are auto-detected) | ### AGGREGATE() Function @@ -139,7 +143,36 @@ The Yardstick adapter works in tandem with Sidemantic's query rewriter to suppor | `AGGREGATE()` in arithmetic expressions (`2 * AGGREGATE(revenue)`) | Supported | | `AGGREGATE(measure) / AGGREGATE(measure) AT (...)` | Supported (each AGGREGATE evaluated independently) | | Scalar `AGGREGATE()` without GROUP BY | Supported (produces a single grand-total row) | -| `AGGREGATE()` without `SEMANTIC` prefix and without `AT` | Error: raises `ValueError` requiring the `SEMANTIC` prefix | +| `AGGREGATE()` without `SEMANTIC` prefix and without `AT` | Supported | +| Native DuckDB `aggregate(list, 'function')` | Supported (falls through to DuckDB; not treated as Yardstick syntax) | + +### Upstream Parity Tests + +The default test suite replays a vendored Yardstick `measures.test` fixture for stable CI coverage. To check against the live upstream Yardstick repository without copying fixtures into Sidemantic, run: + +```bash +SIDEMANTIC_YARDSTICK_UPSTREAM_TESTS=1 uv run pytest -q tests/queries/test_yardstick_measures_replay.py -m yardstick_upstream +``` + +The same command runs in the `Yardstick Upstream Parity` GitHub Actions workflow. That workflow runs nightly, can be triggered manually with a Yardstick ref override, and runs on pull requests that touch Yardstick-specific code or tests. + +The live replay fetches `https://github.com/sidequery/yardstick.git` at `main` by default, checks all upstream `test/sql/*.test` files, and validates both facets: + +| Facet | Coverage | +|-------|----------| +| Model/metric definitions | Parses every upstream `CREATE VIEW ... AS MEASURE` statement and asserts model name, source table/base SQL, primary key, Yardstick metadata, dimension SQL/type/granularity, and metric `agg`/`sql`/`filters`/`type` | +| Query execution | Replays every upstream query block against Sidemantic's Yardstick rewriter and compares result rows | + +The live definition check covers the `CREATE VIEW ... AS MEASURE` definitions used by Yardstick's SQL tests. Sidemantic's native SQL definition parser owns `MODEL(...)`, `METRIC(...)`, and `DIMENSION(...)` files separately from the Yardstick adapter; the live upstream replay does not treat Yardstick's top-level `yardstick_definitions.sql` helper file as part of the SQL-test corpus. + +Optional environment variables: + +| Variable | Purpose | +|----------|---------| +| `YARDSTICK_UPSTREAM_PATH` | Use an existing local Yardstick checkout instead of fetching | +| `YARDSTICK_UPSTREAM_REPO` | Override the upstream Git URL | +| `YARDSTICK_UPSTREAM_REF` | Override the ref fetched from upstream | +| `YARDSTICK_UPSTREAM_CACHE_DIR` | Override the temporary checkout path | ### AT Modifiers diff --git a/pyproject.toml b/pyproject.toml index 3d4c96d3..a4b46dc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,6 +169,7 @@ testpaths = ["tests"] pythonpath = ["."] markers = [ "integration: marks tests as integration tests requiring external services (deselect with '-m \"not integration\"')", + "yardstick_upstream: opt-in tests that fetch and replay live upstream Yardstick SQL tests", ] addopts = "-m 'not integration' --cov=sidemantic --cov-report=term-missing" # Skip integration tests by default and show coverage diff --git a/sidemantic/adapters/yardstick.py b/sidemantic/adapters/yardstick.py index 59de56b5..321b4b82 100644 --- a/sidemantic/adapters/yardstick.py +++ b/sidemantic/adapters/yardstick.py @@ -102,7 +102,15 @@ def __init__(self, dialect: str = "duckdb"): exp.Variance: "variance", exp.VariancePop: "variance_pop", } - _ANONYMOUS_AGGREGATIONS: set[str] = {"mode"} + _ANONYMOUS_AGGREGATIONS: set[str] = { + "entropy", + "geometric_mean", + "kurtosis", + "mode", + "product", + "skewness", + "weighted_avg", + } def parse(self, source: str | Path) -> SemanticGraph: """Parse Yardstick SQL files into a semantic graph.""" @@ -344,6 +352,8 @@ def _has_aggregate_semantics(self, expression: exp.Expression) -> bool: return True for node in expression.walk(): + if isinstance(node, exp.List): + return True if isinstance(node, exp.Anonymous) and (node.name or "").lower() in self._ANONYMOUS_AGGREGATIONS: return True return False diff --git a/sidemantic/sql/aggregation_detection.py b/sidemantic/sql/aggregation_detection.py index 4b8ce581..90743d41 100644 --- a/sidemantic/sql/aggregation_detection.py +++ b/sidemantic/sql/aggregation_detection.py @@ -10,11 +10,20 @@ # SQLGlot treats some engine-specific aggregate functions as Anonymous. # Keep this focused on known aggregate forms we need to support. _ANONYMOUS_AGGREGATE_FUNCTIONS = { + "entropy", + "geometric_mean", + "kurtosis", "mode", + "product", + "skewness", + "weighted_avg", } _AGGREGATE_REGEX = re.compile( - r"\b(sum|count|avg|min|max|median|stddev|stddev_pop|variance|variance_pop|mode|quantile|percentile)\s*\(", + r"\b(" + r"sum|count|avg|min|max|median|stddev|stddev_pop|variance|variance_pop|mode|" + r"quantile|percentile|product|entropy|kurtosis|skewness|geometric_mean|weighted_avg|list" + r")\s*\(", re.IGNORECASE, ) @@ -25,6 +34,8 @@ def expression_has_aggregate(expression: exp.Expression) -> bool: return True for node in expression.walk(): + if isinstance(node, exp.List): + return True if isinstance(node, exp.Anonymous) and (node.name or "").lower() in _ANONYMOUS_AGGREGATE_FUNCTIONS: return True diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 17bb87db..eba63b39 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -4,6 +4,7 @@ """ import os +import re from dataclasses import dataclass from typing import Any @@ -17,6 +18,8 @@ from sidemantic.sql.generator import SQLGenerator from sidemantic.sql.planner import CandidatePlan, RewriteExplanation, SemanticQueryPlan +_YARDSTICK_SYNTAX_HINT_RE = re.compile(r"\b(?:SEMANTIC|AGGREGATE|AT)\b|\{", re.IGNORECASE) + @dataclass class _YardstickAggregateCall: @@ -215,7 +218,7 @@ def cache_result(result: str) -> str: self._raise_on_user_cte_name_collision(parsed) return cache_result(self._rewrite_set_operation(parsed)) - if self._contains_implicit_yardstick_measure_query(parsed): + if self._graph_has_yardstick_models() and self._contains_implicit_yardstick_measure_query(parsed): try: return cache_result(self._rewrite_yardstick_query(sql, strict=strict, allow_plain_measures=True)) except Exception: @@ -365,7 +368,7 @@ def explain(self, sql: str, strict: bool = True) -> RewriteExplanation: self._raise_on_user_cte_name_collision(parsed) return self._explain_set_operation(sql, parsed) - if self._contains_implicit_yardstick_measure_query(parsed): + if self._graph_has_yardstick_models() and self._contains_implicit_yardstick_measure_query(parsed): try: rewritten_sql = self._rewrite_yardstick_query(sql, strict=strict, allow_plain_measures=True) except Exception as e: @@ -2438,7 +2441,7 @@ def _explanation_from_plan( rejected_rules=plan.rejected_rules, ) - def _plan_simple_query(self, parsed: exp.Select) -> SemanticQueryPlan: + def _plan_simple_query(self, parsed: exp.Select, include_candidate_details: bool = True) -> SemanticQueryPlan: """Build a behavior-preserving semantic query plan for a simple SELECT.""" explicit_join_filters = [] if parsed.args.get("joins"): @@ -2478,9 +2481,13 @@ def _plan_simple_query(self, parsed: exp.Select) -> SemanticQueryPlan: offset=offset, aliases=aliases, ) - plan.eligibility = self._plan_eligibility(plan) + plan.eligibility = ( + self._plan_eligibility(plan) + if include_candidate_details or self.use_preaggregations + else self._lightweight_plan_eligibility(plan) + ) plan.candidate_kind = self._chosen_candidate_kind(plan) - plan.candidate_plans = self._candidate_plans_for_plan(plan) + plan.candidate_plans = self._candidate_plans_for_plan(plan) if include_candidate_details else [] if plan.candidate_kind == "single_model_preaggregation": plan.applied_rules.append("preaggregation_route_selection") if plan.candidate_kind == "join_key_preaggregation": @@ -2565,6 +2572,31 @@ def _plan_eligibility(self, plan: SemanticQueryPlan) -> dict[str, dict[str, obje "single_model_preaggregation": self._single_model_preaggregation_eligibility(plan), } + def _lightweight_plan_eligibility(self, plan: SemanticQueryPlan) -> dict[str, dict[str, object]]: + return { + "window_metric": { + "eligible": False, + "metrics": [], + "reason": "not_evaluated_for_rewrite", + }, + "fanout_preaggregation": { + "eligible": False, + "reason": "not_evaluated_for_rewrite", + }, + "join_key_preaggregation": { + "eligible": False, + "reason": "preaggregations_disabled", + "enabled": False, + "requires_enablement": True, + }, + "single_model_preaggregation": { + "eligible": False, + "reason": "preaggregations_disabled", + "enabled": False, + "requires_enablement": True, + }, + } + def _join_key_preaggregation_eligibility(self, plan: SemanticQueryPlan) -> dict[str, object]: details = self.generator.explain_join_key_preaggregation( metrics=plan.metrics, @@ -2900,6 +2932,9 @@ def _dedupe(self, values: list[str]) -> list[str]: def _looks_like_yardstick_query(self, sql: str) -> bool: """Return True if query appears to use Yardstick query syntax.""" + if not _YARDSTICK_SYNTAX_HINT_RE.search(sql): + return False + try: tokens = sqlglot.tokenize(sql, read=self.dialect) except Exception: @@ -2933,6 +2968,11 @@ def _looks_like_yardstick_query(self, sql: str) -> bool: return False + def _graph_has_yardstick_models(self) -> bool: + return any( + isinstance(model.metadata, dict) and "yardstick" in model.metadata for model in self.graph.models.values() + ) + def _is_yardstick_identifier_token(self, token) -> bool: return token.token_type in {TokenType.VAR, TokenType.IDENTIFIER, TokenType.SCHEMA} @@ -3059,11 +3099,16 @@ def _is_yardstick_model(self, model_name: str) -> bool: def _rewrite_yardstick_query(self, sql: str, strict: bool = True, allow_plain_measures: bool = False) -> str: """Rewrite Yardstick-style SQL (`SEMANTIC`, `AGGREGATE`, `AT`) to plain SQL.""" + original_sql = sql + tokens = sqlglot.tokenize(sql, read=self.dialect) + has_semantic_prefix = bool(tokens and tokens[0].text.upper() == "SEMANTIC") sql = self._expand_yardstick_curly_measure_references(sql) transformed_sql, calls = self._replace_yardstick_aggregate_calls(sql) # SEMANTIC prefix without AGGREGATE: fall back to normal SQL rewrite path. if not calls and not allow_plain_measures: + if not has_semantic_prefix and transformed_sql == original_sql: + return transformed_sql return self.rewrite(transformed_sql, strict=strict) try: @@ -3071,15 +3116,16 @@ def _rewrite_yardstick_query(self, sql: str, strict: bool = True, allow_plain_me except Exception as e: raise ValueError(f"Failed to parse Yardstick SQL: {e}") from e - if not isinstance(parsed, exp.Select): - raise ValueError("Yardstick rewrite currently supports SELECT queries only") + select_scopes = list(parsed.find_all(exp.Select)) + if not select_scopes: + raise ValueError("Yardstick rewrite requires a SELECT query or statement containing SELECT") call_map = {call.placeholder: call for call in calls} placeholder_names = set(call_map) rewritten_root: exp.Expression = parsed # Rewrite innermost SELECT scopes first so nested Yardstick placeholders are # resolved in their own FROM/JOIN context before outer scopes are processed. - for select_scope in reversed(list(parsed.find_all(exp.Select))): + for select_scope in reversed(select_scopes): rewritten_scope = self._rewrite_yardstick_select_scope( select_scope, call_map=call_map, @@ -3251,8 +3297,169 @@ def replace_placeholder(node: exp.Expression) -> exp.Expression: return node rewritten = select_scope.transform(replace_placeholder) + rewritten = self._inline_yardstick_order_by_subquery_aliases(rewritten) return self._rewrite_source_model_relations(rewritten) + def _inline_yardstick_order_by_subquery_aliases(self, select_scope: exp.Select) -> exp.Select: + """Inline subquery-backed SELECT aliases inside compound ORDER BY expressions. + + DuckDB accepts ``ORDER BY alias`` for a select item whose expression is a + scalar subquery, but it rejects compound expressions such as + ``ORDER BY metric_alias / total_alias`` when either alias expands to a + scalar subquery. Yardstick upstream fixes this by inlining only the + subquery-backed aliases in non-simple ORDER BY expressions. + """ + order_clause = select_scope.args.get("order") + if not order_clause: + return select_scope + + aliases: dict[str, tuple[exp.Expression, bool]] = {} + has_subquery_alias = False + for projection in select_scope.expressions: + if not isinstance(projection, exp.Alias) or not projection.alias: + continue + has_subquery = self._expression_has_subquery(projection.this) + aliases[projection.alias.lower()] = (projection.this, has_subquery) + has_subquery_alias = has_subquery_alias or has_subquery + + if not aliases or not has_subquery_alias: + return select_scope + + table_qualifiers = self._yardstick_order_table_qualifiers(select_scope) + for order_expr in order_clause.expressions: + expr_obj = order_expr.this + if self._is_simple_yardstick_order_alias_ref(expr_obj, aliases, table_qualifiers): + continue + if not self._yardstick_order_expr_references_subquery_alias(expr_obj, aliases, table_qualifiers): + continue + + rewritten_expr, changed = self._inline_yardstick_order_alias_expr( + expr_obj, + aliases, + table_qualifiers, + ) + if changed: + order_expr.set("this", rewritten_expr) + + return select_scope + + def _expression_has_subquery(self, expression: exp.Expression) -> bool: + return any(isinstance(node, exp.Subquery) for node in expression.walk()) + + def _yardstick_order_table_qualifiers(self, select_scope: exp.Select) -> set[str]: + qualifiers: set[str] = set() + + def add_relation(relation: exp.Expression | None) -> None: + if relation is None: + return + alias = relation.alias + if alias: + qualifiers.add(alias.lower()) + return + if isinstance(relation, exp.Table): + qualifiers.add(relation.name.lower()) + + from_clause = select_scope.args.get("from") + if from_clause: + add_relation(from_clause.this) + + for join in select_scope.args.get("joins") or []: + add_relation(join.this) + + return qualifiers + + def _yardstick_order_alias_key( + self, + expression: exp.Expression, + aliases: dict[str, tuple[exp.Expression, bool]], + table_qualifiers: set[str], + ) -> str | None: + if not isinstance(expression, exp.Column): + return None + + table = expression.table + if table and not (table.lower() == "alias" and "alias" not in table_qualifiers): + return None + + alias_key = expression.name.lower() + if alias_key not in aliases: + return None + return alias_key + + def _is_simple_yardstick_order_alias_ref( + self, + expression: exp.Expression, + aliases: dict[str, tuple[exp.Expression, bool]], + table_qualifiers: set[str], + ) -> bool: + return self._yardstick_order_alias_key(expression, aliases, table_qualifiers) is not None + + def _yardstick_order_expr_references_subquery_alias( + self, + expression: exp.Expression, + aliases: dict[str, tuple[exp.Expression, bool]], + table_qualifiers: set[str], + ) -> bool: + if isinstance(expression, (exp.Select, exp.Subquery)): + return False + + alias_key = self._yardstick_order_alias_key(expression, aliases, table_qualifiers) + if alias_key is not None and aliases[alias_key][1]: + return True + + for value in expression.args.values(): + if isinstance(value, exp.Expression): + if self._yardstick_order_expr_references_subquery_alias(value, aliases, table_qualifiers): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, exp.Expression) and self._yardstick_order_expr_references_subquery_alias( + item, + aliases, + table_qualifiers, + ): + return True + return False + + def _inline_yardstick_order_alias_expr( + self, + expression: exp.Expression, + aliases: dict[str, tuple[exp.Expression, bool]], + table_qualifiers: set[str], + ) -> tuple[exp.Expression, bool]: + if isinstance(expression, (exp.Select, exp.Subquery)): + return expression, False + + alias_key = self._yardstick_order_alias_key(expression, aliases, table_qualifiers) + if alias_key is not None and aliases[alias_key][1]: + return aliases[alias_key][0].copy(), True + + rewritten = expression.copy() + changed = False + for arg_key, value in list(rewritten.args.items()): + if isinstance(value, exp.Expression): + child, child_changed = self._inline_yardstick_order_alias_expr(value, aliases, table_qualifiers) + if child_changed: + rewritten.set(arg_key, child) + changed = True + elif isinstance(value, list): + new_values = [] + list_changed = False + for item in value: + if isinstance(item, exp.Expression): + child, child_changed = self._inline_yardstick_order_alias_expr(item, aliases, table_qualifiers) + new_values.append(child) + list_changed = list_changed or child_changed + else: + new_values.append(item) + if list_changed: + rewritten.set(arg_key, new_values) + changed = True + + if not changed: + return expression, False + return rewritten, True + def _collect_yardstick_group_expressions(self, group_clause: exp.Group) -> list[exp.Expression]: expressions: list[exp.Expression] = [] expressions.extend(group_clause.expressions) @@ -3330,16 +3537,12 @@ def _replace_yardstick_aggregate_calls(self, sql: str) -> tuple[str, list[_Yards segments: list[str] = [] cursor = 0 i = 0 - has_semantic_prefix = False - has_any_at_syntax = False - plain_aggregate_calls_without_at = 0 if tokens and tokens[0].text.upper() == "SEMANTIC": cursor = tokens[0].end + 1 while cursor < len(sql) and sql[cursor].isspace(): cursor += 1 i = 1 - has_semantic_prefix = True def parse_at_chain(start_idx: int, default_end_idx: int) -> tuple[list[str], int]: modifiers: list[str] = [] @@ -3405,14 +3608,12 @@ def parse_at_chain(start_idx: int, default_end_idx: int) -> tuple[list[str], int arg_start = tokens[i + 1].end + 1 arg_end = tokens[j].start argument_sql = sql[arg_start:arg_end].strip() + if not self._is_yardstick_aggregate_argument(argument_sql): + i = j + 1 + continue modifiers, end_idx = parse_at_chain(j + 1, j) - if modifiers: - has_any_at_syntax = True - else: - plain_aggregate_calls_without_at += 1 - placeholder = f"__ysagg_{len(calls)}" calls.append( _YardstickAggregateCall( @@ -3458,7 +3659,6 @@ def parse_at_chain(start_idx: int, default_end_idx: int) -> tuple[list[str], int argument_sql = sql[tokens[measure_start].start : tokens[measure_end].end + 1].strip() modifiers, end_idx = parse_at_chain(at_index, measure_end) if modifiers: - has_any_at_syntax = True placeholder = f"__ysagg_{len(calls)}" calls.append( _YardstickAggregateCall( @@ -3477,12 +3677,16 @@ def parse_at_chain(start_idx: int, default_end_idx: int) -> tuple[list[str], int i += 1 - if plain_aggregate_calls_without_at and not has_semantic_prefix and not has_any_at_syntax: - raise ValueError("AGGREGATE(...) without AT (...) requires the SEMANTIC prefix") - segments.append(sql[cursor:]) return "".join(segments), calls + def _is_yardstick_aggregate_argument(self, argument_sql: str) -> bool: + try: + argument = sqlglot.parse_one(argument_sql, dialect=self.dialect) + except Exception: + return False + return isinstance(argument, exp.Column) + def _extract_source_models_from_select(self, select: exp.Select) -> dict[str, str]: """Map SQL source aliases to semantic model names.""" alias_to_model: dict[str, str] = {} @@ -5267,7 +5471,7 @@ def _rewrite_simple_query(self, parsed: exp.Select) -> str: if self._needs_expression_postprocess(parsed): return self._rewrite_expression_query(parsed, extra_filters=explicit_join_filters) - plan = self._plan_simple_query(parsed) + plan = self._plan_simple_query(parsed, include_candidate_details=False) return self._generate_from_plan(plan) def _validate_explicit_semantic_joins(self, select: exp.Select) -> list[str]: @@ -5762,6 +5966,9 @@ def _stage_semantic_filters(self, filters: list[str]) -> tuple[list[str], list[s return row_filters, aggregate_filters def _semantic_filter_references_metric(self, filter_expr: str) -> bool: + if not self._filter_may_reference_metric(filter_expr): + return False + try: parsed = sqlglot.parse_one(filter_expr, dialect=self.dialect) except Exception: @@ -5784,6 +5991,17 @@ def _semantic_filter_references_metric(self, filter_expr: str) -> bool: return True return False + def _filter_may_reference_metric(self, filter_expr: str) -> bool: + lower_filter = filter_expr.lower() + for metric_name in self.graph.metrics: + if metric_name.lower() in lower_filter: + return True + for model in self.graph.models.values(): + for metric in model.metrics: + if metric.name.lower() in lower_filter: + return True + return False + def _extract_compound_filters(self, condition: exp.Expression) -> list[str]: """Extract filters from compound AND/OR conditions. diff --git a/tests/queries/test_yardstick_measures_replay.py b/tests/queries/test_yardstick_measures_replay.py index fd8b8689..4146dabd 100644 --- a/tests/queries/test_yardstick_measures_replay.py +++ b/tests/queries/test_yardstick_measures_replay.py @@ -4,6 +4,8 @@ import os import re +import subprocess +import tempfile from dataclasses import dataclass from datetime import date, datetime, time from pathlib import Path @@ -34,6 +36,27 @@ class _QueryBlock: rowsort: bool +@dataclass +class _ExpectedDimension: + name: str + sql: str + type: str + granularity: str | None + + +@dataclass +class _ExpectedMetric: + name: str + agg: str | None + sql: str | None + filters: list[str] | None + type: str | None + + +def _env_flag(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} + + def _yardstick_measures_test_path() -> Path: override = os.environ.get("YARDSTICK_MEASURES_TEST_PATH") if override: @@ -46,6 +69,70 @@ def _yardstick_measures_test_path() -> Path: return Path("~/Code/yardstick/test/sql/measures.test").expanduser() +def _run_git(args: list[str], cwd: Path | None = None) -> str: + try: + result = subprocess.run( + ["git", *args], + cwd=cwd, + check=True, + capture_output=True, + text=True, + ) + except FileNotFoundError: + pytest.fail("git is required to run live upstream Yardstick parity tests") + except subprocess.CalledProcessError as exc: + pytest.fail( + "\n".join( + [ + "Failed to fetch live upstream Yardstick tests", + f"Command: git {' '.join(args)}", + f"stdout: {exc.stdout.strip()}", + f"stderr: {exc.stderr.strip()}", + ] + ) + ) + return result.stdout.strip() + + +def _yardstick_upstream_checkout() -> Path: + override = os.environ.get("YARDSTICK_UPSTREAM_PATH") + if override: + path = Path(override).expanduser() + if not path.exists(): + pytest.fail(f"YARDSTICK_UPSTREAM_PATH does not exist: {path}") + return path + + if not _env_flag("SIDEMANTIC_YARDSTICK_UPSTREAM_TESTS"): + pytest.skip("Set SIDEMANTIC_YARDSTICK_UPSTREAM_TESTS=1 to fetch and replay live upstream Yardstick tests") + + repo_url = os.environ.get("YARDSTICK_UPSTREAM_REPO", "https://github.com/sidequery/yardstick.git") + ref = os.environ.get("YARDSTICK_UPSTREAM_REF", "main") + checkout = Path( + os.environ.get( + "YARDSTICK_UPSTREAM_CACHE_DIR", + str(Path(tempfile.gettempdir()) / "sidemantic-yardstick-upstream"), + ) + ).expanduser() + + if not (checkout / ".git").exists(): + if checkout.exists() and any(checkout.iterdir()): + pytest.fail(f"YARDSTICK_UPSTREAM_CACHE_DIR exists but is not a git checkout: {checkout}") + checkout.parent.mkdir(parents=True, exist_ok=True) + _run_git(["clone", "--depth", "1", repo_url, str(checkout)]) + + _run_git(["fetch", "--depth", "1", "origin", ref], cwd=checkout) + _run_git(["checkout", "--detach", "FETCH_HEAD"], cwd=checkout) + return checkout + + +def _yardstick_upstream_sql_test_paths() -> list[Path]: + checkout = _yardstick_upstream_checkout() + paths = sorted((checkout / "test" / "sql").glob("*.test")) + if not paths: + pytest.fail(f"No upstream Yardstick SQL test files found under {checkout / 'test' / 'sql'}") + return paths + + def _parse_measures_test(path: Path) -> tuple[list[_StatementBlock], list[_QueryBlock]]: lines = path.read_text().splitlines() statements: list[_StatementBlock] = [] @@ -163,6 +250,10 @@ def _rows_to_lines(rows: list[tuple[object, ...]]) -> list[str]: def _parse_expected_cell(token: str) -> object: if token == "NULL": return None + if token.lower() == "true": + return True + if token.lower() == "false": + return False if _INT_RE.match(token): return int(token) if _FLOAT_RE.match(token): @@ -174,6 +265,17 @@ def _cell_matches(actual: object, expected: object) -> bool: if expected is None: return actual is None + if isinstance(expected, str) and isinstance(actual, (date, datetime, time)): + actual_text = _stringify_value(actual) + if actual_text == expected: + return True + if isinstance(actual, date) and not isinstance(actual, datetime): + return f"{actual.isoformat()} 00:00:00" == expected + return False + + if isinstance(expected, bool): + return actual is expected + if isinstance(expected, float): if actual is None: return False @@ -277,13 +379,349 @@ def _execute_statement_sql(layer: SemanticLayer, adapter: YardstickAdapter, sql: layer.adapter.execute(f"CREATE VIEW {model.name} AS SELECT * FROM {model.table}") return - if sql.lstrip().upper().startswith("SEMANTIC "): + statement_head = sql.lstrip().upper() + if statement_head.startswith(("SEMANTIC ", "SELECT ", "WITH ")) or "AGGREGATE" in statement_head: layer.sql(sql) return layer.adapter.execute(sql) +def _create_view_model_from_statement( + adapter: YardstickAdapter, + sql: str, +) -> tuple[exp.Create, exp.Select, object] | None: + try: + parsed = adapter._parse_statements(sql) + except Exception: + return None + + if not parsed: + return None + + statement = parsed[0] + if not isinstance(statement, exp.Create) or (statement.args.get("kind") or "").upper() != "VIEW": + return None + + select = statement.expression + if not isinstance(select, exp.Select): + return None + + model = adapter._model_from_create_view(statement, select) + if model is None: + return None + + return statement, select, model + + +def _expected_model_source(select: exp.Select, dialect: str) -> tuple[str | None, str | None]: + from_clause = select.args.get("from") + joins = select.args.get("joins") or [] + where_clause = select.args.get("where") + with_clause = select.args.get("with") + + if ( + isinstance(from_clause, exp.From) + and isinstance(from_clause.this, exp.Table) + and not joins + and where_clause is None + and with_clause is None + ): + table_expr = from_clause.this + if isinstance(table_expr.this, exp.Identifier) and table_expr.args.get("alias") is None: + return table_expr.sql(dialect=dialect), None + + if from_clause is None: + return None, None + + base_relation = exp.select("*") + if with_clause is not None: + base_relation.set("with", with_clause.copy()) + base_relation.set("from", from_clause.copy()) + if joins: + base_relation.set("joins", [join.copy() for join in joins]) + if where_clause is not None: + base_relation.set("where", where_clause.copy()) + + return None, base_relation.sql(dialect=dialect) + + +def _expected_dimension_type(expression: exp.Expression) -> tuple[str, str | None]: + if isinstance(expression, exp.Boolean): + return "boolean", None + if isinstance(expression, exp.Literal): + if expression.is_number: + return "numeric", None + return "categorical", None + if isinstance(expression, exp.Column): + column_name = expression.name.lower() + if "timestamp" in column_name: + return "time", "second" + if "date" in column_name: + return "time", "day" + if "time" in column_name: + return "time", "second" + return "categorical", None + if isinstance(expression, exp.Func): + granularity_by_func = { + "date": "day", + "date_trunc": "day", + "year": "year", + "quarter": "quarter", + "month": "month", + "week": "week", + "day": "day", + "hour": "hour", + "minute": "minute", + } + function_name = (expression.name or "").lower() + if function_name in granularity_by_func: + return "time", granularity_by_func[function_name] + return "categorical", None + + +_EXPECTED_SIMPLE_AGGREGATIONS: tuple[tuple[type[exp.Expression], str], ...] = ( + (exp.Sum, "sum"), + (exp.Avg, "avg"), + (exp.Min, "min"), + (exp.Max, "max"), + (exp.Median, "median"), + (exp.Stddev, "stddev"), + (exp.StddevPop, "stddev_pop"), + (exp.Variance, "variance"), + (exp.VariancePop, "variance_pop"), +) +_EXPECTED_SUPPORTED_FUNCTION_AGGS = { + "avg", + "max", + "median", + "min", + "stddev", + "stddev_pop", + "sum", + "variance", + "variance_pop", +} +_EXPECTED_ANONYMOUS_AGGREGATIONS = { + "entropy", + "geometric_mean", + "kurtosis", + "mode", + "product", + "skewness", + "weighted_avg", +} + + +def _expected_count_aggregation(expression: exp.Expression, dialect: str) -> tuple[str, str] | None: + count_expr = None + if isinstance(expression, exp.Count): + count_expr = expression.this + elif isinstance(expression, exp.Func) and (expression.name or "").lower() == "count": + count_expr = expression.this or (expression.expressions[0] if expression.expressions else None) + else: + return None + + if isinstance(count_expr, exp.Distinct): + if count_expr.expressions: + return "count_distinct", ", ".join(expr.sql(dialect=dialect) for expr in count_expr.expressions) + return "count_distinct", count_expr.sql(dialect=dialect) + + if count_expr is None or isinstance(count_expr, exp.Star): + return "count", "*" + return "count", count_expr.sql(dialect=dialect) + + +def _expected_supported_aggregation(expression: exp.Expression, dialect: str) -> tuple[str, str] | None: + count_aggregation = _expected_count_aggregation(expression, dialect) + if count_aggregation is not None: + return count_aggregation + + for expression_type, aggregation_name in _EXPECTED_SIMPLE_AGGREGATIONS: + if isinstance(expression, expression_type): + inner_expression = expression.this + if inner_expression is None: + return aggregation_name, "*" + return aggregation_name, inner_expression.sql(dialect=dialect) + + if isinstance(expression, exp.Func): + function_name = (expression.name or "").lower() + if function_name in _EXPECTED_SUPPORTED_FUNCTION_AGGS: + inner_expression = expression.this or (expression.expressions[0] if expression.expressions else None) + if inner_expression is None: + return function_name, "*" + return function_name, inner_expression.sql(dialect=dialect) + + return None + + +def _expected_filtered_aggregation( + expression: exp.Expression, + dialect: str, +) -> tuple[str, str, list[str] | None] | None: + if not isinstance(expression, exp.Filter): + return None + + aggregation = _expected_supported_aggregation(expression.this, dialect) + if aggregation is None: + return None + + agg, inner_sql = aggregation + where_expression = expression.args.get("expression") + if isinstance(where_expression, exp.Where): + filter_sql = where_expression.this.sql(dialect=dialect) + elif isinstance(where_expression, exp.Expression): + filter_sql = where_expression.sql(dialect=dialect) + else: + filter_sql = "" + + return agg, inner_sql, [filter_sql] if filter_sql else None + + +def _expected_has_aggregate_semantics(expression: exp.Expression) -> bool: + if any(isinstance(node, exp.AggFunc) for node in expression.walk()): + return True + + for node in expression.walk(): + if isinstance(node, exp.List): + return True + if isinstance(node, exp.Anonymous) and (node.name or "").lower() in _EXPECTED_ANONYMOUS_AGGREGATIONS: + return True + return False + + +def _expected_references_other_measures( + name: str, + expression: exp.Expression, + all_measure_names: set[str], +) -> bool: + measure_lookup = { + measure_name.lower() for measure_name in all_measure_names if measure_name.lower() != name.lower() + } + referenced_columns = {column.name.lower() for column in expression.find_all(exp.Column)} + return bool(referenced_columns & measure_lookup) + + +def _expected_metric_from_expression( + name: str, + expression: exp.Expression, + all_measure_names: set[str], + dialect: str, +) -> _ExpectedMetric: + expression_sql = expression.sql(dialect=dialect) + + if _expected_references_other_measures(name, expression, all_measure_names): + return _ExpectedMetric(name=name, agg=None, sql=expression_sql, filters=None, type="derived") + + filtered_aggregation = _expected_filtered_aggregation(expression, dialect) + if filtered_aggregation is not None: + agg, inner_sql, filters = filtered_aggregation + return _ExpectedMetric(name=name, agg=agg, sql=inner_sql, filters=filters, type=None) + + simple_aggregation = _expected_supported_aggregation(expression, dialect) + if simple_aggregation is not None: + agg, inner_sql = simple_aggregation + return _ExpectedMetric(name=name, agg=agg, sql=inner_sql, filters=None, type=None) + + if _expected_has_aggregate_semantics(expression): + return _ExpectedMetric(name=name, agg=None, sql=expression_sql, filters=None, type=None) + + return _ExpectedMetric(name=name, agg=None, sql=expression_sql, filters=None, type="derived") + + +def _expected_projections( + select: exp.Select, + dialect: str, +) -> tuple[list[_ExpectedDimension], list[_ExpectedMetric]]: + measure_aliases = { + projection.output_name + for projection in select.expressions + if isinstance(projection, exp.Alias) and projection.args.get("yardstick_measure") + } + expected_dimensions: list[_ExpectedDimension] = [] + expected_metrics: list[_ExpectedMetric] = [] + + for projection in select.expressions: + output_name = projection.output_name + if not output_name: + continue + + expression = projection.this if isinstance(projection, exp.Alias) else projection + if output_name in measure_aliases: + expected_metrics.append( + _expected_metric_from_expression( + output_name, + expression, + all_measure_names=set(measure_aliases), + dialect=dialect, + ) + ) + continue + + if isinstance(expression, exp.Star): + continue + + dimension_type, granularity = _expected_dimension_type(expression) + expected_dimensions.append( + _ExpectedDimension( + name=output_name, + sql=expression.sql(dialect=dialect), + type=dimension_type, + granularity=granularity, + ) + ) + return expected_dimensions, expected_metrics + + +def _assert_definition_blocks_match(path: Path) -> None: + statements, _queries = _parse_measures_test(path) + adapter = YardstickAdapter() + checked_models: list[str] = [] + + for statement_block in statements: + parsed_model = _create_view_model_from_statement(adapter, statement_block.sql) + if parsed_model is None: + continue + + create_statement, select, model = parsed_model + expected_source_table, expected_source_sql = _expected_model_source(select, adapter.dialect) + expected_dimensions, expected_metrics = _expected_projections(select, adapter.dialect) + view_name = create_statement.this.name if isinstance(create_statement.this, exp.Table) else None + context = f"{path}:{statement_block.line}" + if view_name: + context = f"{context} ({view_name})" + + assert model.name == view_name, context + assert model.primary_key == (expected_dimensions[0].name if expected_dimensions else "id"), context + assert model.table == (expected_source_table or (view_name if expected_source_sql is None else None)), context + assert model.sql == expected_source_sql, context + + yardstick_metadata = (model.metadata or {}).get("yardstick") + assert isinstance(yardstick_metadata, dict), context + assert yardstick_metadata.get("view_sql") == select.sql(dialect=adapter.dialect), context + assert yardstick_metadata.get("base_table") == expected_source_table, context + assert yardstick_metadata.get("base_relation_sql") == expected_source_sql, context + + assert len(model.dimensions) == len(expected_dimensions), context + for actual_dimension, expected_dimension in zip(model.dimensions, expected_dimensions, strict=True): + assert actual_dimension.name == expected_dimension.name, context + assert actual_dimension.sql == expected_dimension.sql, context + assert actual_dimension.type == expected_dimension.type, context + assert actual_dimension.granularity == expected_dimension.granularity, context + + assert len(model.metrics) == len(expected_metrics), context + for actual_metric, expected_metric in zip(model.metrics, expected_metrics, strict=True): + assert actual_metric.name == expected_metric.name, context + assert actual_metric.agg == expected_metric.agg, context + assert actual_metric.sql == expected_metric.sql, context + assert actual_metric.filters == expected_metric.filters, context + assert actual_metric.type == expected_metric.type, context + + checked_models.append(model.name) + + assert checked_models, f"No Yardstick CREATE VIEW ... AS MEASURE definitions found in {path}" + + def _apply_statement(layer: SemanticLayer, adapter: YardstickAdapter, statement: _StatementBlock) -> None: if not statement.expect_error: _execute_statement_sql(layer, adapter, statement.sql) @@ -311,13 +749,12 @@ def _apply_statement(layer: SemanticLayer, adapter: YardstickAdapter, statement: ) -def test_yardstick_measures_test_replay(): - path = _yardstick_measures_test_path() +def _replay_yardstick_sql_test(path: Path) -> None: if not path.exists(): - pytest.skip(f"Yardstick measures.test not found: {path}") + pytest.skip(f"Yardstick SQL test file not found: {path}") statements, queries = _parse_measures_test(path) - assert queries, f"No query blocks parsed from measures.test: {path}" + assert queries, f"No query blocks parsed from Yardstick SQL test: {path}" layer = SemanticLayer(connection="duckdb:///:memory:") adapter = YardstickAdapter() @@ -335,7 +772,7 @@ def test_yardstick_measures_test_replay(): pytest.fail( "\n".join( [ - f"Execution failed for query at line {query.line}: {query.header}", + f"Execution failed for query at {path}:{query.line}: {query.header}", "SQL:", query.sql, f"Error: {exc}", @@ -343,3 +780,19 @@ def test_yardstick_measures_test_replay(): ) ) _assert_query_rows_match(query, actual_rows) + + +def test_yardstick_measures_test_replay(): + _replay_yardstick_sql_test(_yardstick_measures_test_path()) + + +@pytest.mark.yardstick_upstream +def test_yardstick_upstream_create_view_definitions(): + for path in _yardstick_upstream_sql_test_paths(): + _assert_definition_blocks_match(path) + + +@pytest.mark.yardstick_upstream +def test_yardstick_upstream_sql_replay(): + for path in _yardstick_upstream_sql_test_paths(): + _replay_yardstick_sql_test(path) diff --git a/tests/queries/test_yardstick_query_rewriter.py b/tests/queries/test_yardstick_query_rewriter.py index 05c12ff9..6e7986cf 100644 --- a/tests/queries/test_yardstick_query_rewriter.py +++ b/tests/queries/test_yardstick_query_rewriter.py @@ -733,9 +733,83 @@ def test_yardstick_at_visible_without_where_is_identity(yardstick_layer): } -def test_yardstick_aggregate_without_at_requires_semantic(yardstick_layer): - with pytest.raises(ValueError, match="requires the SEMANTIC prefix"): - yardstick_layer.sql("SELECT AGGREGATE(revenue) AS revenue FROM sales_v") +def test_yardstick_aggregate_without_semantic_prefix(yardstick_layer): + result = yardstick_layer.sql( + """ +SELECT + year, + region, + AGGREGATE(revenue) AS revenue +FROM sales_v +ORDER BY year, region +""" + ) + rows = fetch_dicts(result) + assert [(row["year"], row["region"], float(row["revenue"])) for row in rows] == [ + (2022, "EU", 50.0), + (2022, "US", 100.0), + (2023, "EU", 75.0), + (2023, "US", 150.0), + ] + + +def test_yardstick_order_by_expression_can_reference_subquery_aliases(yardstick_layer): + rows = fetch_dicts( + yardstick_layer.sql( + """ +SEMANTIC SELECT + year, + region, + AGGREGATE(revenue) AS revenue, + AGGREGATE(revenue) AT (ALL region) AS year_total +FROM sales_v +ORDER BY revenue / year_total, year, region +""" + ) + ) + assert [(row["year"], row["region"], float(row["revenue"]), float(row["year_total"])) for row in rows] == [ + (2022, "EU", 50.0, 150.0), + (2023, "EU", 75.0, 225.0), + (2022, "US", 100.0, 150.0), + (2023, "US", 150.0, 225.0), + ] + + +def test_yardstick_ctas_and_insert_select_with_aggregate(yardstick_layer): + yardstick_layer.sql( + """ +CREATE TABLE ctas_result AS +SELECT year, region, AGGREGATE(revenue) AS revenue +FROM sales_v +""" + ) + ctas_rows = fetch_dicts(yardstick_layer.adapter.execute("SELECT * FROM ctas_result ORDER BY year, region")) + assert [(row["year"], row["region"], float(row["revenue"])) for row in ctas_rows] == [ + (2022, "EU", 50.0), + (2022, "US", 100.0), + (2023, "EU", 75.0), + (2023, "US", 150.0), + ] + + yardstick_layer.adapter.execute("CREATE TABLE insert_target (year INT, region TEXT, revenue DOUBLE)") + yardstick_layer.sql( + """ +INSERT INTO insert_target +SELECT year, region, AGGREGATE(revenue) +FROM sales_v +""" + ) + inserted_rows = fetch_dicts(yardstick_layer.adapter.execute("SELECT * FROM insert_target ORDER BY year, region")) + assert [(row["year"], row["region"], float(row["revenue"])) for row in inserted_rows] == [ + (2022, "EU", 50.0), + (2022, "US", 100.0), + (2023, "EU", 75.0), + (2023, "US", 150.0), + ] + + +def test_yardstick_builtin_duckdb_aggregate_function_falls_through(yardstick_layer): + assert yardstick_layer.sql("SELECT aggregate([1, 2, 3], 'sum')").fetchall() == [(6,)] def test_yardstick_scalar_aggregate_without_group_by(yardstick_layer):