From 9109863363872702e51a9f3d9ab023285cb0fad4 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 10:58:19 -0700 Subject: [PATCH 01/13] Align native contract and SQL correctness --- docs/native-fixtures.md | 12 +- docs/native-format.md | 47 +- docs/runtime-feature-matrix.md | 2 +- docs/rust-runtime.md | 2 +- scripts/generate_schema.py | 45 + sidemantic-rs/examples/parity_adapter.rs | 9 + sidemantic-rs/src/config/loader.rs | 153 +++- sidemantic-rs/src/config/mod.rs | 2 +- sidemantic-rs/src/config/schema.rs | 306 ++++++- sidemantic-rs/src/config/sql_parser.rs | 858 +++++++++++++++++- sidemantic-rs/src/core/dependency.rs | 128 ++- sidemantic-rs/src/core/graph.rs | 165 +++- sidemantic-rs/src/core/model.rs | 96 +- sidemantic-rs/src/main.rs | 3 + sidemantic-rs/src/runtime.rs | 87 +- sidemantic-rs/src/sql/generator.rs | 191 +++- sidemantic-schema.json | 220 +++++ sidemantic/adapters/sidemantic.py | 631 ++++++++++--- sidemantic/core/inheritance.py | 18 +- sidemantic/core/metric.py | 16 +- sidemantic/core/pre_aggregation.py | 2 + sidemantic/core/relationship.py | 34 +- sidemantic/core/semantic_graph.py | 90 +- sidemantic/core/semantic_layer.py | 17 +- sidemantic/core/sql_definitions.py | 28 +- sidemantic/db/base.py | 41 + sidemantic/db/bigquery.py | 3 +- sidemantic/db/clickhouse.py | 3 +- sidemantic/db/databricks.py | 3 +- sidemantic/db/postgres.py | 6 +- sidemantic/db/snowflake.py | 3 +- sidemantic/rust_bridge.py | 16 +- sidemantic/schema.py | 45 + sidemantic/sql/aggregation_detection.py | 2 +- sidemantic/sql/generator.py | 435 +++++---- sidemantic/sql/query_rewriter.py | 6 +- sidemantic/validation.py | 33 +- .../sidemantic_adapter/test_parsing.py | 569 ++++++++++++ .../test_rust_bridge_yaml_serialization.py | 24 + tests/core/test_sql_definitions.py | 12 + tests/db/test_postgres_adapter.py | 28 + tests/db/test_query_history_validation.py | 57 ++ tests/joins/test_many_to_many_joins.py | 55 ++ .../cumulative_revenue_by_month_result.json | 12 + .../expected/multi_platform_users_result.json | 10 + .../revenue_mom_by_month_region_result.json | 26 + .../signup_conversion_by_region_result.json | 10 + .../expected/signup_retention_result.json | 30 + .../queries/multi_platform_users.query.yml | 2 + .../advanced_metrics/seed/duckdb.sql | 24 + .../compact_sql_model/README.md | 4 + .../compact_sql_model/expected/result.json | 10 + .../expected/validation.json | 3 + .../compact_sql_model/models/orders.sql | 11 + .../queries/revenue_by_status.query.yml | 6 + .../compact_sql_model/seed/duckdb.sql | 19 + .../custom_relationship_sql/README.md | 6 + .../expected/result.json | 10 + .../expected/validation.json | 4 + .../custom_relationship_sql/models/orders.yml | 24 + .../queries/revenue_by_country.query.yml | 6 + .../custom_relationship_sql/seed/duckdb.sql | 20 + .../expected/validation.json | 3 + .../models/orders.yml | 8 + tests/native-fixtures/manifest.yml | 113 ++- .../many_to_many_composite_keys/README.md | 3 + .../expected/result.json | 14 + .../expected/validation.json | 3 + .../models/sales.yml | 26 + .../queries/revenue_by_category.query.yml | 6 + .../seed/duckdb.sql | 32 + .../native-fixtures/native_aliases/README.md | 8 + .../native_aliases/expected/result.json | 12 + .../native_aliases/expected/validation.json | 3 + .../native_aliases/models/models.yml | 20 + .../queries/revenue_by_status.query.yml | 7 + .../native_aliases/seed/duckdb.sql | 10 + .../relationship_default_keys/README.md | 7 + .../expected/account_region_result.json | 10 + .../expected/profile_tier_result.json | 10 + .../expected/result.json | 10 + .../expected/validation.json | 3 + .../expected/vendor_segment_result.json | 10 + .../models/models.yml | 69 ++ .../customer_count_by_order_status.query.yml | 6 + .../customer_count_by_profile_tier.query.yml | 6 + .../invoice_count_by_vendor_segment.query.yml | 6 + .../payment_count_by_account_region.query.yml | 6 + .../relationship_default_keys/seed/duckdb.sql | 63 ++ .../statistical_aggregations/README.md | 4 + .../expected/result.json | 8 + .../expected/validation.json | 3 + .../models/orders.yml | 19 + .../queries/amount_stats.query.yml | 5 + .../statistical_aggregations/seed/duckdb.sql | 9 + .../table_calculations/README.md | 7 +- .../table_calculations/expected/result.json | 4 +- .../top_level_metric_contract/README.md | 8 + .../expected/result.json | 10 + .../expected/validation.json | 4 + .../models/orders.yml | 25 + .../queries/revenue_per_order.query.yml | 10 + .../top_level_metric_contract/seed/duckdb.sql | 10 + .../native_compat/test_basic_model_fixture.py | 25 +- tests/queries/test_basic.py | 17 + tests/queries/test_sql_rewriter.py | 33 + tests/rust_layer_adapter.py | 46 +- tests/test_loaders.py | 112 +++ tests/test_metric_expressions.py | 82 ++ tests/test_relationships.py | 98 ++ tests/test_validation.py | 38 + 111 files changed, 5240 insertions(+), 551 deletions(-) create mode 100644 tests/db/test_query_history_validation.py create mode 100644 tests/native-fixtures/advanced_metrics/expected/cumulative_revenue_by_month_result.json create mode 100644 tests/native-fixtures/advanced_metrics/expected/multi_platform_users_result.json create mode 100644 tests/native-fixtures/advanced_metrics/expected/revenue_mom_by_month_region_result.json create mode 100644 tests/native-fixtures/advanced_metrics/expected/signup_conversion_by_region_result.json create mode 100644 tests/native-fixtures/advanced_metrics/expected/signup_retention_result.json create mode 100644 tests/native-fixtures/advanced_metrics/seed/duckdb.sql create mode 100644 tests/native-fixtures/compact_sql_model/README.md create mode 100644 tests/native-fixtures/compact_sql_model/expected/result.json create mode 100644 tests/native-fixtures/compact_sql_model/expected/validation.json create mode 100644 tests/native-fixtures/compact_sql_model/models/orders.sql create mode 100644 tests/native-fixtures/compact_sql_model/queries/revenue_by_status.query.yml create mode 100644 tests/native-fixtures/compact_sql_model/seed/duckdb.sql create mode 100644 tests/native-fixtures/custom_relationship_sql/README.md create mode 100644 tests/native-fixtures/custom_relationship_sql/expected/result.json create mode 100644 tests/native-fixtures/custom_relationship_sql/expected/validation.json create mode 100644 tests/native-fixtures/custom_relationship_sql/models/orders.yml create mode 100644 tests/native-fixtures/custom_relationship_sql/queries/revenue_by_country.query.yml create mode 100644 tests/native-fixtures/custom_relationship_sql/seed/duckdb.sql create mode 100644 tests/native-fixtures/invalid_unknown_native_field/expected/validation.json create mode 100644 tests/native-fixtures/invalid_unknown_native_field/models/orders.yml create mode 100644 tests/native-fixtures/many_to_many_composite_keys/README.md create mode 100644 tests/native-fixtures/many_to_many_composite_keys/expected/result.json create mode 100644 tests/native-fixtures/many_to_many_composite_keys/expected/validation.json create mode 100644 tests/native-fixtures/many_to_many_composite_keys/models/sales.yml create mode 100644 tests/native-fixtures/many_to_many_composite_keys/queries/revenue_by_category.query.yml create mode 100644 tests/native-fixtures/many_to_many_composite_keys/seed/duckdb.sql create mode 100644 tests/native-fixtures/native_aliases/README.md create mode 100644 tests/native-fixtures/native_aliases/expected/result.json create mode 100644 tests/native-fixtures/native_aliases/expected/validation.json create mode 100644 tests/native-fixtures/native_aliases/models/models.yml create mode 100644 tests/native-fixtures/native_aliases/queries/revenue_by_status.query.yml create mode 100644 tests/native-fixtures/native_aliases/seed/duckdb.sql create mode 100644 tests/native-fixtures/relationship_default_keys/README.md create mode 100644 tests/native-fixtures/relationship_default_keys/expected/account_region_result.json create mode 100644 tests/native-fixtures/relationship_default_keys/expected/profile_tier_result.json create mode 100644 tests/native-fixtures/relationship_default_keys/expected/result.json create mode 100644 tests/native-fixtures/relationship_default_keys/expected/validation.json create mode 100644 tests/native-fixtures/relationship_default_keys/expected/vendor_segment_result.json create mode 100644 tests/native-fixtures/relationship_default_keys/models/models.yml create mode 100644 tests/native-fixtures/relationship_default_keys/queries/customer_count_by_order_status.query.yml create mode 100644 tests/native-fixtures/relationship_default_keys/queries/customer_count_by_profile_tier.query.yml create mode 100644 tests/native-fixtures/relationship_default_keys/queries/invoice_count_by_vendor_segment.query.yml create mode 100644 tests/native-fixtures/relationship_default_keys/queries/payment_count_by_account_region.query.yml create mode 100644 tests/native-fixtures/relationship_default_keys/seed/duckdb.sql create mode 100644 tests/native-fixtures/statistical_aggregations/README.md create mode 100644 tests/native-fixtures/statistical_aggregations/expected/result.json create mode 100644 tests/native-fixtures/statistical_aggregations/expected/validation.json create mode 100644 tests/native-fixtures/statistical_aggregations/models/orders.yml create mode 100644 tests/native-fixtures/statistical_aggregations/queries/amount_stats.query.yml create mode 100644 tests/native-fixtures/statistical_aggregations/seed/duckdb.sql create mode 100644 tests/native-fixtures/top_level_metric_contract/README.md create mode 100644 tests/native-fixtures/top_level_metric_contract/expected/result.json create mode 100644 tests/native-fixtures/top_level_metric_contract/expected/validation.json create mode 100644 tests/native-fixtures/top_level_metric_contract/models/orders.yml create mode 100644 tests/native-fixtures/top_level_metric_contract/queries/revenue_per_order.query.yml create mode 100644 tests/native-fixtures/top_level_metric_contract/seed/duckdb.sql diff --git a/docs/native-fixtures.md b/docs/native-fixtures.md index b0b53357..61c0a6b9 100644 --- a/docs/native-fixtures.md +++ b/docs/native-fixtures.md @@ -79,7 +79,8 @@ The suite currently covers: - Parameter interpolation in query filters. - Pre-aggregation routing shape and DuckDB execution against seeded rollup tables. - Semantic SQL rewrite cases for single-model and relationship queries. -- Query-local table calculations on the Rust SQL compiler path, including Rust-only DuckDB result coverage. +- Query-local table calculations for the shared Python/Rust subset. Python applies these after fetching rows; + Rust compiles them into SQL window expressions. - Native `.sql` definition files. - Native SQL frontmatter model definitions. - YAML `sql_metrics` and `sql_segments` blocks. @@ -112,6 +113,15 @@ The default Rust runner loads every manifest fixture, asserts `expected/validati The `adbc-exec` Rust runner executes every query with `expected_result` or `rust_expected_result` through DuckDB ADBC, using the fixture seed SQL and result columns from the manifest. Any Rust-only expected output must include `rust_only_reason`. It is enabled in CI after installing the DuckDB ADBC driver. +Table-calculation fixture contract: + +- Shared table calculations may use `percent_of_total`, `percent_of_previous`, `running_total`, `rank`, `row_number`, or `moving_average`. +- Shared calculations should include deterministic query `order_by` when row order affects the result. +- Python evaluates shared calculations with `TableCalculationProcessor` after query execution. +- Rust evaluates shared calculations by compiling them into SQL expressions. +- Rust-only table calculation types (`dense_rank`, `difference`, `lead`, `lag`) must use `rust_expected_result` and `rust_only_reason`. +- Python-only post-query table calculation types (`percent_of_column_total`, `percentile`) stay out of shared native fixtures until Rust supports them. + ## Adding Fixtures Add the narrowest fixture that proves one semantic behavior. Avoid kitchen-sink fixtures unless the behavior itself is cross-feature interaction. diff --git a/docs/native-format.md b/docs/native-format.md index e8fdc3f4..854d0fe6 100644 --- a/docs/native-format.md +++ b/docs/native-format.md @@ -9,6 +9,20 @@ The native format has two source forms: The native format is the runtime contract. External formats such as LookML, MetricFlow, Hex, Rill, Malloy, Omni, Superset, GoodData, Snowflake Cortex, ThoughtSpot, Holistics, Tableau, AtScale SML, BSL, Yardstick, and Graphene GSQL should be converted into this format by Python importers before they are expected to run through the Rust native runtime. +## Rust Loader Scope + +The Rust runtime and Rust CLI directory loader intentionally have a smaller direct +input surface than Python: + +- `.yml` / `.yaml`: native Sidemantic YAML or Cube YAML. +- `.sql`: native Sidemantic SQL definition files. + +They do not auto-detect LookML, MetricFlow/dbt manifests, Hex, Rill, Malloy, +Omni, Superset, GoodData, Snowflake Cortex, ThoughtSpot, Holistics, Tableau, +AtScale SML, BSL, Yardstick, or other external source formats. Convert those +formats through the Python CLI/API first, then load the exported native YAML/SQL +with the Rust runtime. + ## Versioning Current native format version: `1`. @@ -61,6 +75,18 @@ Top-level sections: | `metrics` | No | Graph-level metrics. Rust assigns these to exactly one owning model when possible. | | `parameters` | No | Graph-level parameters for templates and query-time substitution. | +Top-level metrics are graph-scoped in the Python runtime. The Rust runtime does not +store a separate graph-metric namespace at execution time; it assigns each top-level +metric to one owning model by resolving explicit model references, metric dependencies, +entity dimensions, or a single-model project fallback. If Rust cannot infer exactly +one owner, loading fails. Portable native files should therefore make top-level metric +dependencies explicit, for example `orders.total_revenue` rather than `total_revenue` +when multiple models define the same local metric name. Dotted top-level metric names +are allowed and are resolved by exact metric name before `model.metric` parsing. + +Top-level parameters remain graph-scoped in both runtimes. Query APIs interpolate +parameter values before SQL compilation. + ## Models Models describe physical or logical query sources. @@ -94,6 +120,12 @@ At least one of `table`, `sql`, or `source_uri` should be present unless the mod | `pre_aggregations` | No | List of pre-aggregation definitions. | | `default_time_dimension` | No | Time dimension to add by default when the query needs time grouping. | | `default_grain` | No | Default time grain for the default time dimension. | +| `auto_dimensions` | No | Python auto-discovery flag. Rust accepts `false` for compatibility and rejects `true` because it does not perform schema discovery. | + +Canonical CLI-authored files should use `metrics` and `sql`. The native loaders +also accept compatibility input aliases: model-level `measures` for `metrics`, +dimension/metric `expr` for `sql`, and metric `measure` for `sql`. Exports use +canonical field names. Single-column primary key: @@ -332,8 +364,21 @@ relationships: | `primary_key_columns` | Conditional | Explicit target-column list. | | `through` | For many-to-many | Junction model. | | `through_foreign_key` | For many-to-many | Source-to-through key. | +| `through_foreign_key_columns` | For many-to-many | Explicit source-to-through key columns. | | `related_foreign_key` | For many-to-many | Through-to-target key. | -| `sql` | No | Custom join SQL using runtime placeholders where supported. | +| `related_foreign_key_columns` | For many-to-many | Explicit through-to-target key columns. | +| `sql` | No | Custom join SQL using `{from}` and `{to}` runtime placeholders. | + +For CLI-authored native files, prefer explicit `foreign_key` and `primary_key` +fields. Omitted keys are still supported for compatibility: `many_to_one` +defaults the source key to `{name}_id`, while `one_to_many` and `one_to_one` +default the related-side key to `id`; omitted `primary_key` resolves to the +target model's declared primary key when building graph joins. + +When `sql` is present, Python and Rust use it instead of the FK/PK-generated +predicate. `{from}` is replaced with the source model's runtime alias and `{to}` +with the target model's runtime alias. Reverse graph traversal swaps the +placeholders automatically. Relationship types: diff --git a/docs/runtime-feature-matrix.md b/docs/runtime-feature-matrix.md index 09fc5331..0a4672c9 100644 --- a/docs/runtime-feature-matrix.md +++ b/docs/runtime-feature-matrix.md @@ -27,7 +27,7 @@ This matrix documents current product support for native Sidemantic projects. It | Conversion metrics | Yes | Yes, fixture-covered compile | No dedicated fixture yet | No dedicated fixture yet | | Retention metrics | Yes | Yes, fixture-covered compile | No dedicated fixture yet | No dedicated fixture yet | | Cohort metrics | Yes | Yes, fixture-covered compile | No dedicated fixture yet | No dedicated fixture yet | -| Table calculations | Post-query processing | Yes, Rust fixture-covered compile and Rust-only result coverage | No dedicated fixture yet | No dedicated fixture yet | +| Table calculations | Yes, shared fixture post-query result parity | Yes, shared fixture SQL/window result parity | No dedicated fixture yet | No dedicated fixture yet | | Pre-aggregation routing | Yes | Yes, fixture-covered compile | No dedicated fixture yet | No dedicated fixture yet | | Semantic SQL rewrite | Yes | Native subset, fixture-covered | Native subset target | Narrow subset | | DuckDB execution | Yes | Via ADBC, fixture result parity in CI | Native DuckDB process | No | diff --git a/docs/rust-runtime.md b/docs/rust-runtime.md index 3ca784d6..a6a9298d 100644 --- a/docs/rust-runtime.md +++ b/docs/rust-runtime.md @@ -105,4 +105,4 @@ cd sidemantic-rs && cargo test --test native_fixtures CI runs these in the `Native Compatibility` job. -The shared fixture suite currently includes executable coverage for basic models, joins, fanout-safe symmetric aggregation, many-to-many joins, parameters in filters, embedded SQL definitions, SQL frontmatter definitions, default time dimensions, segments, derived/ratio metrics, and pre-aggregation routing. Table calculations have Rust-only DuckDB result coverage because Python does not accept `table_calculations` in the native query API yet. `source_uri` is covered as a validation-only load fixture and query compilation rejects it until a concrete table or SQL source is provided. +The shared fixture suite currently includes executable coverage for basic models, joins, fanout-safe symmetric aggregation, many-to-many joins, parameters in filters, embedded SQL definitions, SQL frontmatter definitions, default time dimensions, segments, derived/ratio metrics, table calculations, and pre-aggregation routing. `source_uri` is covered as a validation-only load fixture and query compilation rejects it until a concrete table or SQL source is provided. diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index a7b5ad29..164a49d4 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -2,11 +2,54 @@ """Generate JSON Schema from Pydantic models for YAML editor support.""" import json +from copy import deepcopy from pathlib import Path from sidemantic import Dimension, Metric, Model, Parameter, Relationship, Segment +def add_native_relationship_aliases(schema: dict) -> dict: + """Expose native YAML relationship aliases that map to Python API fields.""" + properties = schema.setdefault("properties", {}) + + if "foreign_key" in properties and "foreign_key_columns" not in properties: + foreign_key_columns = deepcopy(properties["foreign_key"]) + foreign_key_columns["title"] = "Foreign Key Columns" + foreign_key_columns["description"] = "Explicit source-column list (alias for foreign_key)" + properties["foreign_key_columns"] = foreign_key_columns + + if "primary_key" in properties and "primary_key_columns" not in properties: + primary_key_columns = deepcopy(properties["primary_key"]) + primary_key_columns["title"] = "Primary Key Columns" + primary_key_columns["description"] = "Explicit target-column list (alias for primary_key)" + properties["primary_key_columns"] = primary_key_columns + + if "sql" not in properties: + properties["sql"] = { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Custom join SQL using {from} and {to} runtime placeholders", + "title": "Sql", + } + + return schema + + +def patch_relationship_schemas(schema: dict) -> None: + """Patch every embedded Relationship schema emitted by Pydantic.""" + if not isinstance(schema, dict): + return + if schema.get("title") == "Relationship": + add_native_relationship_aliases(schema) + for value in schema.values(): + if isinstance(value, dict): + patch_relationship_schemas(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + patch_relationship_schemas(item) + + def generate_schema() -> dict: """Generate JSON Schema for sidemantic YAML files.""" # Get schemas from pydantic models @@ -51,6 +94,8 @@ def generate_schema() -> dict: }, } + patch_relationship_schemas(schema) + return schema diff --git a/sidemantic-rs/examples/parity_adapter.rs b/sidemantic-rs/examples/parity_adapter.rs index 4c243c0a..df49d88a 100644 --- a/sidemantic-rs/examples/parity_adapter.rs +++ b/sidemantic-rs/examples/parity_adapter.rs @@ -28,6 +28,7 @@ enum Request { #[serde(default)] order_by: Vec, limit: Option, + offset: Option, #[serde(default)] ungrouped: bool, #[serde(default)] @@ -137,6 +138,7 @@ fn handle(request: Request) -> sidemantic::Result { segments, order_by, limit, + offset, ungrouped, skip_default_time_dimensions, dialect, @@ -153,6 +155,9 @@ fn handle(request: Request) -> sidemantic::Result { if let Some(limit) = limit { query = query.with_limit(limit); } + if let Some(offset) = offset { + query = query.with_offset(offset); + } let mut generator = SqlGenerator::new(&graph); if let Some(dialect) = dialect { generator = generator.with_dialect(parse_dialect(&dialect)?); @@ -654,6 +659,10 @@ fn metric_aggregation_name(aggregation: Option<&Aggregation>) -> &'static str { Some(Aggregation::Min) => "min", Some(Aggregation::Max) => "max", Some(Aggregation::Median) => "median", + Some(Aggregation::Stddev) => "stddev", + Some(Aggregation::StddevPop) => "stddev_pop", + Some(Aggregation::Variance) => "variance", + Some(Aggregation::VariancePop) => "variance_pop", Some(Aggregation::Expression) | None => "sum", } } diff --git a/sidemantic-rs/src/config/loader.rs b/sidemantic-rs/src/config/loader.rs index 6e61e4ad..4876a3a2 100644 --- a/sidemantic-rs/src/config/loader.rs +++ b/sidemantic-rs/src/config/loader.rs @@ -15,7 +15,7 @@ use crate::error::{Result, SidemanticError}; use super::schema::{CubeConfig, SidemanticConfig, NATIVE_FORMAT_VERSION}; use super::sql_parser::{ - parse_sql_definitions, parse_sql_graph_definitions_extended, parse_sql_model, + parse_sql_definitions, parse_sql_graph_definitions_extended, parse_sql_models, }; #[derive(Debug)] @@ -231,52 +231,72 @@ fn model_from_sql_frontmatter(frontmatter: serde_yaml::Mapping) -> Result } fn parse_sql_content(content: &str) -> Result { - let has_model_statement = { - let upper = content.to_ascii_uppercase(); - upper.contains("MODEL") && upper.contains("MODEL (") - }; - let mut models: Vec = Vec::new(); let mut top_level_metrics: Vec = Vec::new(); let mut top_level_parameters: Vec = Vec::new(); - if has_model_statement { - let model = parse_sql_model(content).map_err(|e| { - SidemanticError::Validation(format!("failed to parse SQL model statement: {e}")) - })?; - let model_metric_names: HashSet = model - .metrics - .iter() - .map(|metric| metric.name.clone()) - .collect(); - models.push(model); - - let (sql_metrics, _, sql_parameters, _) = parse_sql_graph_definitions_extended(content) - .map_err(|e| { - SidemanticError::Validation(format!("failed to parse SQL graph definitions: {e}")) - })?; - for metric in sql_metrics { - if !model_metric_names.contains(&metric.name) { - top_level_metrics.push(metric); + match parse_sql_models(content) { + Ok(parsed_models) => { + let model_metric_names: HashSet = parsed_models + .iter() + .flat_map(|model| model.metrics.iter().map(|metric| metric.name.clone())) + .collect(); + models.extend(parsed_models); + + let graph_definitions = parse_sql_graph_definitions_extended(content); + let (sql_metrics, _, sql_parameters, _) = match graph_definitions { + Ok(definitions) => definitions, + Err(_) + if content + .trim_start() + .to_ascii_lowercase() + .starts_with("model ") + && content.to_ascii_lowercase().contains(" from ") => + { + (Vec::new(), Vec::new(), Vec::new(), Vec::new()) + } + Err(err) => { + return Err(SidemanticError::Validation(format!( + "failed to parse SQL graph definitions: {err}" + ))); + } + }; + for metric in sql_metrics { + if !model_metric_names.contains(&metric.name) { + top_level_metrics.push(metric); + } } + top_level_parameters.extend(sql_parameters); } - top_level_parameters.extend(sql_parameters); - } else { - let (frontmatter, sql_body) = parse_sql_frontmatter_and_body(content)?; - let (sql_metrics, sql_segments, sql_parameters, sql_preaggs) = - parse_sql_graph_definitions_extended(&sql_body).map_err(|e| { - SidemanticError::Validation(format!("failed to parse SQL graph definitions: {e}")) - })?; - top_level_parameters.extend(sql_parameters); - - if let Some(frontmatter) = frontmatter { - let mut model = model_from_sql_frontmatter(frontmatter)?; - model.metrics.extend(sql_metrics); - model.segments.extend(sql_segments); - model.pre_aggregations.extend(sql_preaggs); - models.push(model); - } else { - top_level_metrics.extend(sql_metrics); + Err(model_err) + if content + .trim_start() + .to_ascii_lowercase() + .starts_with("model ") => + { + return Err(SidemanticError::Validation(format!( + "failed to parse SQL model statement: {model_err}" + ))); + } + Err(_) => { + let (frontmatter, sql_body) = parse_sql_frontmatter_and_body(content)?; + let (sql_metrics, sql_segments, sql_parameters, sql_preaggs) = + parse_sql_graph_definitions_extended(&sql_body).map_err(|e| { + SidemanticError::Validation(format!( + "failed to parse SQL graph definitions: {e}" + )) + })?; + top_level_parameters.extend(sql_parameters); + + if let Some(frontmatter) = frontmatter { + let mut model = model_from_sql_frontmatter(frontmatter)?; + model.metrics.extend(sql_metrics); + model.segments.extend(sql_segments); + model.pre_aggregations.extend(sql_preaggs); + models.push(model); + } else { + top_level_metrics.extend(sql_metrics); + } } } @@ -355,10 +375,14 @@ pub fn load_from_sql_string_with_metadata(content: &str) -> Result) -> Result { Ok(load_from_directory_with_metadata(dir)?.graph) } @@ -1016,7 +1040,9 @@ fn infer_relationships(models: &mut HashMap) { primary_key_columns: Some(target_primary_keys.clone()), through: None, through_foreign_key: None, + through_foreign_key_columns: None, related_foreign_key: None, + related_foreign_key_columns: None, sql: None, metadata: None, }, @@ -1034,7 +1060,9 @@ fn infer_relationships(models: &mut HashMap) { primary_key_columns: Some(target_primary_keys), through: None, through_foreign_key: None, + through_foreign_key_columns: None, related_foreign_key: None, + related_foreign_key_columns: None, sql: None, metadata: None, }, @@ -1233,6 +1261,47 @@ METRIC ( .contains("Unsupported native Sidemantic format version 2; supported version is 1")); } + #[test] + fn test_load_from_sql_string_supports_compact_model_syntax() { + let sql = r#" +model orders from orders ( + primary key (order_id) + status + sum(amount) as revenue +) +"#; + + let loaded = load_from_sql_string_with_metadata(sql).unwrap(); + let orders = loaded.graph.get_model("orders").unwrap(); + assert!(orders.get_dimension("status").is_some()); + assert!(orders.get_metric("revenue").is_some()); + } + + #[test] + fn test_load_from_sql_string_keeps_multiple_legacy_models_separate() { + let sql = r#" +MODEL (name orders, table orders, primary_key order_id); +METRIC order_count AS COUNT(*); + +MODEL (name customers, table customers, primary_key customer_id); +METRIC customer_count AS COUNT(*); +"#; + + let loaded = load_from_sql_string_with_metadata(sql).unwrap(); + + let orders = loaded.graph.get_model("orders").unwrap(); + assert!(orders.get_metric("order_count").is_some()); + assert!(orders.get_metric("customer_count").is_none()); + + let customers = loaded.graph.get_model("customers").unwrap(); + assert!(customers.get_metric("customer_count").is_some()); + assert!(customers.get_metric("order_count").is_none()); + assert_eq!( + loaded.model_order, + vec!["orders".to_string(), "customers".to_string()] + ); + } + #[test] fn test_sql_frontmatter_version_is_not_model_metadata() { let sql = r#" diff --git a/sidemantic-rs/src/config/mod.rs b/sidemantic-rs/src/config/mod.rs index be596577..46435b45 100644 --- a/sidemantic-rs/src/config/mod.rs +++ b/sidemantic-rs/src/config/mod.rs @@ -15,5 +15,5 @@ pub use loader::{ pub use schema::{CubeConfig, ModelConfig, SidemanticConfig}; pub use sql_parser::{ parse_sql_definitions, parse_sql_graph_definitions, parse_sql_graph_definitions_extended, - parse_sql_model, parse_sql_statement_blocks, + parse_sql_model, parse_sql_models, parse_sql_statement_blocks, }; diff --git a/sidemantic-rs/src/config/schema.rs b/sidemantic-rs/src/config/schema.rs index 7f703d0a..96ff4243 100644 --- a/sidemantic-rs/src/config/schema.rs +++ b/sidemantic-rs/src/config/schema.rs @@ -57,8 +57,10 @@ pub struct ModelConfig { #[serde(default)] pub meta: Option, #[serde(default)] - pub dimensions: Vec, + pub auto_dimensions: bool, #[serde(default)] + pub dimensions: Vec, + #[serde(default, alias = "measures")] pub metrics: Vec, #[serde(default)] pub relationships: Vec, @@ -104,6 +106,7 @@ pub struct DimensionConfig { pub name: String, #[serde(default, rename = "type")] pub dim_type: Option, + #[serde(default, alias = "expr")] pub sql: Option, pub granularity: Option, pub supported_granularities: Option>, @@ -129,6 +132,7 @@ pub struct MetricConfig { #[serde(default, rename = "type")] pub metric_type: Option, pub agg: Option, + #[serde(default, alias = "expr", alias = "measure")] pub sql: Option, pub numerator: Option, pub denominator: Option, @@ -194,7 +198,11 @@ pub struct RelationshipConfig { pub primary_key_columns: Option>, pub through: Option, pub through_foreign_key: Option, + #[serde(default)] + pub through_foreign_key_columns: Option>, pub related_foreign_key: Option, + #[serde(default)] + pub related_foreign_key_columns: Option>, /// Custom SQL join condition using {from} and {to} placeholders pub sql: Option, #[serde(default)] @@ -375,6 +383,15 @@ impl SidemanticConfig { self.validate_version()?; for model in &self.models { + if model.auto_dimensions { + return Err(crate::error::SidemanticError::validation_issue( + "unsupported_auto_dimensions", + Some(&model.name), + &format!("models.{}.auto_dimensions", model.name), + Some("true"), + "Rust native runtime does not support auto_dimensions; declare dimensions explicitly or set auto_dimensions: false", + )); + } validate_optional_enum( model.default_grain.as_deref(), &format!("models.{}.default_grain", model.name), @@ -587,6 +604,14 @@ impl DimensionConfig { impl MetricConfig { fn into_metric(self) -> Metric { + let inline_aggregation = if self.agg.is_none() { + self.sql + .as_deref() + .and_then(parse_inline_metric_aggregation) + } else { + None + }; + let metric_type = match self .metric_type .as_deref() @@ -602,7 +627,7 @@ impl MetricConfig { Some("retention") => MetricType::Retention, Some("cohort") => MetricType::Cohort, _ => { - if self.agg.is_none() && self.sql.is_some() { + if inline_aggregation.is_none() && self.agg.is_none() && self.sql.is_some() { MetricType::Derived } else { MetricType::Simple @@ -610,7 +635,15 @@ impl MetricConfig { } }; - let agg = self.agg.as_deref().map(parse_aggregation); + let agg = self + .agg + .as_deref() + .map(parse_aggregation) + .or_else(|| inline_aggregation.as_ref().map(|(agg, _)| agg.clone())); + let sql = inline_aggregation + .as_ref() + .and_then(|(_, inner_sql)| inner_sql.clone()) + .or(self.sql); let grain_to_date = self.grain_to_date.as_deref().and_then(parse_time_grain); let comparison_type = self .comparison_type @@ -636,7 +669,7 @@ impl MetricConfig { extends: self.extends, r#type: metric_type, agg, - sql: self.sql, + sql, numerator: self.numerator, denominator: self.denominator, offset_window: self.offset_window, @@ -712,7 +745,13 @@ impl RelationshipConfig { primary_key_columns, through: self.through, through_foreign_key: self.through_foreign_key, + through_foreign_key_columns: self + .through_foreign_key_columns + .filter(|columns| !columns.is_empty()), related_foreign_key: self.related_foreign_key, + related_foreign_key_columns: self + .related_foreign_key_columns + .filter(|columns| !columns.is_empty()), sql: self.sql, metadata: self.metadata, } @@ -880,6 +919,10 @@ impl CubeMeasure { Some("avg") => (MetricType::Simple, Some(Aggregation::Avg)), Some("min") => (MetricType::Simple, Some(Aggregation::Min)), Some("max") => (MetricType::Simple, Some(Aggregation::Max)), + Some("stddev") => (MetricType::Simple, Some(Aggregation::Stddev)), + Some("stddev_pop") => (MetricType::Simple, Some(Aggregation::StddevPop)), + Some("variance") => (MetricType::Simple, Some(Aggregation::Variance)), + Some("variance_pop") => (MetricType::Simple, Some(Aggregation::VariancePop)), Some("number") => (MetricType::Derived, None), // derived/calculated _ => (MetricType::Simple, Some(Aggregation::Sum)), }; @@ -1034,6 +1077,10 @@ fn validate_metric_config(metric: &MetricConfig, field_path: &str) -> crate::err "min", "max", "median", + "stddev", + "stddev_pop", + "variance", + "variance_pop", "expression", ], )?; @@ -1078,6 +1125,10 @@ fn validate_metric_config(metric: &MetricConfig, field_path: &str) -> crate::err "min", "max", "median", + "stddev", + "stddev_pop", + "variance", + "variance_pop", "expression", ], )?; @@ -1096,11 +1147,98 @@ fn parse_aggregation(s: &str) -> Aggregation { "min" => Aggregation::Min, "max" => Aggregation::Max, "median" => Aggregation::Median, + "stddev" => Aggregation::Stddev, + "stddev_pop" => Aggregation::StddevPop, + "variance" => Aggregation::Variance, + "variance_pop" | "var_pop" => Aggregation::VariancePop, "expression" => Aggregation::Expression, _ => Aggregation::Sum, } } +fn parse_inline_metric_aggregation(sql_expr: &str) -> Option<(Aggregation, Option)> { + let trimmed = sql_expr.trim(); + if trimmed.is_empty() { + return None; + } + + let open_paren = trimmed.find('(')?; + let func = trimmed[..open_paren].trim().to_ascii_lowercase(); + if !matches!( + func.as_str(), + "sum" + | "avg" + | "min" + | "max" + | "median" + | "stddev" + | "stddev_pop" + | "variance" + | "variance_pop" + | "var_pop" + | "count" + ) { + return None; + } + + let mut depth = 0i32; + let mut close_paren = None; + for (idx, ch) in trimmed.char_indices().skip(open_paren) { + match ch { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth == 0 { + close_paren = Some(idx); + break; + } + if depth < 0 { + return None; + } + } + _ => {} + } + } + + let close_paren = close_paren?; + if depth != 0 || !trimmed[close_paren + 1..].trim().is_empty() { + return None; + } + + let inner = trimmed[open_paren + 1..close_paren].trim(); + match func.as_str() { + "sum" | "avg" | "min" | "max" | "median" | "stddev" | "stddev_pop" | "variance" + | "variance_pop" | "var_pop" => { + if inner.is_empty() { + None + } else { + Some((parse_aggregation(&func), Some(inner.to_string()))) + } + } + "count" => { + if inner.is_empty() { + return None; + } + if inner == "*" { + return Some((Aggregation::Count, Some("*".to_string()))); + } + + let inner_lower = inner.to_ascii_lowercase(); + if inner_lower.starts_with("distinct ") { + let distinct_expr = inner[8..].trim(); + if distinct_expr.is_empty() { + None + } else { + Some((Aggregation::CountDistinct, Some(distinct_expr.to_string()))) + } + } else { + Some((Aggregation::Count, Some(inner.to_string()))) + } + } + _ => None, + } +} + fn parse_time_grain(s: &str) -> Option { match s.to_lowercase().as_str() { "day" => Some(TimeGrain::Day), @@ -1185,6 +1323,58 @@ parameters: assert_eq!(parameters[0].name, "status"); } + #[test] + fn test_native_yaml_accepts_python_compatibility_aliases() { + let yaml = r#" +version: 1 +models: + - name: orders + table: orders + auto_dimensions: false + dimensions: + - name: status + type: categorical + expr: order_status + measures: + - name: revenue + agg: sum + expr: amount + - name: revenue_per_order + type: derived + measure: revenue / order_count + - name: order_count + agg: count +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (models, _, _) = config.into_parts().unwrap(); + let orders = &models[0]; + + assert_eq!(orders.dimensions[0].sql.as_deref(), Some("order_status")); + assert_eq!(orders.metrics.len(), 3); + assert_eq!(orders.metrics[0].sql.as_deref(), Some("amount")); + assert_eq!( + orders.metrics[1].sql.as_deref(), + Some("revenue / order_count") + ); + } + + #[test] + fn test_native_yaml_rejects_auto_dimensions_true() { + let yaml = r#" +version: 1 +models: + - name: orders + table: orders + auto_dimensions: true +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let err = config.into_parts().unwrap_err(); + assert!(err.to_string().contains("unsupported_auto_dimensions")); + assert!(err.to_string().contains("auto_dimensions")); + } + #[test] fn test_parse_native_yaml_without_version_defaults_to_supported_contract() { let yaml = r#" @@ -1483,6 +1673,45 @@ models: assert_eq!(rel.primary_key.as_deref(), Some("product_id")); } + #[test] + fn test_parse_many_to_many_composite_junction_key_fields() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key_columns: [tenant_id, order_id] + relationships: + - name: products + type: many_to_many + through: order_items + through_foreign_key_columns: [tenant_id, order_id] + related_foreign_key_columns: [tenant_id, product_id] + - name: order_items + table: order_items + - name: products + table: products + primary_key_columns: [tenant_id, product_id] +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (models, _, _) = config.into_parts().unwrap(); + + let orders = models.iter().find(|m| m.name == "orders").unwrap(); + let rel = orders + .relationships + .iter() + .find(|r| r.name == "products") + .unwrap(); + assert_eq!( + rel.through_foreign_key_columns.as_ref().unwrap(), + &vec!["tenant_id".to_string(), "order_id".to_string()] + ); + assert_eq!( + rel.related_foreign_key_columns.as_ref().unwrap(), + &vec!["tenant_id".to_string(), "product_id".to_string()] + ); + } + #[test] fn test_parse_native_yaml_composite_keys() { let yaml = r#" @@ -1525,6 +1754,75 @@ models: ); } + #[test] + fn test_parse_native_yaml_normalizes_inline_aggregate_metric() { + let yaml = r#" +models: + - name: orders + table: orders + metrics: + - name: revenue + sql: SUM(amount) + - name: distinct_customers + sql: COUNT(DISTINCT customer_id) + - name: revenue_stddev + agg: stddev + sql: amount + - name: revenue_variance_pop + sql: VARIANCE_POP(amount) + - name: revenue_per_order + sql: SUM(amount) / COUNT(*) +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (models, _, _) = config.into_parts().unwrap(); + let orders = models.iter().find(|m| m.name == "orders").unwrap(); + + let revenue = orders.metrics.iter().find(|m| m.name == "revenue").unwrap(); + assert_eq!(revenue.r#type, MetricType::Simple); + assert_eq!(revenue.agg, Some(Aggregation::Sum)); + assert_eq!(revenue.sql.as_deref(), Some("amount")); + + let distinct_customers = orders + .metrics + .iter() + .find(|m| m.name == "distinct_customers") + .unwrap(); + assert_eq!(distinct_customers.r#type, MetricType::Simple); + assert_eq!(distinct_customers.agg, Some(Aggregation::CountDistinct)); + assert_eq!(distinct_customers.sql.as_deref(), Some("customer_id")); + + let revenue_stddev = orders + .metrics + .iter() + .find(|m| m.name == "revenue_stddev") + .unwrap(); + assert_eq!(revenue_stddev.r#type, MetricType::Simple); + assert_eq!(revenue_stddev.agg, Some(Aggregation::Stddev)); + assert_eq!(revenue_stddev.sql.as_deref(), Some("amount")); + + let revenue_variance_pop = orders + .metrics + .iter() + .find(|m| m.name == "revenue_variance_pop") + .unwrap(); + assert_eq!(revenue_variance_pop.r#type, MetricType::Simple); + assert_eq!(revenue_variance_pop.agg, Some(Aggregation::VariancePop)); + assert_eq!(revenue_variance_pop.sql.as_deref(), Some("amount")); + + let revenue_per_order = orders + .metrics + .iter() + .find(|m| m.name == "revenue_per_order") + .unwrap(); + assert_eq!(revenue_per_order.r#type, MetricType::Derived); + assert_eq!(revenue_per_order.agg, None); + assert_eq!( + revenue_per_order.sql.as_deref(), + Some("SUM(amount) / COUNT(*)") + ); + } + #[test] fn test_strip_cube_placeholder() { assert_eq!(strip_cube_placeholder("${CUBE}.status"), "status"); diff --git a/sidemantic-rs/src/config/sql_parser.rs b/sidemantic-rs/src/config/sql_parser.rs index 565d3129..dd590be6 100644 --- a/sidemantic-rs/src/config/sql_parser.rs +++ b/sidemantic-rs/src/config/sql_parser.rs @@ -14,7 +14,7 @@ //! DIMENSION status AS status; //! ``` -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use nom::{ branch::alt, @@ -30,6 +30,7 @@ use nom::{ #[cfg(not(target_arch = "wasm32"))] use polyglot_sql::parse as polyglot_parse; use polyglot_sql::{DialectType, Expression}; +use regex::Regex; use crate::core::{ Aggregation, CohortInnerMetric, ComparisonCalculation, ComparisonType, Dimension, @@ -322,6 +323,10 @@ fn parse_metric_aggregation(value: Option<&String>) -> Option { "min" => Some(Aggregation::Min), "max" => Some(Aggregation::Max), "median" => Some(Aggregation::Median), + "stddev" => Some(Aggregation::Stddev), + "stddev_pop" => Some(Aggregation::StddevPop), + "variance" => Some(Aggregation::Variance), + "variance_pop" | "var_pop" => Some(Aggregation::VariancePop), "expression" => Some(Aggregation::Expression), _ => None, }) @@ -556,6 +561,28 @@ fn simple_dimension(input: &str) -> IResult<&str, Statement> { Ok((input, Statement::Dimension(props))) } +/// Parse simple SEGMENT: SEGMENT name AS expr +fn simple_segment(input: &str) -> IResult<&str, Statement> { + let (input, _) = multispace0(input)?; + let (input, _) = tag_no_case("SEGMENT")(input)?; + let (input, _) = multispace1(input)?; + + let (input, name) = recognize(pair(identifier, opt(pair(char('.'), identifier))))(input)?; + + let (input, _) = multispace1(input)?; + let (input, _) = tag_no_case("AS")(input)?; + let (input, _) = multispace1(input)?; + + let (input, expr) = take_while(|c| c != ';')(input)?; + let (input, _) = opt(char(';'))(input)?; + + let mut props = HashMap::new(); + props.insert("name".to_string(), name.trim().to_string()); + props.insert("sql".to_string(), expr.trim().to_string()); + + Ok((input, Statement::Segment(props))) +} + /// Parse metric expression to extract aggregation function fn parse_metric_expression(name: &str, expr: &str) -> HashMap { let mut props = HashMap::new(); @@ -612,6 +639,10 @@ fn parse_top_level_function_metric(expr: &str) -> Option<(String, String)> { "min" => "min", "max" => "max", "median" => "median", + "stddev" => "stddev", + "stddev_pop" => "stddev_pop", + "variance" => "variance", + "variance_pop" | "var_pop" => "variance_pop", "count_distinct" | "countdistinct" => "count_distinct", "count" => { if inner_expr == "*" { @@ -761,6 +792,10 @@ fn extract_aggregation_from_function_name_polyglot( "min" => "min", "max" => "max", "median" => "median", + "stddev" => "stddev", + "stddev_pop" => "stddev_pop", + "variance" => "variance", + "variance_pop" | "var_pop" => "variance_pop", "count_distinct" => "count_distinct", _ => return None, }; @@ -856,6 +891,26 @@ fn prefixed_dimension(input: &str) -> IResult<&str, Statement> { Ok((input, Statement::Dimension(props))) } +/// Parse SEGMENT with model prefix: SEGMENT model.name (props) +fn prefixed_segment(input: &str) -> IResult<&str, Statement> { + let (input, _) = multispace0(input)?; + let (input, _) = tag_no_case("SEGMENT")(input)?; + let (input, _) = multispace1(input)?; + + let (input, model) = identifier(input)?; + let (input, _) = char('.')(input)?; + let (input, name) = identifier(input)?; + + let (input, _) = multispace0(input)?; + let (input, props) = delimited(char('('), property_list, char(')'))(input)?; + let (input, _) = multispace0(input)?; + let (input, _) = opt(char(';'))(input)?; + + let mut props = props; + props.insert("name".to_string(), format!("{model}.{name}")); + Ok((input, Statement::Segment(props))) +} + /// Parse any statement (tries simple AS syntax first, then parenthesized) fn statement(input: &str) -> IResult<&str, Statement> { let (input, _) = multispace0(input)?; @@ -865,9 +920,11 @@ fn statement(input: &str) -> IResult<&str, Statement> { // Try simple AS syntax first for METRIC and DIMENSION simple_metric, simple_dimension, + simple_segment, // Try model.name (props) syntax prefixed_metric, prefixed_dimension, + prefixed_segment, // Fall back to simple parenthesized syntax map(definition("DIMENSION"), Statement::Dimension), map(definition("METRIC"), Statement::Metric), @@ -929,16 +986,15 @@ fn parse_file(input: &str) -> IResult<&str, Vec> { Ok((remaining, statements)) } -// ============================================================================ -// Public API -// ============================================================================ - -/// Parse SQL definitions into a Model -pub fn parse_sql_model(sql: &str) -> Result { - let (_, statements) = - parse_file(sql).map_err(|e| SidemanticError::Validation(format!("Parse error: {e}")))?; +fn has_compact_model_syntax(sql: &str) -> bool { + Regex::new(r"(?i)\bmodel\s+[A-Za-z_][A-Za-z0-9_]*\s+from\b") + .expect("valid compact model regex") + .is_match(sql) +} - let mut model: Option = None; +fn parse_legacy_sql_models_from_statements(statements: Vec) -> Result> { + let mut models = Vec::new(); + let mut current_model: Option = None; let mut dimensions = Vec::new(); let mut metrics = Vec::new(); let mut segments = Vec::new(); @@ -948,7 +1004,16 @@ pub fn parse_sql_model(sql: &str) -> Result { for stmt in statements { match stmt { Statement::Model(props) => { - model = Some(build_model(&props)?); + flush_legacy_sql_model( + &mut models, + &mut current_model, + &mut dimensions, + &mut metrics, + &mut segments, + &mut relationships, + &mut pre_aggregations, + ); + current_model = Some(build_model(&props)?); } Statement::Dimension(props) => { if let Some(dim) = build_dimension(&props) { @@ -979,19 +1044,556 @@ pub fn parse_sql_model(sql: &str) -> Result { } } - let mut model = model.ok_or_else(|| { - SidemanticError::Validation("SQL definitions must include a MODEL statement".into()) + flush_legacy_sql_model( + &mut models, + &mut current_model, + &mut dimensions, + &mut metrics, + &mut segments, + &mut relationships, + &mut pre_aggregations, + ); + + if models.is_empty() { + return Err(SidemanticError::Validation( + "SQL definitions must include a MODEL statement".into(), + )); + } + + Ok(models) +} + +#[allow(clippy::too_many_arguments)] +fn flush_legacy_sql_model( + models: &mut Vec, + current_model: &mut Option, + dimensions: &mut Vec, + metrics: &mut Vec, + segments: &mut Vec, + relationships: &mut Vec, + pre_aggregations: &mut Vec, +) { + let Some(mut model) = current_model.take() else { + dimensions.clear(); + metrics.clear(); + segments.clear(); + relationships.clear(); + pre_aggregations.clear(); + return; + }; + + model.dimensions.append(dimensions); + model.metrics.append(metrics); + model.segments.append(segments); + model.relationships.append(relationships); + model.pre_aggregations.append(pre_aggregations); + models.push(model); +} + +fn strip_line_comment(line: &str) -> &str { + line.split_once("--") + .map(|(before, _)| before) + .unwrap_or(line) +} + +fn split_annotation(line: &str) -> (&str, Option<&str>) { + let Some(idx) = line.rfind(" : ") else { + return (line, None); + }; + (&line[..idx], Some(line[idx + 3..].trim())) +} + +fn split_field_alias(line: &str) -> (String, String) { + let lower = line.to_ascii_lowercase(); + if let Some(idx) = lower.rfind(" as ") { + let expr = line[..idx].trim().to_string(); + let name = line[idx + 4..].trim().to_string(); + (expr, name) + } else { + let name = line.trim().to_string(); + (name.clone(), name) + } +} + +fn split_columns(value: &str) -> Result> { + let columns = value + .split(',') + .map(str::trim) + .filter(|column| !column.is_empty()) + .map(ToString::to_string) + .collect::>(); + if columns.is_empty() { + return Err(SidemanticError::Validation( + "Primary key requires at least one column".to_string(), + )); + } + Ok(columns) +} + +fn parse_compact_annotation(annotation: Option<&str>) -> Result<(Option, Option)> { + let Some(annotation) = annotation else { + return Ok((None, None)); + }; + + let mut dim_type = None; + let mut granularity = None; + let mut parts = annotation.split_whitespace().peekable(); + while let Some(part) = parts.next() { + if part.eq_ignore_ascii_case("grain") { + let Some(grain) = parts.next() else { + return Err(SidemanticError::Validation( + "field annotation grain requires a value".to_string(), + )); + }; + granularity = Some(grain.to_ascii_lowercase()); + } else if dim_type.is_none() { + dim_type = Some(part.to_ascii_lowercase()); + } else { + return Err(SidemanticError::Validation(format!( + "Unrecognized field annotation '{annotation}'" + ))); + } + } + + if granularity.is_some() + && dim_type + .as_deref() + .is_some_and(|value| value != "time" && value != "timestamp" && value != "date") + { + return Err(SidemanticError::Validation(format!( + "field annotation cannot use grain with type '{}'", + dim_type.unwrap_or_default() + ))); + } + + Ok((dim_type, granularity)) +} + +fn compact_relationship_type(kind: &str) -> Result { + match kind.to_ascii_lowercase().as_str() { + "one" => Ok(RelationshipType::ManyToOne), + "many" => Ok(RelationshipType::OneToMany), + "one_to_one" => Ok(RelationshipType::OneToOne), + "many_to_one" => Ok(RelationshipType::ManyToOne), + "one_to_many" => Ok(RelationshipType::OneToMany), + "many_to_many" => Ok(RelationshipType::ManyToMany), + _ => Err(SidemanticError::Validation(format!( + "unsupported compact join relationship type '{kind}'" + ))), + } +} + +fn parse_compact_join(line: &str) -> Result { + let join_re = Regex::new(r"(?i)^join\s+(\w+)\s+([A-Za-z_][A-Za-z0-9_]*)\s+on\s+(.+)$") + .expect("valid compact join regex"); + let captures = join_re.captures(line).ok_or_else(|| { + SidemanticError::Validation(format!("Unrecognized compact join statement: {line}")) })?; + let rel_type = compact_relationship_type(&captures[1])?; + let target_model = captures[2].to_string(); + let mut predicate = captures[3].trim(); + if predicate.starts_with('(') && predicate.ends_with(')') { + predicate = predicate[1..predicate.len() - 1].trim(); + } + + let mut local_keys = Vec::new(); + let mut target_keys = Vec::new(); + for part in Regex::new(r"(?i)\s+and\s+") + .expect("valid and regex") + .split(predicate) + { + let Some((left, right)) = part.split_once('=') else { + return Err(SidemanticError::Validation(format!( + "compact join '{line}' must compare model columns" + ))); + }; + let left = left.trim().trim_matches(|c| c == '(' || c == ')').trim(); + let right = right.trim().trim_matches(|c| c == '(' || c == ')').trim(); + let Some((right_model, right_col)) = right.split_once('.') else { + return Err(SidemanticError::Validation(format!( + "compact join '{line}' must compare model columns" + ))); + }; + if right_model.trim() != target_model { + return Err(SidemanticError::Validation(format!( + "compact join '{line}' must compare columns from target model '{target_model}'" + ))); + } + if left.contains('.') { + return Err(SidemanticError::Validation(format!( + "compact join '{line}' must use local columns on the left side" + ))); + } + local_keys.push(left.to_string()); + target_keys.push(right_col.trim().to_string()); + } + + if local_keys.is_empty() { + return Err(SidemanticError::Validation(format!( + "compact join '{line}' must compare model columns" + ))); + } + + let (foreign_keys, primary_keys) = match rel_type { + RelationshipType::ManyToOne | RelationshipType::OneToOne => (local_keys, target_keys), + _ => (target_keys, local_keys), + }; + + let mut rel = Relationship::new(target_model); + rel.r#type = rel_type; + Ok(rel.with_key_columns(foreign_keys, primary_keys)) +} + +fn infer_compact_dimension_type(name: &str, expression: &str) -> String { + let lowered_name = name.to_ascii_lowercase(); + let lowered_expression = expression.to_ascii_lowercase(); + if lowered_name.contains("date") + || lowered_name.contains("time") + || lowered_name.ends_with("_at") + || lowered_expression.contains("date_trunc") + || lowered_expression.contains("timestamp") + || lowered_expression.contains("::date") + { + "time".to_string() + } else if lowered_expression.contains(" = ") + || lowered_expression.contains(" != ") + || lowered_expression.contains(" <> ") + || lowered_expression.contains(" > ") + || lowered_expression.contains(" < ") + { + "boolean".to_string() + } else if [" + ", " - ", " * ", " / "] + .iter() + .any(|operator| lowered_expression.contains(operator)) + { + "numeric".to_string() + } else { + infer_dimension_type(expression) + } +} + +fn compact_expression_references_metrics(expression: &str, metric_names: &HashSet) -> bool { + if metric_names.is_empty() { + return false; + } + let tokens = Regex::new(r"\b[A-Za-z_][A-Za-z0-9_]*\b") + .expect("valid identifier regex") + .find_iter(expression) + .map(|m| m.as_str()) + .collect::>(); + metric_names + .iter() + .any(|name| tokens.contains(name.as_str())) +} + +type CompactFieldDeclaration = (usize, String, String, Option, Option); + +fn build_compact_model( + name: String, + table: Option, + source_sql: Option, + body: &str, +) -> Result { + let mut model = Model::new(&name, "id"); + model.table = table; + model.sql = source_sql; + + let mut field_declarations: Vec = Vec::new(); + let mut metric_names = HashSet::new(); + let mut seen_fields = HashSet::new(); + let mut seen_segments = HashSet::new(); + + for (idx, raw_line) in body.lines().enumerate() { + let line = strip_line_comment(raw_line).trim(); + if line.is_empty() { + continue; + } + + let lower = line.to_ascii_lowercase(); + if lower.starts_with("primary key") { + let open = line.find('(').ok_or_else(|| { + SidemanticError::Validation("Primary key requires column list".to_string()) + })?; + let close = line.rfind(')').ok_or_else(|| { + SidemanticError::Validation("Primary key requires column list".to_string()) + })?; + if close <= open { + return Err(SidemanticError::Validation( + "Primary key requires at least one column".to_string(), + )); + } + let columns = split_columns(&line[open + 1..close])?; + model.primary_key = columns[0].clone(); + model.primary_key_columns = columns; + continue; + } + + if lower.starts_with("default time ") { + let rest = line["default time ".len()..].trim(); + let mut parts = rest.split_whitespace(); + let Some(dimension) = parts.next() else { + return Err(SidemanticError::Validation( + "default time requires a dimension".to_string(), + )); + }; + model.default_time_dimension = Some(dimension.to_string()); + if parts + .next() + .is_some_and(|part| part.eq_ignore_ascii_case("grain")) + { + if let Some(grain) = parts.next() { + model.default_grain = Some(grain.to_ascii_lowercase()); + } + } + continue; + } + + if lower.starts_with("segment ") { + let rest = line["segment ".len()..].trim(); + let lower_rest = rest.to_ascii_lowercase(); + let Some(as_idx) = lower_rest.find(" as ") else { + return Err(SidemanticError::Validation(format!( + "Unrecognized compact segment statement: {line}" + ))); + }; + let segment_name = rest[..as_idx].trim(); + let segment_sql = rest[as_idx + 4..].trim(); + if !seen_segments.insert(segment_name.to_string()) { + return Err(SidemanticError::Validation(format!( + "Model '{name}' defines segment '{segment_name}' more than once" + ))); + } + model + .segments + .push(Segment::new(segment_name, segment_sql.to_string())); + continue; + } + + if lower.starts_with("join ") { + model.relationships.push(parse_compact_join(line)?); + continue; + } + + if lower.starts_with("table ") { + return Err(SidemanticError::Validation(format!( + "compact model '{name}' must use `model {name} from ` instead of a table statement" + ))); + } + + let (field_line, annotation) = split_annotation(line); + let (dimension_type, granularity) = parse_compact_annotation(annotation)?; + let (expression, field_name) = split_field_alias(field_line); + if field_name.is_empty() { + return Err(SidemanticError::Validation(format!( + "Unrecognized statement in model '{name}': {line}" + ))); + } + if !seen_fields.insert(field_name.clone()) { + return Err(SidemanticError::Validation(format!( + "Model '{name}' defines field '{field_name}' more than once" + ))); + } + field_declarations.push((idx, field_name, expression, dimension_type, granularity)); + } + + let mut pending = Vec::new(); + let mut parsed_fields: Vec<(usize, bool, Dimension, Option)> = Vec::new(); + for (idx, field_name, expression, dimension_type, granularity) in field_declarations { + let metric_props = parse_metric_expression(&field_name, &expression); + if metric_props + .get("agg") + .is_some_and(|agg| agg != "expression") + { + if dimension_type.is_some() { + return Err(SidemanticError::Validation(format!( + "Field '{field_name}' in model '{name}' is a metric and cannot use dimension annotation" + ))); + } + let metric = build_metric(&metric_props).ok_or_else(|| { + SidemanticError::Validation(format!( + "failed to build compact metric '{field_name}'" + )) + })?; + metric_names.insert(field_name); + parsed_fields.push((idx, false, Dimension::new("__unused"), Some(metric))); + } else { + pending.push((idx, field_name, expression, dimension_type, granularity)); + } + } + + let mut remaining = Vec::new(); + for (idx, field_name, expression, dimension_type, granularity) in pending { + if compact_expression_references_metrics(&expression, &metric_names) { + if dimension_type.is_some() { + return Err(SidemanticError::Validation(format!( + "Field '{field_name}' in model '{name}' is a metric and cannot use dimension annotation" + ))); + } + let mut props = HashMap::new(); + props.insert("name".to_string(), field_name.clone()); + props.insert("type".to_string(), "derived".to_string()); + props.insert("sql".to_string(), expression); + let metric = build_metric(&props).ok_or_else(|| { + SidemanticError::Validation(format!( + "failed to build compact metric '{field_name}'" + )) + })?; + metric_names.insert(field_name); + parsed_fields.push((idx, false, Dimension::new("__unused"), Some(metric))); + } else { + remaining.push((idx, field_name, expression, dimension_type, granularity)); + } + } - model.dimensions.extend(dimensions); - model.metrics.extend(metrics); - model.segments.extend(segments); - model.relationships.extend(relationships); - model.pre_aggregations.extend(pre_aggregations); + for (idx, field_name, expression, dimension_type, granularity) in remaining { + let mut props = HashMap::new(); + props.insert("name".to_string(), field_name.clone()); + props.insert( + "type".to_string(), + dimension_type + .unwrap_or_else(|| infer_compact_dimension_type(&field_name, &expression)), + ); + if expression != field_name { + props.insert("sql".to_string(), expression); + } + if let Some(granularity) = granularity { + props.insert("granularity".to_string(), granularity); + } + let dimension = build_dimension(&props).ok_or_else(|| { + SidemanticError::Validation(format!("failed to build compact dimension '{field_name}'")) + })?; + parsed_fields.push((idx, true, dimension, None)); + } + + parsed_fields.sort_by_key(|(idx, _, _, _)| *idx); + for (_, is_dimension, dimension, metric) in parsed_fields { + if is_dimension { + model.dimensions.push(dimension); + } else if let Some(metric) = metric { + model.metrics.push(metric); + } + } + + if let Some(default_time) = model.default_time_dimension.as_ref() { + let Some(dimension) = model.get_dimension(default_time) else { + return Err(SidemanticError::Validation(format!( + "Default time dimension '{default_time}' in model '{name}' is not defined" + ))); + }; + if dimension.r#type != DimensionType::Time { + return Err(SidemanticError::Validation(format!( + "Default time dimension '{default_time}' in model '{name}' must be a time dimension" + ))); + } + } Ok(model) } +fn parse_compact_sql_models(sql: &str) -> Result> { + let header_re = Regex::new(r"(?is)\bmodel\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s*") + .expect("valid compact model header regex"); + let mut models = Vec::new(); + let mut remaining = sql; + + while let Some(captures) = header_re.captures(remaining) { + let matched = captures.get(0).unwrap(); + if !remaining[..matched.start()].trim().is_empty() { + return Err(SidemanticError::Validation( + "Rust compact SQL model parser does not support non-model statements before compact model blocks".to_string(), + )); + } + let name = captures[1].to_string(); + let after_from = &remaining[matched.end()..]; + let after_from = after_from.trim_start(); + + let (table, source_sql, before_body) = + if let Some(source_rest) = after_from.strip_prefix('(') { + let (source_sql, source_remainder) = parse_balanced_parens(source_rest) + .ok_or_else(|| { + SidemanticError::Validation(format!( + "compact model '{name}' has an unterminated SQL source" + )) + })?; + ( + None, + Some(source_sql.trim().to_string()), + source_remainder.trim_start(), + ) + } else { + let open = after_from.find('(').ok_or_else(|| { + SidemanticError::Validation(format!( + "compact model '{name}' must use `model {name} from
(...)`" + )) + })?; + let table = after_from[..open].trim(); + if table.is_empty() { + return Err(SidemanticError::Validation(format!( + "compact model '{name}' must use `model {name} from
(...)`" + ))); + } + ( + Some(table.to_string()), + None, + after_from[open..].trim_start(), + ) + }; + + let Some(body_rest) = before_body.strip_prefix('(') else { + return Err(SidemanticError::Validation(format!( + "compact model '{name}' must include a model body" + ))); + }; + let (body, rest) = parse_balanced_parens(body_rest).ok_or_else(|| { + SidemanticError::Validation(format!("compact model '{name}' has an unterminated body")) + })?; + models.push(build_compact_model(name, table, source_sql, body)?); + remaining = rest.trim_start(); + if let Some(after_semicolon) = remaining.strip_prefix(';') { + remaining = after_semicolon; + } + } + + if models.is_empty() { + return Err(SidemanticError::Validation( + "SQL definitions must include a compact model statement".to_string(), + )); + } + + if !remaining.trim().is_empty() { + return Err(SidemanticError::Validation( + "Rust compact SQL model parser does not support trailing graph-level definitions after compact model blocks".to_string(), + )); + } + + Ok(models) +} + +// ============================================================================ +// Public API +// ============================================================================ + +/// Parse SQL definitions into one or more models. +pub fn parse_sql_models(sql: &str) -> Result> { + if has_compact_model_syntax(sql) { + return parse_compact_sql_models(sql); + } + + let (_, statements) = + parse_file(sql).map_err(|e| SidemanticError::Validation(format!("Parse error: {e}")))?; + + parse_legacy_sql_models_from_statements(statements) +} + +/// Parse SQL definitions into the first model. +pub fn parse_sql_model(sql: &str) -> Result { + parse_sql_models(sql).and_then(|models| { + models.into_iter().next().ok_or_else(|| { + SidemanticError::Validation("SQL definitions must include a MODEL statement".into()) + }) + }) +} + /// Parse SQL into statement blocks preserving high-level statement kinds/properties. pub fn parse_sql_statement_blocks(sql: &str) -> Result> { let (_, statements) = @@ -1388,6 +1990,8 @@ fn build_relationship(props: &HashMap) -> Option { let foreign_key_columns = parse_key_columns(props, "foreign_key"); let primary_key_columns = parse_key_columns(props, "primary_key"); + let through_foreign_key_columns = parse_key_columns(props, "through_foreign_key"); + let related_foreign_key_columns = parse_key_columns(props, "related_foreign_key"); Some(Relationship { name: name.clone(), @@ -1401,8 +2005,14 @@ fn build_relationship(props: &HashMap) -> Option { .and_then(|columns| columns.first().cloned()), primary_key_columns, through: props.get("through").cloned(), - through_foreign_key: props.get("through_foreign_key").cloned(), - related_foreign_key: props.get("related_foreign_key").cloned(), + through_foreign_key: through_foreign_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + through_foreign_key_columns, + related_foreign_key: related_foreign_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + related_foreign_key_columns, sql: props.get("sql").cloned(), metadata: props.get("metadata").map(|value| parse_literal(value)), }) @@ -1707,6 +2317,10 @@ mod tests { METRIC revenue AS SUM(COALESCE(amount, 0)); METRIC unique_customers AS COUNT(DISTINCT customer_id); METRIC median_amount AS MEDIAN(amount); + METRIC amount_stddev AS STDDEV(amount); + METRIC amount_stddev_pop AS STDDEV_POP(amount); + METRIC amount_variance AS VARIANCE(amount); + METRIC amount_variance_pop AS VARIANCE_POP(amount); METRIC approximate_customers AS APPROX_COUNT_DISTINCT(customer_id); "#; @@ -1724,6 +2338,22 @@ mod tests { assert_eq!(median_amount.agg, Some(Aggregation::Median)); assert_eq!(median_amount.sql, Some("amount".to_string())); + let amount_stddev = model.get_metric("amount_stddev").unwrap(); + assert_eq!(amount_stddev.agg, Some(Aggregation::Stddev)); + assert_eq!(amount_stddev.sql, Some("amount".to_string())); + + let amount_stddev_pop = model.get_metric("amount_stddev_pop").unwrap(); + assert_eq!(amount_stddev_pop.agg, Some(Aggregation::StddevPop)); + assert_eq!(amount_stddev_pop.sql, Some("amount".to_string())); + + let amount_variance = model.get_metric("amount_variance").unwrap(); + assert_eq!(amount_variance.agg, Some(Aggregation::Variance)); + assert_eq!(amount_variance.sql, Some("amount".to_string())); + + let amount_variance_pop = model.get_metric("amount_variance_pop").unwrap(); + assert_eq!(amount_variance_pop.agg, Some(Aggregation::VariancePop)); + assert_eq!(amount_variance_pop.sql, Some("amount".to_string())); + let approximate_customers = model.get_metric("approximate_customers").unwrap(); assert_eq!(approximate_customers.agg, Some(Aggregation::Expression)); assert_eq!( @@ -1732,6 +2362,170 @@ mod tests { ); } + #[test] + fn test_parse_compact_sql_model() { + let sql = r#" +model orders from orders ( + primary key (order_id, store_id) + default time order_date grain day + + status + date_trunc('day', created_at) as order_date : time grain day + status = 'completed' as is_complete : boolean + amount - discount as net_amount : numeric + + segment completed as status = 'completed' + + join one customers on customer_id = customers.id + join many order_items on (order_id = order_items.order_id and store_id = order_items.store_id) + + revenue / order_count as average_order_value + sum(amount) as revenue + count(*) as order_count +) +"#; + + let model = parse_sql_model(sql).unwrap(); + + assert_eq!(model.name, "orders"); + assert_eq!(model.table.as_deref(), Some("orders")); + assert_eq!( + model.primary_key_columns, + vec!["order_id".to_string(), "store_id".to_string()] + ); + assert_eq!(model.default_time_dimension.as_deref(), Some("order_date")); + assert_eq!(model.default_grain.as_deref(), Some("day")); + + let status = model.get_dimension("status").unwrap(); + assert_eq!(status.r#type, DimensionType::Categorical); + assert_eq!(status.sql, None); + + let order_date = model.get_dimension("order_date").unwrap(); + assert_eq!(order_date.r#type, DimensionType::Time); + assert_eq!( + order_date.sql.as_deref(), + Some("date_trunc('day', created_at)") + ); + assert_eq!(order_date.granularity.as_deref(), Some("day")); + + let is_complete = model.get_dimension("is_complete").unwrap(); + assert_eq!(is_complete.r#type, DimensionType::Boolean); + assert_eq!(is_complete.sql.as_deref(), Some("status = 'completed'")); + + let net_amount = model.get_dimension("net_amount").unwrap(); + assert_eq!(net_amount.r#type, DimensionType::Numeric); + assert_eq!(net_amount.sql.as_deref(), Some("amount - discount")); + + let completed = model.get_segment("completed").unwrap(); + assert_eq!(completed.sql, "status = 'completed'"); + + let customers = model.get_relationship("customers").unwrap(); + assert_eq!(customers.r#type, RelationshipType::ManyToOne); + assert_eq!( + customers.foreign_key_columns(), + vec!["customer_id".to_string()] + ); + assert_eq!(customers.primary_key_columns(), vec!["id".to_string()]); + + let order_items = model.get_relationship("order_items").unwrap(); + assert_eq!(order_items.r#type, RelationshipType::OneToMany); + assert_eq!( + order_items.foreign_key_columns(), + vec!["order_id".to_string(), "store_id".to_string()] + ); + assert_eq!( + order_items.primary_key_columns(), + vec!["order_id".to_string(), "store_id".to_string()] + ); + + let revenue = model.get_metric("revenue").unwrap(); + assert_eq!(revenue.agg, Some(Aggregation::Sum)); + assert_eq!(revenue.sql.as_deref(), Some("amount")); + + let order_count = model.get_metric("order_count").unwrap(); + assert_eq!(order_count.agg, Some(Aggregation::Count)); + assert_eq!(order_count.sql, None); + + let average_order_value = model.get_metric("average_order_value").unwrap(); + assert_eq!(average_order_value.r#type, MetricType::Derived); + assert_eq!( + average_order_value.sql.as_deref(), + Some("revenue / order_count") + ); + } + + #[test] + fn test_parse_compact_sql_models_multiple_and_derived_source() { + let sql = r#" +model completed_orders from ( + select * + from raw.orders + where status = 'completed' +) ( + primary key (order_id) + created_at as order_date : time grain day + sum(amount) as revenue +) + +model customers from public.customers ( + primary key (id) + region +) +"#; + + let models = parse_sql_models(sql).unwrap(); + + assert_eq!( + models.iter().map(|m| m.name.as_str()).collect::>(), + vec!["completed_orders", "customers"] + ); + assert!(models[0].table.is_none()); + assert_eq!( + models[0].sql.as_deref(), + Some("select *\n from raw.orders\n where status = 'completed'") + ); + assert_eq!( + models[0].get_metric("revenue").unwrap().agg, + Some(Aggregation::Sum) + ); + assert_eq!(models[1].table.as_deref(), Some("public.customers")); + assert!(models[1].get_dimension("region").is_some()); + } + + #[test] + fn test_parse_legacy_sql_models_multiple() { + let sql = r#" +MODEL (name orders, table orders, primary_key order_id); +METRIC order_count AS COUNT(*); + +MODEL (name customers, table customers, primary_key customer_id); +METRIC customer_count AS COUNT(*); +"#; + + let models = parse_sql_models(sql).unwrap(); + + assert_eq!( + models.iter().map(|m| m.name.as_str()).collect::>(), + vec!["orders", "customers"] + ); + assert!(models[0].get_metric("order_count").is_some()); + assert!(models[0].get_metric("customer_count").is_none()); + assert!(models[1].get_metric("customer_count").is_some()); + } + + #[test] + fn test_parse_compact_sql_model_rejects_bad_join() { + let sql = r#" +model orders from orders ( + primary key (order_id) + join one customers on customer_id = 1 +) +"#; + + let err = parse_sql_model(sql).unwrap_err(); + assert!(err.to_string().contains("must compare model columns")); + } + #[test] fn test_simple_dimension_syntax() { let sql = r#" @@ -1750,6 +2544,30 @@ mod tests { assert_eq!(order_date.sql, Some("created_at".to_string())); } + #[test] + fn test_simple_segment_syntax() { + let sql = r#" + MODEL (name orders, table orders); + SEGMENT completed AS status = 'completed'; + "#; + + let model = parse_sql_model(sql).unwrap(); + assert_eq!(model.segments.len(), 1); + + let completed = model.get_segment("completed").unwrap(); + assert_eq!(completed.sql, "status = 'completed'"); + } + + #[test] + fn test_simple_segment_graph_definition_syntax() { + let sql = "SEGMENT completed AS status = 'completed';"; + + let (_metrics, segments) = parse_sql_definitions(sql).unwrap(); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].name, "completed"); + assert_eq!(segments[0].sql, "status = 'completed'"); + } + #[test] fn test_mixed_syntax() { let sql = r#" diff --git a/sidemantic-rs/src/core/dependency.rs b/sidemantic-rs/src/core/dependency.rs index 5cd12910..f8f10a2b 100644 --- a/sidemantic-rs/src/core/dependency.rs +++ b/sidemantic-rs/src/core/dependency.rs @@ -49,6 +49,14 @@ pub fn extract_dependencies_with_context( // Resolve references using graph if available if let Some(g) = graph { for ref_name in refs { + if has_inline_aggregation(sql) { + if let Some(resolved) = + resolve_metric_reference(&ref_name, g, model_context) + { + deps.insert(resolved); + } + continue; + } let resolved = resolve_reference(&ref_name, g, model_context); deps.insert(resolved); } @@ -96,6 +104,46 @@ fn has_operators(s: &str) -> bool { .any(|&op| s.contains(op)) } +fn has_inline_aggregation(sql: &str) -> bool { + let lower = sql.to_ascii_lowercase(); + let bytes = lower.as_bytes(); + let aggregate_names = [ + "sum", + "avg", + "count", + "min", + "max", + "median", + "stddev", + "stddev_pop", + "variance", + "variance_pop", + ]; + + for name in aggregate_names { + let mut start = 0; + while let Some(offset) = lower[start..].find(name) { + let name_start = start + offset; + let name_end = name_start + name.len(); + let before_is_ident = name_start > 0 + && (bytes[name_start - 1].is_ascii_alphanumeric() || bytes[name_start - 1] == b'_'); + let after_is_ident = name_end < bytes.len() + && (bytes[name_end].is_ascii_alphanumeric() || bytes[name_end] == b'_'); + if before_is_ident || after_is_ident { + start = name_end; + continue; + } + + if lower[name_end..].trim_start().starts_with('(') { + return true; + } + start = name_end; + } + } + + false +} + /// Extract column references from a SQL expression /// /// Uses polyglot-sql to parse the expression and find all column identifiers. @@ -103,6 +151,10 @@ fn extract_column_references(sql: &str) -> HashSet { let mut refs = HashSet::new(); let normalized_sql = sql.replace("${CUBE}.", "").replace("${CUBE}", ""); + if has_inline_aggregation(&normalized_sql) { + return extract_simple_references(&normalized_sql); + } + // polyglot-sql traversal can recurse indefinitely on some PostgreSQL cast // forms (expr::type). Fall back to the tokenizer path for these expressions. if normalized_sql.contains("::") { @@ -272,6 +324,45 @@ fn resolve_reference(ref_name: &str, graph: &SemanticGraph, model_context: Optio ref_name.to_string() } +fn resolve_metric_reference( + ref_name: &str, + graph: &SemanticGraph, + model_context: Option<&str>, +) -> Option { + if graph.get_metric(ref_name).is_some() { + return Some(ref_name.to_string()); + } + + if let Some((model_name, metric_name)) = ref_name.rsplit_once('.') { + if graph + .get_model(model_name) + .and_then(|model| model.get_metric(metric_name)) + .is_some() + { + return Some(ref_name.to_string()); + } + return None; + } + + if let Some(context_model_name) = model_context { + if graph + .get_model(context_model_name) + .and_then(|model| model.get_metric(ref_name)) + .is_some() + { + return Some(format!("{context_model_name}.{ref_name}")); + } + } + + for model in graph.models() { + if model.get_metric(ref_name).is_some() { + return Some(format!("{}.{}", model.name, ref_name)); + } + } + + None +} + /// Build a dependency graph for all metrics and check for cycles pub fn check_circular_dependencies( metrics: &[(&str, &Metric)], @@ -331,7 +422,7 @@ pub fn check_circular_dependencies( mod tests { use super::*; #[allow(unused_imports)] - use crate::core::model::Aggregation; + use crate::core::model::{Aggregation, Dimension, Model}; #[test] fn test_ratio_dependencies() { @@ -367,6 +458,41 @@ mod tests { assert!(deps.is_empty()); } + #[test] + fn test_inline_aggregation_skips_raw_field_references_with_graph() { + let mut graph = SemanticGraph::new(); + graph + .add_model( + Model::new("orders", "id") + .with_table("orders") + .with_dimension(Dimension::categorical("status")) + .with_metric(Metric::sum("revenue", "amount")), + ) + .unwrap(); + let metric = Metric::derived("computed_revenue", "SUM(orders.amount) * 2"); + + let deps = extract_dependencies(&metric, Some(&graph)); + + assert!(deps.is_empty()); + } + + #[test] + fn test_inline_aggregation_keeps_metric_references_with_graph() { + let mut graph = SemanticGraph::new(); + graph + .add_model( + Model::new("orders", "id") + .with_table("orders") + .with_metric(Metric::sum("revenue", "amount")), + ) + .unwrap(); + let metric = Metric::derived("computed_revenue", "SUM(orders.revenue) * 2"); + + let deps = extract_dependencies(&metric, Some(&graph)); + + assert_eq!(deps, HashSet::from(["orders.revenue".to_string()])); + } + #[test] fn test_extract_column_references() { let refs = extract_column_references("(revenue - cost) / revenue"); diff --git a/sidemantic-rs/src/core/graph.rs b/sidemantic-rs/src/core/graph.rs index 4cfbe37c..30028a1a 100644 --- a/sidemantic-rs/src/core/graph.rs +++ b/sidemantic-rs/src/core/graph.rs @@ -78,7 +78,7 @@ type AdjacencyEdge = ( ); /// The semantic graph holds all models and their relationships -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct SemanticGraph { models: HashMap, metrics: HashMap, @@ -411,13 +411,19 @@ impl SemanticGraph { continue; } - let (source_fk_opt, target_fk_opt) = rel.junction_keys(); - let (Some(source_fk), Some(target_fk)) = (source_fk_opt, target_fk_opt) - else { + let (source_fks, target_fks) = rel.junction_key_columns(); + if source_fks.is_empty() || target_fks.is_empty() { continue; - }; + } - let source_pk = model.primary_keys(); + let source_pk = { + let keys = model.primary_keys(); + if keys.is_empty() { + vec!["id".to_string()] + } else { + keys + } + }; let target_pk = if rel.primary_key.is_some() || rel.primary_key_columns.is_some() { rel.primary_key_columns() @@ -427,20 +433,17 @@ impl SemanticGraph { .map(|target_model| target_model.primary_keys()) .unwrap_or_else(|| vec!["id".to_string()]) }; - let source_pk_first = source_pk - .first() - .cloned() - .unwrap_or_else(|| "id".to_string()); - let target_pk_first = target_pk - .first() - .cloned() - .unwrap_or_else(|| "id".to_string()); + let target_pk = if target_pk.is_empty() { + vec!["id".to_string()] + } else { + target_pk + }; // source -> through (one_to_many) self.adjacency.entry(model.name.clone()).or_default().push(( through_name.clone(), - vec![source_pk_first.clone()], - vec![source_fk.clone()], + source_pk.clone(), + source_fks.clone(), RelationshipType::OneToMany, None, )); @@ -450,8 +453,8 @@ impl SemanticGraph { .or_default() .push(( model.name.clone(), - vec![source_fk], - vec![source_pk_first], + source_fks, + source_pk, RelationshipType::ManyToOne, None, )); @@ -462,16 +465,16 @@ impl SemanticGraph { .or_default() .push(( rel.name.clone(), - vec![target_fk.clone()], - vec![target_pk_first.clone()], + target_fks.clone(), + target_pk.clone(), RelationshipType::ManyToOne, None, )); // target -> through (one_to_many) self.adjacency.entry(rel.name.clone()).or_default().push(( through_name.clone(), - vec![target_pk_first], - vec![target_fk], + target_pk, + target_fks, RelationshipType::OneToMany, None, )); @@ -854,6 +857,63 @@ mod tests { assert_eq!(path.steps.len(), 1); } + #[test] + fn test_one_to_many_omitted_key_defaults_to_id() { + let mut graph = SemanticGraph::new(); + + let customers = Model::new("customers", "id") + .with_table("customers") + .with_relationship(Relationship::one_to_many("orders")); + let orders = Model::new("orders", "id").with_table("orders"); + + graph.add_model(customers).unwrap(); + graph.add_model(orders).unwrap(); + + let path = graph.find_join_path("customers", "orders").unwrap(); + assert_eq!(path.steps.len(), 1); + assert_eq!(path.steps[0].from_keys, vec!["id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["id".to_string()]); + } + + #[test] + fn test_many_to_one_omitted_keys_use_name_id_and_target_primary_key() { + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_relationship(Relationship::many_to_one("customers")); + let customers = Model::new("customers", "customer_uid").with_table("customers"); + + graph.add_model(orders).unwrap(); + graph.add_model(customers).unwrap(); + + let path = graph.find_join_path("orders", "customers").unwrap(); + assert_eq!(path.steps.len(), 1); + assert_eq!(path.steps[0].from_keys, vec!["customers_id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["customer_uid".to_string()]); + } + + #[test] + fn test_one_to_one_omitted_key_defaults_to_id() { + let mut graph = SemanticGraph::new(); + + let mut relationship = Relationship::new("profiles"); + relationship.r#type = RelationshipType::OneToOne; + + let users = Model::new("users", "id") + .with_table("users") + .with_relationship(relationship); + let profiles = Model::new("profiles", "id").with_table("profiles"); + + graph.add_model(users).unwrap(); + graph.add_model(profiles).unwrap(); + + let path = graph.find_join_path("users", "profiles").unwrap(); + assert_eq!(path.steps.len(), 1); + assert_eq!(path.steps[0].from_keys, vec!["id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["id".to_string()]); + } + #[test] fn test_parse_reference() { let graph = create_test_graph(); @@ -929,7 +989,9 @@ mod tests { primary_key_columns: None, through: None, through_foreign_key: None, + through_foreign_key_columns: None, related_foreign_key: None, + related_foreign_key_columns: None, sql: None, metadata: None, }); @@ -962,7 +1024,9 @@ mod tests { primary_key_columns: None, through: Some("order_items".to_string()), through_foreign_key: Some("order_id".to_string()), + through_foreign_key_columns: None, related_foreign_key: Some("product_id".to_string()), + related_foreign_key_columns: None, sql: None, metadata: None, }); @@ -995,6 +1059,63 @@ mod tests { assert_eq!(path.steps[1].relationship_type, RelationshipType::ManyToOne); } + #[test] + fn test_many_to_many_through_preserves_composite_primary_keys() { + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "tenant_id") + .with_primary_key_columns(vec!["tenant_id".to_string(), "order_id".to_string()]) + .with_table("orders") + .with_relationship(Relationship { + name: "products".to_string(), + r#type: RelationshipType::ManyToMany, + foreign_key: None, + foreign_key_columns: None, + primary_key: None, + primary_key_columns: None, + through: Some("order_items".to_string()), + through_foreign_key: Some("order_id".to_string()), + through_foreign_key_columns: Some(vec![ + "tenant_id".to_string(), + "order_id".to_string(), + ]), + related_foreign_key: Some("product_id".to_string()), + related_foreign_key_columns: Some(vec![ + "tenant_id".to_string(), + "product_id".to_string(), + ]), + sql: None, + metadata: None, + }); + let order_items = Model::new("order_items", "id").with_table("order_items"); + let products = Model::new("products", "tenant_id") + .with_primary_key_columns(vec!["tenant_id".to_string(), "product_id".to_string()]) + .with_table("products"); + + graph.add_model(orders).unwrap(); + graph.add_model(order_items).unwrap(); + graph.add_model(products).unwrap(); + + let path = graph.find_join_path("orders", "products").unwrap(); + assert_eq!(path.steps.len(), 2); + assert_eq!( + path.steps[0].from_keys, + vec!["tenant_id".to_string(), "order_id".to_string()] + ); + assert_eq!( + path.steps[0].to_keys, + vec!["tenant_id".to_string(), "order_id".to_string()] + ); + assert_eq!( + path.steps[1].from_keys, + vec!["tenant_id".to_string(), "product_id".to_string()] + ); + assert_eq!( + path.steps[1].to_keys, + vec!["tenant_id".to_string(), "product_id".to_string()] + ); + } + #[test] fn test_find_join_path_with_composite_keys() { let mut graph = SemanticGraph::new(); diff --git a/sidemantic-rs/src/core/model.rs b/sidemantic-rs/src/core/model.rs index 383bc60c..b0c11697 100644 --- a/sidemantic-rs/src/core/model.rs +++ b/sidemantic-rs/src/core/model.rs @@ -128,6 +128,10 @@ pub enum Aggregation { Min, Max, Median, + Stddev, + StddevPop, + Variance, + VariancePop, /// Raw expression that already contains aggregation (e.g., SUM(amount) * 2) Expression, } @@ -142,6 +146,10 @@ impl Aggregation { Aggregation::Min => "MIN", Aggregation::Max => "MAX", Aggregation::Median => "MEDIAN", + Aggregation::Stddev => "STDDEV", + Aggregation::StddevPop => "STDDEV_POP", + Aggregation::Variance => "VARIANCE", + Aggregation::VariancePop => "VAR_POP", Aggregation::Expression => "", // Not used - expression stored in sql field } } @@ -605,7 +613,8 @@ pub struct Relationship { pub name: String, #[serde(default)] pub r#type: RelationshipType, - /// Foreign key column (defaults to {name}_id) + /// Foreign key column. + /// Defaults to `{name}_id` for `many_to_one`, and `id` for one-to-one/one-to-many compatibility. pub foreign_key: Option, /// Foreign key columns for composite relationships #[serde(default)] @@ -621,9 +630,15 @@ pub struct Relationship { /// Foreign key in junction model pointing to this model #[serde(default)] pub through_foreign_key: Option, + /// Foreign key columns in junction model pointing to this model + #[serde(default)] + pub through_foreign_key_columns: Option>, /// Foreign key in junction model pointing to related model #[serde(default)] pub related_foreign_key: Option, + /// Foreign key columns in junction model pointing to related model + #[serde(default)] + pub related_foreign_key_columns: Option>, /// Custom SQL join condition (overrides FK/PK) /// Use {from} and {to} placeholders for table aliases #[serde(default)] @@ -644,7 +659,9 @@ impl Relationship { primary_key_columns: None, through: None, through_foreign_key: None, + through_foreign_key_columns: None, related_foreign_key: None, + related_foreign_key_columns: None, sql: None, metadata: None, } @@ -716,7 +733,13 @@ impl Relationship { .clone() .filter(|columns| !columns.is_empty()) .or_else(|| self.foreign_key.clone().map(|key| vec![key])) - .unwrap_or_else(|| vec![format!("{}_id", self.name)]) + .unwrap_or_else(|| { + if self.r#type == RelationshipType::ManyToOne { + vec![format!("{}_id", self.name)] + } else { + vec!["id".to_string()] + } + }) } pub fn primary_key_columns(&self) -> Vec { @@ -738,13 +761,40 @@ impl Relationship { if self.r#type != RelationshipType::ManyToMany { return (None, None); } + let (source_keys, target_keys) = self.junction_key_columns(); ( - self.through_foreign_key - .clone() - .or_else(|| self.foreign_key.clone()), - self.related_foreign_key.clone(), + source_keys.into_iter().next(), + target_keys.into_iter().next(), ) } + + /// Get junction key columns for many-to-many relationships. + pub fn junction_key_columns(&self) -> (Vec, Vec) { + if self.r#type != RelationshipType::ManyToMany { + return (Vec::new(), Vec::new()); + } + + let source_keys = self + .through_foreign_key_columns + .clone() + .filter(|columns| !columns.is_empty()) + .or_else(|| self.through_foreign_key.clone().map(|key| vec![key])) + .or_else(|| { + self.foreign_key_columns + .clone() + .filter(|columns| !columns.is_empty()) + }) + .or_else(|| self.foreign_key.clone().map(|key| vec![key])) + .unwrap_or_default(); + let target_keys = self + .related_foreign_key_columns + .clone() + .filter(|columns| !columns.is_empty()) + .or_else(|| self.related_foreign_key.clone().map(|key| vec![key])) + .unwrap_or_default(); + + (source_keys, target_keys) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -1049,6 +1099,40 @@ mod tests { let metric = Metric::count_distinct("unique_customers", "customer_id"); assert_eq!(metric.to_sql(Some("o")), "COUNT(DISTINCT o.customer_id)"); + + let metric = Metric { + name: "revenue_stddev".to_string(), + extends: None, + agg: Some(Aggregation::Stddev), + sql: Some("amount".to_string()), + ..Metric::new("revenue_stddev") + }; + assert_eq!(metric.to_sql(Some("o")), "STDDEV(o.amount)"); + + let metric = Metric { + name: "revenue_variance_pop".to_string(), + extends: None, + agg: Some(Aggregation::VariancePop), + sql: Some("amount".to_string()), + ..Metric::new("revenue_variance_pop") + }; + assert_eq!(metric.to_sql(None), "VAR_POP(amount)"); + } + + #[test] + fn test_relationship_default_foreign_keys_match_native_contract() { + let rel = Relationship::many_to_one("customers"); + assert_eq!(rel.foreign_key_columns(), vec!["customers_id".to_string()]); + assert_eq!(rel.primary_key_columns(), vec!["id".to_string()]); + + let rel = Relationship::one_to_many("orders"); + assert_eq!(rel.foreign_key_columns(), vec!["id".to_string()]); + assert_eq!(rel.primary_key_columns(), vec!["id".to_string()]); + + let mut rel = Relationship::new("profile"); + rel.r#type = RelationshipType::OneToOne; + assert_eq!(rel.foreign_key_columns(), vec!["id".to_string()]); + assert_eq!(rel.primary_key_columns(), vec!["id".to_string()]); } #[test] diff --git a/sidemantic-rs/src/main.rs b/sidemantic-rs/src/main.rs index 95e1accf..bf491720 100644 --- a/sidemantic-rs/src/main.rs +++ b/sidemantic-rs/src/main.rs @@ -174,6 +174,9 @@ fn print_help() { sidemantic preagg refresh --models ./models --model orders --name daily_revenue --mode full\n\ sidemantic serve --models ./models --bind 127.0.0.1:5544\n\ \n\ + Model loading: Rust --models accepts native Sidemantic YAML/SQL and Cube YAML only.\n\ + Convert LookML, MetricFlow, Hex, Rill, Malloy, and other external formats with the Python CLI/API first.\n\ + \n\ Use ' --help' for command-specific usage." ); } diff --git a/sidemantic-rs/src/runtime.rs b/sidemantic-rs/src/runtime.rs index 55cb5c32..83d5b28a 100644 --- a/sidemantic-rs/src/runtime.rs +++ b/sidemantic-rs/src/runtime.rs @@ -153,7 +153,11 @@ struct GraphPathRelationshipPayload { has_primary_key: bool, through: Option, through_foreign_key: Option, + #[serde(default)] + through_foreign_key_columns: Vec, related_foreign_key: Option, + #[serde(default)] + related_foreign_key_columns: Vec, } #[derive(Debug, Deserialize)] @@ -622,6 +626,10 @@ fn parse_metric_agg_for_dependencies(agg: Option<&str>) -> Option { Some("min") => Some(Aggregation::Min), Some("max") => Some(Aggregation::Max), Some("median") => Some(Aggregation::Median), + Some("stddev") => Some(Aggregation::Stddev), + Some("stddev_pop") => Some(Aggregation::StddevPop), + Some("variance") => Some(Aggregation::Variance), + Some("variance_pop") | Some("var_pop") => Some(Aggregation::VariancePop), Some("expression") => Some(Aggregation::Expression), _ => None, } @@ -1599,6 +1607,10 @@ fn resolve_wasm_aggregate_projection( "min" => Aggregation::Min, "max" => Aggregation::Max, "median" => Aggregation::Median, + "stddev" => Aggregation::Stddev, + "stddev_pop" => Aggregation::StddevPop, + "variance" => Aggregation::Variance, + "variance_pop" => Aggregation::VariancePop, "count" => Aggregation::Count, "count_distinct" => Aggregation::CountDistinct, _ => { @@ -4168,7 +4180,17 @@ pub fn parse_simple_metric_aggregation(sql_expr: &str) -> Option<(String, Option let func = trimmed[..open_paren].trim().to_lowercase(); if !matches!( func.as_str(), - "sum" | "avg" | "min" | "max" | "median" | "count" + "sum" + | "avg" + | "min" + | "max" + | "median" + | "stddev" + | "stddev_pop" + | "variance" + | "variance_pop" + | "var_pop" + | "count" ) { return None; } @@ -4204,11 +4226,17 @@ pub fn parse_simple_metric_aggregation(sql_expr: &str) -> Option<(String, Option let inner = trimmed[open_paren + 1..close_paren].trim(); match func.as_str() { - "sum" | "avg" | "min" | "max" | "median" => { + "sum" | "avg" | "min" | "max" | "median" | "stddev" | "stddev_pop" | "variance" + | "variance_pop" | "var_pop" => { if inner.is_empty() { None } else { - Some((func, Some(inner.to_string()))) + let agg = if func == "var_pop" { + "variance_pop".to_string() + } else { + func + }; + Some((agg, Some(inner.to_string()))) } } "count" => { @@ -4790,6 +4818,10 @@ fn catalog_aggregation_name(aggregation: Option<&Aggregation>) -> Option<&'stati Some(Aggregation::Min) => Some("min"), Some(Aggregation::Max) => Some("max"), Some(Aggregation::Median) => Some("median"), + Some(Aggregation::Stddev) => Some("stddev"), + Some(Aggregation::StddevPop) => Some("stddev_pop"), + Some(Aggregation::Variance) => Some("variance"), + Some(Aggregation::VariancePop) => Some("variance_pop"), Some(Aggregation::Expression) => Some("expression"), None => None, } @@ -4798,7 +4830,10 @@ fn catalog_aggregation_name(aggregation: Option<&Aggregation>) -> Option<&'stati fn catalog_metric_data_type(aggregation: Option<&str>) -> &'static str { match aggregation { Some("count" | "count_distinct") => "BIGINT", - Some("sum" | "avg" | "min" | "max" | "median" | "percentile") => "NUMERIC", + Some( + "sum" | "avg" | "min" | "max" | "median" | "stddev" | "stddev_pop" | "variance" + | "variance_pop" | "percentile", + ) => "NUMERIC", _ => "NUMERIC", } } @@ -5015,6 +5050,20 @@ impl SidemanticRuntime { Some(Aggregation::Median) => { select_exprs.push(format!("MEDIAN({sql_expr}) as {measure_name}_raw")); } + Some(Aggregation::Stddev) => { + select_exprs.push(format!("STDDEV({sql_expr}) as {measure_name}_raw")); + } + Some(Aggregation::StddevPop) => { + select_exprs + .push(format!("STDDEV_POP({sql_expr}) as {measure_name}_raw")); + } + Some(Aggregation::Variance) => { + select_exprs + .push(format!("VARIANCE({sql_expr}) as {measure_name}_raw")); + } + Some(Aggregation::VariancePop) => { + select_exprs.push(format!("VAR_POP({sql_expr}) as {measure_name}_raw")); + } Some(Aggregation::Expression) | None => { select_exprs.push(format!("SUM({sql_expr}) as {measure_name}_raw")); } @@ -5520,7 +5569,23 @@ fn semantic_graph_from_graph_path_payload( } }, ), + through_foreign_key_columns: if relationship_payload + .through_foreign_key_columns + .is_empty() + { + None + } else { + Some(relationship_payload.through_foreign_key_columns.clone()) + }, related_foreign_key: relationship_payload.related_foreign_key.clone(), + related_foreign_key_columns: if relationship_payload + .related_foreign_key_columns + .is_empty() + { + None + } else { + Some(relationship_payload.related_foreign_key_columns.clone()) + }, sql: None, metadata: None, }); @@ -6043,6 +6108,10 @@ type: one_to_many Some("customer_id".to_string()) )) ); + assert_eq!( + parse_simple_metric_aggregation("VARIANCE_POP(amount)"), + Some(("variance_pop".to_string(), Some("amount".to_string()))) + ); assert_eq!(parse_simple_metric_aggregation("revenue + cost"), None); let simple_metric_yaml = r#" @@ -6054,6 +6123,16 @@ sql: amount assert_eq!(metric_sql_expr(simple_metric_yaml).unwrap(), "amount"); assert!(metric_is_simple_aggregation(simple_metric_yaml).unwrap()); + let stats_metric_yaml = r#" +name: amount_stddev_pop +agg: stddev_pop +sql: amount +"#; + assert_eq!( + metric_to_sql(stats_metric_yaml).unwrap(), + "STDDEV_POP(amount)" + ); + let count_metric_yaml = r#" name: orders agg: count diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index 9bcf6172..17b12dc1 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -444,7 +444,11 @@ impl<'a> SqlGenerator<'a> { }; let output_alias = self.output_alias(&dim_ref.model, &dim_ref.alias, &alias_collisions); - select_parts.push(format!(" {} AS {}", sql_expr, output_alias)); + select_parts.push(format!( + " {} AS {}", + sql_expr, + self.quote_identifier(&output_alias) + )); } // Add metrics to SELECT @@ -463,7 +467,7 @@ impl<'a> SqlGenerator<'a> { let output_alias = self.output_alias(&metric_ref.model, &metric_ref.alias, &alias_collisions); let raw_alias = format!("{}_raw", metric_ref.name); - let raw_col = format!("{alias}.{raw_alias}"); + let raw_col = format!("{alias}.{}", self.quote_identifier(&raw_alias)); let sql_expr = match metric.r#type { MetricType::Simple if query.ungrouped => raw_col.clone(), @@ -541,7 +545,11 @@ impl<'a> SqlGenerator<'a> { MetricType::Conversion => metric.to_sql(Some(&alias)), }; - select_parts.push(format!(" {} AS {}", sql_expr, output_alias)); + select_parts.push(format!( + " {} AS {}", + sql_expr, + self.quote_identifier(&output_alias) + )); } // Add table calculations to SELECT @@ -704,7 +712,9 @@ impl<'a> SqlGenerator<'a> { let mut refs = Vec::new(); for metric in metrics { - let (model, name) = if metric.contains('.') { + let (model, name) = if let Some((model, name)) = self.exact_metric_reference(metric)? { + (model, name) + } else if metric.contains('.') { let (model, name, _) = self.graph.parse_reference(metric)?; (model, name) } else { @@ -774,6 +784,23 @@ impl<'a> SqlGenerator<'a> { Ok(()) } + fn exact_metric_reference(&self, reference: &str) -> Result> { + let mut owners = Vec::new(); + for model in self.graph.models() { + if model.get_metric(reference).is_some() { + owners.push(model.name.clone()); + } + } + + match owners.len() { + 0 => Ok(None), + 1 => Ok(Some((owners[0].clone(), reference.to_string()))), + _ => Err(SidemanticError::InvalidReference { + reference: reference.to_string(), + }), + } + } + /// Find all models required by the query fn find_required_models( &self, @@ -1190,6 +1217,9 @@ impl<'a> SqlGenerator<'a> { default_model: &str, ) -> Result> { if reference.contains('.') { + if let Some((model_name, metric_name)) = self.exact_metric_reference(reference)? { + return Ok(Some((model_name, metric_name))); + } let (model_name, metric_name, _) = self.graph.parse_reference(reference)?; let Some(model) = self.graph.get_model(&model_name) else { return Ok(None); @@ -1256,6 +1286,14 @@ impl<'a> SqlGenerator<'a> { let head_len = trimmed.find(char::is_whitespace).unwrap_or(trimmed.len()); let (head, suffix) = trimmed.split_at(head_len); + for metric_ref in metric_refs { + if metric_ref.name == head || metric_ref.alias == head { + let alias = + self.output_alias(&metric_ref.model, &metric_ref.alias, alias_collisions); + return format!("{}{}", self.quote_identifier(&alias), suffix); + } + } + let Ok((model, field, granularity)) = self.graph.parse_reference(head) else { return trimmed.to_string(); }; @@ -1266,7 +1304,7 @@ impl<'a> SqlGenerator<'a> { && dim_ref.granularity.as_deref() == granularity.as_deref() { let alias = self.output_alias(&model, &dim_ref.alias, alias_collisions); - return format!("{alias}{suffix}"); + return format!("{}{}", self.quote_identifier(&alias), suffix); } } @@ -1274,7 +1312,7 @@ impl<'a> SqlGenerator<'a> { for metric_ref in metric_refs { if metric_ref.model == model && metric_ref.name == field { let alias = self.output_alias(&model, &metric_ref.alias, alias_collisions); - return format!("{alias}{suffix}"); + return format!("{}{}", self.quote_identifier(&alias), suffix); } } } @@ -3129,10 +3167,15 @@ impl<'a> SqlGenerator<'a> { let mut seen_models = HashSet::new(); for metric_ref in metrics { - if !metric_ref.contains('.') { - continue; - }; - let model_name = self.graph.parse_reference(metric_ref)?.0; + let model_name = + if let Some((model_name, _)) = self.exact_metric_reference(metric_ref)? { + model_name + } else { + if !metric_ref.contains('.') { + continue; + } + self.graph.parse_reference(metric_ref)?.0 + }; if !seen_models.insert(model_name.clone()) { continue; @@ -3316,7 +3359,13 @@ impl<'a> SqlGenerator<'a> { Some(Aggregation::Avg) => self .find_count_measure_for_avg(&metric.name, preagg_measures) .is_some(), - Some(Aggregation::CountDistinct) => false, + Some( + Aggregation::CountDistinct + | Aggregation::Stddev + | Aggregation::StddevPop + | Aggregation::Variance + | Aggregation::VariancePop, + ) => false, Some(Aggregation::Median | Aggregation::Expression) => true, } } @@ -3963,6 +4012,7 @@ impl<'a> SqlGenerator<'a> { | "SUM" | "THEN" | "TRUE" + | "VAR_POP" | "VARIANCE" | "VARIANCE_POP" | "VARCHAR" @@ -3997,8 +4047,12 @@ impl<'a> SqlGenerator<'a> { visited: &mut HashSet<(String, String)>, ) -> Result> { let (model_name, metric_name) = if reference.contains('.') { - let (m, n, _) = self.graph.parse_reference(reference)?; - (m, n) + if let Some((model_name, metric_name)) = self.exact_metric_reference(reference)? { + (model_name, metric_name) + } else { + let (m, n, _) = self.graph.parse_reference(reference)?; + (m, n) + } } else { let mut owners = Vec::new(); for model in self.graph.models() { @@ -4129,7 +4183,7 @@ impl<'a> SqlGenerator<'a> { fn is_inline_aggregate_expression(expr: &str) -> bool { let aggregate_re = regex::Regex::new( - r"(?i)\b(SUM|AVG|COUNT|MIN|MAX|MEDIAN|MODE|PERCENTILE_CONT|PERCENTILE_DISC|QUANTILE_CONT|QUANTILE_DISC|STDDEV|STDDEV_POP|VARIANCE|VARIANCE_POP)\s*\(", + r"(?i)\b(SUM|AVG|COUNT|MIN|MAX|MEDIAN|MODE|PERCENTILE_CONT|PERCENTILE_DISC|QUANTILE_CONT|QUANTILE_DISC|STDDEV|STDDEV_POP|VARIANCE|VARIANCE_POP|VAR_POP)\s*\(", ) .expect("valid aggregate regex"); aggregate_re.is_match(expr) @@ -4452,6 +4506,57 @@ mod tests { assert!(sql.contains("GROUP BY 1")); } + #[test] + fn test_statistical_aggregation_metrics_render_supported_sql() { + let mut graph = SemanticGraph::new(); + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_metric(Metric { + name: "amount_stddev".to_string(), + extends: None, + agg: Some(Aggregation::Stddev), + sql: Some("amount".to_string()), + ..Metric::new("amount_stddev") + }) + .with_metric(Metric { + name: "amount_stddev_pop".to_string(), + extends: None, + agg: Some(Aggregation::StddevPop), + sql: Some("amount".to_string()), + ..Metric::new("amount_stddev_pop") + }) + .with_metric(Metric { + name: "amount_variance".to_string(), + extends: None, + agg: Some(Aggregation::Variance), + sql: Some("amount".to_string()), + ..Metric::new("amount_variance") + }) + .with_metric(Metric { + name: "amount_variance_pop".to_string(), + extends: None, + agg: Some(Aggregation::VariancePop), + sql: Some("amount".to_string()), + ..Metric::new("amount_variance_pop") + }); + graph.add_model(orders).unwrap(); + + let generator = SqlGenerator::new(&graph); + let query = SemanticQuery::new().with_metrics(vec![ + "orders.amount_stddev".into(), + "orders.amount_stddev_pop".into(), + "orders.amount_variance".into(), + "orders.amount_variance_pop".into(), + ]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("STDDEV(orders_cte.amount_stddev_raw) AS amount_stddev")); + assert!(sql.contains("STDDEV_POP(orders_cte.amount_stddev_pop_raw) AS amount_stddev_pop")); + assert!(sql.contains("VARIANCE(orders_cte.amount_variance_raw) AS amount_variance")); + assert!(sql.contains("VAR_POP(orders_cte.amount_variance_pop_raw) AS amount_variance_pop")); + } + #[test] fn test_count_builder_uses_valid_raw_cte_expression() { let graph = create_test_graph(); @@ -4883,6 +4988,64 @@ mod tests { assert!(sql.contains(" AND ")); } + #[test] + fn test_query_with_composite_many_to_many_through_join() { + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "tenant_id") + .with_primary_key_columns(vec!["tenant_id".to_string(), "order_id".to_string()]) + .with_table("orders") + .with_metric(Metric::sum("revenue", "amount")) + .with_relationship(Relationship { + name: "products".to_string(), + r#type: RelationshipType::ManyToMany, + foreign_key: None, + foreign_key_columns: None, + primary_key: None, + primary_key_columns: None, + through: Some("order_items".to_string()), + through_foreign_key: Some("order_id".to_string()), + through_foreign_key_columns: Some(vec![ + "tenant_id".to_string(), + "order_id".to_string(), + ]), + related_foreign_key: Some("product_id".to_string()), + related_foreign_key_columns: Some(vec![ + "tenant_id".to_string(), + "product_id".to_string(), + ]), + sql: None, + metadata: None, + }); + let order_items = Model::new("order_items", "tenant_id") + .with_primary_key_columns(vec![ + "tenant_id".to_string(), + "order_id".to_string(), + "product_id".to_string(), + ]) + .with_table("order_items"); + let products = Model::new("products", "tenant_id") + .with_primary_key_columns(vec!["tenant_id".to_string(), "product_id".to_string()]) + .with_table("products") + .with_dimension(Dimension::categorical("name")); + + graph.add_model(orders).unwrap(); + graph.add_model(order_items).unwrap(); + graph.add_model(products).unwrap(); + + let generator = SqlGenerator::new(&graph); + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into()]) + .with_dimensions(vec!["products.name".into()]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("products_cte.tenant_id = order_items_cte.tenant_id")); + assert!(sql.contains("products_cte.product_id = order_items_cte.product_id")); + assert!(sql.contains("order_items_cte.tenant_id = orders_cte.tenant_id")); + assert!(sql.contains("order_items_cte.order_id = orders_cte.order_id")); + } + #[test] fn test_query_with_filter() { let graph = create_test_graph(); diff --git a/sidemantic-schema.json b/sidemantic-schema.json index 591dd48c..ae8885f2 100644 --- a/sidemantic-schema.json +++ b/sidemantic-schema.json @@ -988,6 +988,20 @@ "description": "Measures to pre-aggregate (e.g., ['count', 'revenue'])", "title": "Measures" }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Adapter-specific metadata payload", + "title": "Meta" + }, "name": { "description": "Unique pre-aggregation name", "title": "Name", @@ -1031,6 +1045,19 @@ "title": "Scheduled Refresh", "type": "boolean" }, + "sql": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "SQL for original_sql or custom pre-aggregation definitions", + "title": "Sql" + }, "time_dimension": { "anyOf": [ { @@ -1137,6 +1164,25 @@ "description": "Foreign key column(s) (defaults to {name}_id for many_to_one)", "title": "Foreign Key" }, + "foreign_key_columns": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Explicit source-column list (alias for foreign_key)", + "title": "Foreign Key Columns" + }, "metadata": { "anyOf": [ { @@ -1175,6 +1221,25 @@ "description": "Primary/unique key column(s): related model key for many_to_one, local model key for one_to_many", "title": "Primary Key" }, + "primary_key_columns": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Explicit target-column list (alias for primary_key)", + "title": "Primary Key Columns" + }, "related_foreign_key": { "anyOf": [ { @@ -1188,6 +1253,35 @@ "description": "Foreign key in junction model pointing to related model", "title": "Related Foreign Key" }, + "related_foreign_key_columns": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Foreign key columns in junction model pointing to related model", + "title": "Related Foreign Key Columns" + }, + "sql": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Custom join SQL using {from} and {to} runtime placeholders", + "title": "Sql" + }, "through": { "anyOf": [ { @@ -1214,6 +1308,22 @@ "description": "Foreign key in junction model pointing to this model", "title": "Through Foreign Key" }, + "through_foreign_key_columns": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Foreign key columns in junction model pointing to this model", + "title": "Through Foreign Key Columns" + }, "type": { "description": "Type of relationship", "enum": [ @@ -2781,6 +2891,20 @@ "description": "Measures to pre-aggregate (e.g., ['count', 'revenue'])", "title": "Measures" }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Adapter-specific metadata payload", + "title": "Meta" + }, "name": { "description": "Unique pre-aggregation name", "title": "Name", @@ -2824,6 +2948,19 @@ "title": "Scheduled Refresh", "type": "boolean" }, + "sql": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "SQL for original_sql or custom pre-aggregation definitions", + "title": "Sql" + }, "time_dimension": { "anyOf": [ { @@ -2930,6 +3067,25 @@ "description": "Foreign key column(s) (defaults to {name}_id for many_to_one)", "title": "Foreign Key" }, + "foreign_key_columns": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Explicit source-column list (alias for foreign_key)", + "title": "Foreign Key Columns" + }, "metadata": { "anyOf": [ { @@ -2968,6 +3124,25 @@ "description": "Primary/unique key column(s): related model key for many_to_one, local model key for one_to_many", "title": "Primary Key" }, + "primary_key_columns": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Explicit target-column list (alias for primary_key)", + "title": "Primary Key Columns" + }, "related_foreign_key": { "anyOf": [ { @@ -2981,6 +3156,35 @@ "description": "Foreign key in junction model pointing to related model", "title": "Related Foreign Key" }, + "related_foreign_key_columns": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Foreign key columns in junction model pointing to related model", + "title": "Related Foreign Key Columns" + }, + "sql": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Custom join SQL using {from} and {to} runtime placeholders", + "title": "Sql" + }, "through": { "anyOf": [ { @@ -3007,6 +3211,22 @@ "description": "Foreign key in junction model pointing to this model", "title": "Through Foreign Key" }, + "through_foreign_key_columns": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Foreign key columns in junction model pointing to this model", + "title": "Through Foreign Key Columns" + }, "type": { "description": "Type of relationship", "enum": [ diff --git a/sidemantic/adapters/sidemantic.py b/sidemantic/adapters/sidemantic.py index e08266e5..ab045ba4 100644 --- a/sidemantic/adapters/sidemantic.py +++ b/sidemantic/adapters/sidemantic.py @@ -22,6 +22,159 @@ ) NATIVE_FORMAT_VERSION = 1 +ROOT_FIELDS = { + "version", + "connection", + "models", + "metrics", + "parameters", + "sql_metrics", + "sql_segments", +} +MODEL_FIELDS = { + "name", + "extends", + "table", + "sql", + "source_uri", + "primary_key", + "primary_key_columns", + "unique_keys", + "description", + "label", + "metadata", + "meta", + "auto_dimensions", + "dimensions", + "metrics", + "measures", + "relationships", + "segments", + "pre_aggregations", + "default_time_dimension", + "default_grain", + "sql_metrics", + "sql_segments", +} +DIMENSION_FIELDS = { + "name", + "type", + "sql", + "expr", + "granularity", + "supported_granularities", + "description", + "label", + "metadata", + "meta", + "format", + "value_format_name", + "parent", + "window", + "public", +} +METRIC_FIELDS = { + "name", + "extends", + "type", + "agg", + "sql", + "expr", + "measure", + "numerator", + "denominator", + "offset_window", + "window", + "grain_to_date", + "window_expression", + "window_frame", + "window_order", + "base_metric", + "comparison_type", + "time_offset", + "calculation", + "entity", + "base_event", + "conversion_event", + "conversion_window", + "steps", + "cohort_event", + "activity_event", + "periods", + "retention_granularity", + "granularity", + "inner_metrics", + "entity_dimensions", + "having", + "fill_nulls_with", + "format", + "value_format_name", + "drill_fields", + "non_additive_dimension", + "filters", + "description", + "label", + "metadata", + "meta", + "public", +} +RELATIONSHIP_FIELDS = { + "name", + "type", + "foreign_key", + "foreign_key_columns", + "primary_key", + "primary_key_columns", + "through", + "through_foreign_key", + "through_foreign_key_columns", + "related_foreign_key", + "related_foreign_key_columns", + "sql", + "metadata", +} +SEGMENT_FIELDS = { + "name", + "sql", + "description", + "public", +} +PRE_AGGREGATION_FIELDS = { + "name", + "type", + "sql", + "measures", + "dimensions", + "time_dimension", + "granularity", + "partition_granularity", + "build_range_start", + "build_range_end", + "scheduled_refresh", + "refresh_key", + "indexes", + "meta", +} +REFRESH_KEY_FIELDS = { + "every", + "sql", + "incremental", + "update_window", +} +INDEX_FIELDS = { + "name", + "columns", + "type", +} +PARAMETER_FIELDS = { + "name", + "type", + "description", + "label", + "default_value", + "allowed_values", + "default_to_today", +} def substitute_env_vars(content: str) -> str: @@ -87,6 +240,25 @@ def validate_native_format_version(data: dict) -> None: ) +def reject_unknown_fields( + mapping: dict, + allowed_fields: set[str], + context: str, + *, + source_path: Path | None = None, +) -> None: + """Reject misspelled native fields before constructing permissive Pydantic models.""" + if not isinstance(mapping, dict): + location = f"{source_path}: " if source_path else "" + raise ValueError(f"{location}{context} must be a mapping") + + unknown = sorted(set(mapping) - allowed_fields) + if unknown: + location = f"{source_path}: " if source_path else "" + fields = ", ".join(unknown) + raise ValueError(f"{location}unknown native field(s) in {context}: {fields}") + + def normalize_sql_frontmatter(frontmatter: dict) -> dict: validate_native_format_version(frontmatter) normalized = dict(frontmatter) @@ -141,7 +313,10 @@ def parse(self, source: str | Path) -> SemanticGraph: if models: for model in models: graph.add_model(model) - sql_metrics, sql_segments, sql_parameters = parse_sql_graph_definitions(content) + try: + sql_metrics, sql_segments, sql_parameters = parse_sql_graph_definitions(content) + except Exception as exc: + raise ValueError(f"{source_path}: invalid SQL graph definitions: {exc}") from exc model_metric_names = {metric.name for model in models for metric in model.metrics} for metric in sql_metrics: if metric.name not in model_metric_names: @@ -150,15 +325,18 @@ def parse(self, source: str | Path) -> SemanticGraph: graph.add_parameter(param) else: # YAML frontmatter + SQL metrics/segments - frontmatter, sql_metrics, sql_segments, sql_parameters, sql_preaggs = ( - parse_sql_file_with_frontmatter_extended(source_path) - ) + try: + frontmatter, sql_metrics, sql_segments, sql_parameters, sql_preaggs = ( + parse_sql_file_with_frontmatter_extended(source_path) + ) + except Exception as exc: + raise ValueError(f"{source_path}: invalid SQL definitions: {exc}") from exc # Parse frontmatter as a model only when it still contains model fields # after native contract metadata such as `version` is removed. normalized_frontmatter = normalize_sql_frontmatter(frontmatter) if frontmatter else {} if normalized_frontmatter: - model = self._parse_model(normalized_frontmatter) + model = self._parse_model(normalized_frontmatter, source_path=source_path) if model: # Add SQL-defined metrics/segments to the model model.metrics.extend(sql_metrics) @@ -189,37 +367,77 @@ def parse(self, source: str | Path) -> SemanticGraph: return graph validate_native_format_version(data) + reject_unknown_fields(data, ROOT_FIELDS, "root", source_path=source_path) # Parse models for model_def in data.get("models") or []: - model = self._parse_model(model_def) + model = self._parse_model(model_def, source_path=source_path) if model: graph.add_model(model) # Parse metrics for metric_def in data.get("metrics") or []: - metric = self._parse_metric(metric_def) + metric = self._parse_metric(metric_def, source_path=source_path, context="metric") if metric: graph.add_metric(metric) # Parse parameters for parameter_def in data.get("parameters") or []: - parameter = self._parse_parameter(parameter_def) + parameter = self._parse_parameter(parameter_def, source_path=source_path, context="parameter") if parameter: graph.add_parameter(parameter) # Parse SQL-defined metrics/segments if present if "sql_metrics" in data: - sql_metrics, _ = parse_sql_definitions(data["sql_metrics"]) + sql_metrics, _ = self._parse_embedded_sql_definitions( + data["sql_metrics"], source_path=source_path, block_name="sql_metrics" + ) for metric in sql_metrics: graph.add_metric(metric) if "sql_segments" in data: - _, sql_segments = parse_sql_definitions(data["sql_segments"]) + _, sql_segments = self._parse_embedded_sql_definitions( + data["sql_segments"], source_path=source_path, block_name="sql_segments" + ) # Note: segments need to be attached to models # For now, skip graph-level segments + self._resolve_inheritance(graph) + return graph + def _parse_embedded_sql_definitions( + self, + sql: str, + *, + source_path: Path | None = None, + block_name: str, + model_name: str | None = None, + ) -> tuple[list[Metric], list[Segment]]: + try: + return parse_sql_definitions(sql) + except Exception as exc: + scope = f"model '{model_name}' {block_name}" if model_name else block_name + location = f"{source_path}: " if source_path else "" + raise ValueError(f"{location}invalid {scope}: {exc}") from exc + + def _resolve_inheritance(self, graph: SemanticGraph) -> None: + from sidemantic.core.inheritance import resolve_metric_inheritance, resolve_model_inheritance + + if any(model.extends for model in graph.models.values()): + missing_parent = any(model.extends and model.extends not in graph.models for model in graph.models.values()) + if not missing_parent: + graph.models = resolve_model_inheritance(graph.models) + graph._mark_dirty() + + for model in graph.models.values(): + if any(metric.extends for metric in model.metrics): + resolved_metrics = resolve_metric_inheritance({metric.name: metric for metric in model.metrics}) + model.metrics = list(resolved_metrics.values()) + + if any(metric.extends for metric in graph.metrics.values()): + graph.metrics = resolve_metric_inheritance(graph.metrics) + graph._mark_dirty() + def export(self, graph: SemanticGraph, output_path: str | Path) -> None: """Export semantic graph to Sidemantic YAML. @@ -237,12 +455,15 @@ def export(self, graph: SemanticGraph, output_path: str | Path) -> None: if graph.metrics: data["metrics"] = [self._export_metric(metric, graph) for metric in graph.metrics.values()] + if graph.parameters: + data["parameters"] = [self._export_parameter(parameter) for parameter in graph.parameters.values()] + output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: yaml.dump(data, f, sort_keys=False, default_flow_style=False) - def _parse_model(self, model_def: dict) -> Model | None: + def _parse_model(self, model_def: dict, *, source_path: Path | None = None) -> Model | None: """Parse model definition. Args: @@ -251,6 +472,8 @@ def _parse_model(self, model_def: dict) -> Model | None: Returns: Model instance or None """ + reject_unknown_fields(model_def, MODEL_FIELDS, "model", source_path=source_path) + name = model_def.get("name") if not name: return None @@ -258,21 +481,43 @@ def _parse_model(self, model_def: dict) -> Model | None: # Parse joins joins = [] for relationship_def in model_def.get("relationships") or []: + reject_unknown_fields( + relationship_def, + RELATIONSHIP_FIELDS, + f"model '{name}' relationship", + source_path=source_path, + ) + relationship_sql = relationship_def.get("sql") + if relationship_sql is not None and ("{from}" not in relationship_sql or "{to}" not in relationship_sql): + location = f"{source_path}: " if source_path else "" + raise ValueError( + f"{location}model '{name}' relationship '{relationship_def.get('name')}' sql must include " + "both {from} and {to} placeholders" + ) join = Relationship( name=relationship_def.get("name"), type=relationship_def.get("type"), - foreign_key=relationship_def.get("foreign_key"), - primary_key=relationship_def.get("primary_key"), + foreign_key=relationship_def.get("foreign_key_columns") or relationship_def.get("foreign_key"), + primary_key=relationship_def.get("primary_key_columns") or relationship_def.get("primary_key"), metadata=relationship_def.get("metadata"), through=relationship_def.get("through"), through_foreign_key=relationship_def.get("through_foreign_key"), + through_foreign_key_columns=relationship_def.get("through_foreign_key_columns"), related_foreign_key=relationship_def.get("related_foreign_key"), + related_foreign_key_columns=relationship_def.get("related_foreign_key_columns"), + sql=relationship_sql, ) joins.append(join) # Parse dimensions dimensions = [] for dim_def in model_def.get("dimensions") or []: + reject_unknown_fields( + dim_def, + DIMENSION_FIELDS, + f"model '{name}' dimension", + source_path=source_path, + ) dimension = Dimension( name=dim_def.get("name"), type=dim_def.get("type", "categorical"), # Default to categorical @@ -285,63 +530,27 @@ def _parse_model(self, model_def: dict) -> Model | None: value_format_name=dim_def.get("value_format_name"), parent=dim_def.get("parent"), metadata=dim_def.get("metadata"), + meta=dim_def.get("meta"), window=dim_def.get("window"), + public=dim_def.get("public", True), ) dimensions.append(dimension) # Parse measures/metrics (support both field names for backwards compatibility) measures = [] for measure_def in model_def.get("metrics", model_def.get("measures") or []): - measure = Metric( - name=measure_def.get("name"), - extends=measure_def.get("extends"), - agg=measure_def.get("agg"), - sql=measure_def.get("sql") or measure_def.get("expr"), - type=measure_def.get("type"), - filters=measure_def.get("filters"), - fill_nulls_with=measure_def.get("fill_nulls_with"), - description=measure_def.get("description"), - label=measure_def.get("label"), - format=measure_def.get("format"), - value_format_name=measure_def.get("value_format_name"), - drill_fields=measure_def.get("drill_fields"), - non_additive_dimension=measure_def.get("non_additive_dimension"), - metadata=measure_def.get("metadata"), - base_metric=measure_def.get("base_metric"), - comparison_type=measure_def.get("comparison_type"), - time_offset=measure_def.get("time_offset"), - calculation=measure_def.get("calculation"), - numerator=measure_def.get("numerator"), - denominator=measure_def.get("denominator"), - entity=measure_def.get("entity"), - base_event=measure_def.get("base_event"), - conversion_event=measure_def.get("conversion_event"), - conversion_window=measure_def.get("conversion_window"), - steps=measure_def.get("steps"), - offset_window=measure_def.get("offset_window"), - # Retention parameters - cohort_event=measure_def.get("cohort_event"), - activity_event=measure_def.get("activity_event"), - periods=measure_def.get("periods"), - retention_granularity=(measure_def.get("retention_granularity") or measure_def.get("granularity")) - if measure_def.get("type") == "retention" - else None, - # Cumulative/window parameters - window=measure_def.get("window"), - grain_to_date=measure_def.get("grain_to_date"), - window_expression=measure_def.get("window_expression"), - window_frame=measure_def.get("window_frame"), - window_order=measure_def.get("window_order"), - # Cohort parameters - inner_metrics=measure_def.get("inner_metrics"), - entity_dimensions=measure_def.get("entity_dimensions"), - having=measure_def.get("having"), + measure = self._parse_metric( + measure_def, + source_path=source_path, + context=f"model '{name}' metric", ) - measures.append(measure) + if measure: + measures.append(measure) # Parse segments segments = [] for seg_def in model_def.get("segments") or []: + reject_unknown_fields(seg_def, SEGMENT_FIELDS, f"model '{name}' segment", source_path=source_path) segment = Segment( name=seg_def.get("name"), sql=seg_def.get("sql"), @@ -352,11 +561,15 @@ def _parse_model(self, model_def: dict) -> Model | None: # Parse SQL-defined metrics/segments if present if "sql_metrics" in model_def: - sql_metrics, _ = parse_sql_definitions(model_def["sql_metrics"]) + sql_metrics, _ = self._parse_embedded_sql_definitions( + model_def["sql_metrics"], source_path=source_path, block_name="sql_metrics", model_name=name + ) measures.extend(sql_metrics) if "sql_segments" in model_def: - _, sql_segments = parse_sql_definitions(model_def["sql_segments"]) + _, sql_segments = self._parse_embedded_sql_definitions( + model_def["sql_segments"], source_path=source_path, block_name="sql_segments", model_name=name + ) segments.extend(sql_segments) # Parse pre-aggregations @@ -364,11 +577,23 @@ def _parse_model(self, model_def: dict) -> Model | None: pre_aggregations = [] for preagg_def in model_def.get("pre_aggregations") or []: + reject_unknown_fields( + preagg_def, + PRE_AGGREGATION_FIELDS, + f"model '{name}' pre_aggregation", + source_path=source_path, + ) # Parse refresh_key if present refresh_key = None if "refresh_key" in preagg_def: refresh_key_def = preagg_def["refresh_key"] if isinstance(refresh_key_def, dict): + reject_unknown_fields( + refresh_key_def, + REFRESH_KEY_FIELDS, + f"model '{name}' pre_aggregation refresh_key", + source_path=source_path, + ) refresh_key = RefreshKey( every=refresh_key_def.get("every"), sql=refresh_key_def.get("sql"), @@ -376,9 +601,19 @@ def _parse_model(self, model_def: dict) -> Model | None: update_window=refresh_key_def.get("update_window"), ) + for index_def in preagg_def.get("indexes") or []: + if isinstance(index_def, dict): + reject_unknown_fields( + index_def, + INDEX_FIELDS, + f"model '{name}' pre_aggregation index", + source_path=source_path, + ) + preagg = PreAggregation( name=preagg_def.get("name"), type=preagg_def.get("type", "rollup"), + sql=preagg_def.get("sql"), measures=preagg_def.get("measures") or [], dimensions=preagg_def.get("dimensions") or [], time_dimension=preagg_def.get("time_dimension"), @@ -389,29 +624,48 @@ def _parse_model(self, model_def: dict) -> Model | None: indexes=preagg_def.get("indexes"), build_range_start=preagg_def.get("build_range_start"), build_range_end=preagg_def.get("build_range_end"), + meta=preagg_def.get("meta"), ) pre_aggregations.append(preagg) - return Model( - name=name, - table=model_def.get("table"), - sql=model_def.get("sql"), - source_uri=model_def.get("source_uri"), - description=model_def.get("description"), - extends=model_def.get("extends"), - primary_key=model_def.get("primary_key", "id"), - relationships=joins, - dimensions=dimensions, - metrics=measures, - segments=segments, - pre_aggregations=pre_aggregations, - default_time_dimension=model_def.get("default_time_dimension"), - default_grain=model_def.get("default_grain"), - metadata=model_def.get("metadata"), - auto_dimensions=model_def.get("auto_dimensions", False), - ) - - def _parse_metric(self, metric_def: dict) -> Metric | None: + model_kwargs = { + "name": name, + "relationships": joins, + "dimensions": dimensions, + "metrics": measures, + "segments": segments, + "pre_aggregations": pre_aggregations, + } + for field in [ + "table", + "sql", + "source_uri", + "description", + "extends", + "unique_keys", + "default_time_dimension", + "default_grain", + "metadata", + "auto_dimensions", + "meta", + ]: + if field in model_def: + model_kwargs[field] = model_def.get(field) + + if "primary_key_columns" in model_def: + model_kwargs["primary_key"] = model_def.get("primary_key_columns") + elif "primary_key" in model_def: + model_kwargs["primary_key"] = model_def.get("primary_key") + + return Model(**model_kwargs) + + def _parse_metric( + self, + metric_def: dict, + *, + source_path: Path | None = None, + context: str = "metric", + ) -> Metric | None: """Parse measure definition. Args: @@ -420,56 +674,74 @@ def _parse_metric(self, metric_def: dict) -> Metric | None: Returns: Measure instance or None """ + reject_unknown_fields(metric_def, METRIC_FIELDS, context, source_path=source_path) + name = metric_def.get("name") metric_type = metric_def.get("type") if not name: return None - return Metric( - name=name, - extends=metric_def.get("extends"), - type=metric_type, - description=metric_def.get("description"), - label=metric_def.get("label"), - metadata=metric_def.get("metadata"), - sql=metric_def.get("sql") or metric_def.get("expr") or metric_def.get("measure"), - agg=metric_def.get("agg"), - numerator=metric_def.get("numerator"), - denominator=metric_def.get("denominator"), - base_metric=metric_def.get("base_metric"), - comparison_type=metric_def.get("comparison_type"), - time_offset=metric_def.get("time_offset"), - calculation=metric_def.get("calculation"), - entity=metric_def.get("entity"), - base_event=metric_def.get("base_event"), - conversion_event=metric_def.get("conversion_event"), - conversion_window=metric_def.get("conversion_window"), - steps=metric_def.get("steps"), - offset_window=metric_def.get("offset_window"), - cohort_event=metric_def.get("cohort_event"), - activity_event=metric_def.get("activity_event"), - periods=metric_def.get("periods"), - retention_granularity=(metric_def.get("retention_granularity") or metric_def.get("granularity")) - if metric_type == "retention" - else None, - inner_metrics=metric_def.get("inner_metrics"), - entity_dimensions=metric_def.get("entity_dimensions"), - having=metric_def.get("having"), - window=metric_def.get("window"), - grain_to_date=metric_def.get("grain_to_date"), - window_expression=metric_def.get("window_expression"), - window_frame=metric_def.get("window_frame"), - window_order=metric_def.get("window_order"), - filters=metric_def.get("filters"), - fill_nulls_with=metric_def.get("fill_nulls_with"), - format=metric_def.get("format"), - value_format_name=metric_def.get("value_format_name"), - drill_fields=metric_def.get("drill_fields"), - non_additive_dimension=metric_def.get("non_additive_dimension"), - ) + metric_kwargs = {"name": name} + for field in [ + "extends", + "type", + "description", + "label", + "metadata", + "agg", + "numerator", + "denominator", + "base_metric", + "comparison_type", + "time_offset", + "calculation", + "entity", + "base_event", + "conversion_event", + "conversion_window", + "steps", + "offset_window", + "cohort_event", + "activity_event", + "periods", + "inner_metrics", + "entity_dimensions", + "having", + "window", + "grain_to_date", + "window_expression", + "window_frame", + "window_order", + "filters", + "fill_nulls_with", + "format", + "value_format_name", + "drill_fields", + "non_additive_dimension", + "meta", + "public", + ]: + if field in metric_def: + metric_kwargs[field] = metric_def.get(field) + + if "sql" in metric_def or "expr" in metric_def or "measure" in metric_def: + metric_kwargs["sql"] = metric_def.get("sql") or metric_def.get("expr") or metric_def.get("measure") + + if metric_type == "retention" and ("retention_granularity" in metric_def or "granularity" in metric_def): + metric_kwargs["retention_granularity"] = metric_def.get("retention_granularity") or metric_def.get( + "granularity" + ) + + return Metric(**metric_kwargs) - def _parse_parameter(self, parameter_def: dict) -> Parameter | None: + def _parse_parameter( + self, + parameter_def: dict, + *, + source_path: Path | None = None, + context: str = "parameter", + ) -> Parameter | None: """Parse parameter definition. Args: @@ -478,6 +750,8 @@ def _parse_parameter(self, parameter_def: dict) -> Parameter | None: Returns: Parameter instance or None """ + reject_unknown_fields(parameter_def, PARAMETER_FIELDS, context, source_path=source_path) + name = parameter_def.get("name") param_type = parameter_def.get("type") @@ -515,6 +789,8 @@ def _export_model(self, model: Model) -> dict: result["description"] = model.description if model.metadata: result["metadata"] = model.metadata + if model.meta: + result["meta"] = model.meta # Export joins if model.relationships: @@ -531,11 +807,22 @@ def _export_model(self, model: Model) -> dict: if relationship.through_foreign_key else {} ), + **( + {"through_foreign_key_columns": relationship.through_foreign_key_columns} + if relationship.through_foreign_key_columns + else {} + ), **( {"related_foreign_key": relationship.related_foreign_key} if relationship.related_foreign_key else {} ), + **( + {"related_foreign_key_columns": relationship.related_foreign_key_columns} + if relationship.related_foreign_key_columns + else {} + ), + **({"sql": relationship.sql} if relationship.sql else {}), **({"metadata": relationship.metadata} if relationship.metadata else {}), } for relationship in model.relationships @@ -557,12 +844,16 @@ def _export_model(self, model: Model) -> dict: dim_def["sql"] = dim.sql if dim.granularity: dim_def["granularity"] = dim.granularity + if dim.supported_granularities: + dim_def["supported_granularities"] = dim.supported_granularities if dim.description: dim_def["description"] = dim.description if dim.label: dim_def["label"] = dim.label if dim.metadata: dim_def["metadata"] = dim.metadata + if dim.meta: + dim_def["meta"] = dim.meta if dim.format: dim_def["format"] = dim.format if dim.value_format_name: @@ -571,6 +862,8 @@ def _export_model(self, model: Model) -> dict: dim_def["parent"] = dim.parent if dim.window: dim_def["window"] = dim.window + if not dim.public: + dim_def["public"] = dim.public result["dimensions"].append(dim_def) # Export metrics (model-level aggregations) @@ -591,6 +884,10 @@ def _export_model(self, model: Model) -> dict: measure_def["label"] = measure.label if measure.metadata: measure_def["metadata"] = measure.metadata + if measure.meta: + measure_def["meta"] = measure.meta + if not measure.public: + measure_def["public"] = measure.public if measure.format: measure_def["format"] = measure.format if measure.value_format_name: @@ -670,6 +967,11 @@ def _export_model(self, model: Model) -> dict: seg_def["public"] = segment.public result["segments"].append(seg_def) + if model.pre_aggregations: + result["pre_aggregations"] = [ + self._export_pre_aggregation(pre_aggregation) for pre_aggregation in model.pre_aggregations + ] + return result def _export_metric(self, measure: Metric, graph) -> dict: @@ -694,6 +996,10 @@ def _export_metric(self, measure: Metric, graph) -> dict: result["label"] = measure.label if measure.metadata: result["metadata"] = measure.metadata + if measure.meta: + result["meta"] = measure.meta + if not measure.public: + result["public"] = measure.public # Type-specific fields if measure.numerator: @@ -736,11 +1042,6 @@ def _export_metric(self, measure: Metric, graph) -> dict: result["having"] = measure.having if measure.sql: result["sql"] = measure.sql - # Auto-detect and export dependencies for derived measures - if measure.type == "derived": - dependencies = measure.get_dependencies(graph) - if dependencies: - result["metrics"] = list(dependencies) if measure.agg: result["agg"] = measure.agg if measure.window: @@ -749,3 +1050,77 @@ def _export_metric(self, measure: Metric, graph) -> dict: result["filters"] = measure.filters return result + + def _export_pre_aggregation(self, pre_aggregation) -> dict: + result = { + "name": pre_aggregation.name, + "type": pre_aggregation.type, + } + + if pre_aggregation.sql: + result["sql"] = pre_aggregation.sql + if pre_aggregation.measures: + result["measures"] = pre_aggregation.measures + if pre_aggregation.dimensions: + result["dimensions"] = pre_aggregation.dimensions + if pre_aggregation.time_dimension: + result["time_dimension"] = pre_aggregation.time_dimension + if pre_aggregation.granularity: + result["granularity"] = pre_aggregation.granularity + if pre_aggregation.partition_granularity: + result["partition_granularity"] = pre_aggregation.partition_granularity + if pre_aggregation.build_range_start: + result["build_range_start"] = pre_aggregation.build_range_start + if pre_aggregation.build_range_end: + result["build_range_end"] = pre_aggregation.build_range_end + if pre_aggregation.scheduled_refresh is False: + result["scheduled_refresh"] = False + if pre_aggregation.refresh_key: + result["refresh_key"] = self._export_refresh_key(pre_aggregation.refresh_key) + if pre_aggregation.indexes: + result["indexes"] = [self._export_index(index) for index in pre_aggregation.indexes] + if pre_aggregation.meta: + result["meta"] = pre_aggregation.meta + + return result + + def _export_refresh_key(self, refresh_key) -> dict: + result = {} + if refresh_key.every: + result["every"] = refresh_key.every + if refresh_key.sql: + result["sql"] = refresh_key.sql + if refresh_key.incremental: + result["incremental"] = refresh_key.incremental + if refresh_key.update_window: + result["update_window"] = refresh_key.update_window + return result + + def _export_index(self, index) -> dict: + result = { + "name": index.name, + "columns": index.columns, + } + if index.type != "regular": + result["type"] = index.type + return result + + def _export_parameter(self, parameter: Parameter) -> dict: + """Export parameter to dictionary.""" + result = { + "name": parameter.name, + "type": parameter.type, + } + + if parameter.description: + result["description"] = parameter.description + if parameter.label: + result["label"] = parameter.label + if parameter.default_value is not None: + result["default_value"] = parameter.default_value + if parameter.allowed_values is not None: + result["allowed_values"] = parameter.allowed_values + if parameter.default_to_today: + result["default_to_today"] = parameter.default_to_today + + return result diff --git a/sidemantic/core/inheritance.py b/sidemantic/core/inheritance.py index 2d770e4a..f6cfab86 100644 --- a/sidemantic/core/inheritance.py +++ b/sidemantic/core/inheritance.py @@ -32,7 +32,7 @@ def merge_model(child: Model, parent: Model) -> Model: # Using include=model_fields_set instead of exclude_none so that # a child can explicitly set a field to None to clear a parent value. child_fields = child.model_fields_set - {"extends"} - child_data = child.model_dump(include=child_fields) + child_data = child.model_dump(include=child_fields, exclude_unset=True) # Merge lists (dimensions, metrics, relationships, segments) # Child's items are added to parent's items @@ -49,7 +49,19 @@ def merge_model(child: Model, parent: Model) -> Model: merged_data[field] = list(parent_by_name.values()) # Override scalar fields with child values - for field in ["table", "sql", "description", "primary_key", "meta"]: + for field in [ + "table", + "sql", + "source_uri", + "description", + "primary_key", + "unique_keys", + "default_time_dimension", + "default_grain", + "metadata", + "auto_dimensions", + "meta", + ]: if field in child_data: merged_data[field] = child_data[field] @@ -100,7 +112,7 @@ def merge_metric(child: Metric, parent: Metric) -> Metric: # Override with child's explicitly set data (excluding extends). child_fields = child.model_fields_set - {"extends"} - child_data = child.model_dump(include=child_fields) + child_data = child.model_dump(include=child_fields, exclude_unset=True) # Handle list fields - merge arrays for field in ["filters", "drill_fields"]: diff --git a/sidemantic/core/metric.py b/sidemantic/core/metric.py index 78f0d93c..e8ecc3b4 100644 --- a/sidemantic/core/metric.py +++ b/sidemantic/core/metric.py @@ -102,6 +102,15 @@ def handle_expr_and_parse_agg(cls, data): exp.Max: "max", exp.Median: "median", } + for agg_class_name, agg_name in { + "Stddev": "stddev", + "StddevPop": "stddev_pop", + "Variance": "variance", + "VariancePop": "variance_pop", + }.items(): + agg_class = getattr(exp, agg_class_name, None) + if agg_class is not None: + agg_map[agg_class] = agg_name agg_func = None inner_expr = None @@ -143,6 +152,11 @@ def handle_expr_and_parse_agg(cls, data): "min": "min", "max": "max", "median": "median", + "stddev": "stddev", + "stddev_pop": "stddev_pop", + "variance": "variance", + "variance_pop": "variance_pop", + "var_pop": "variance_pop", "count": "count", } if func_name in func_map: @@ -355,7 +369,7 @@ def to_sql(self) -> str: if not self.agg: raise ValueError(f"Cannot convert complex metric '{self.name}' to SQL - use type-specific logic") - agg_func = self.agg.upper() + agg_func = {"variance_pop": "VAR_POP"}.get(self.agg, self.agg.upper()) if agg_func == "COUNT_DISTINCT": agg_func = "COUNT(DISTINCT" return f"{agg_func} {self.sql_expr})" diff --git a/sidemantic/core/pre_aggregation.py b/sidemantic/core/pre_aggregation.py index 7871990d..c8b19ada 100644 --- a/sidemantic/core/pre_aggregation.py +++ b/sidemantic/core/pre_aggregation.py @@ -51,6 +51,7 @@ class PreAggregation(BaseModel): type: Literal["rollup", "original_sql", "rollup_join", "lambda"] = Field( "rollup", description="Pre-aggregation type" ) + sql: str | None = Field(None, description="SQL for original_sql or custom pre-aggregation definitions") # Rollup configuration measures: list[str] | None = Field(None, description="Measures to pre-aggregate (e.g., ['count', 'revenue'])") @@ -75,6 +76,7 @@ class PreAggregation(BaseModel): # Build range (for historical data) build_range_start: str | None = Field(None, description="SQL expression for start of data range to aggregate") build_range_end: str | None = Field(None, description="SQL expression for end of data range to aggregate") + meta: dict[str, Any] | None = Field(None, description="Adapter-specific metadata payload") def get_table_name(self, model_name: str, database: str | None = None, schema: str | None = None) -> str: """Generate the physical table name for this pre-aggregation. diff --git a/sidemantic/core/relationship.py b/sidemantic/core/relationship.py index 7dfe9168..26b75ba1 100644 --- a/sidemantic/core/relationship.py +++ b/sidemantic/core/relationship.py @@ -33,9 +33,16 @@ class Relationship(BaseModel): through_foreign_key: str | None = Field( default=None, description="Foreign key in junction model pointing to this model" ) + through_foreign_key_columns: list[str] | None = Field( + default=None, description="Foreign key columns in junction model pointing to this model" + ) related_foreign_key: str | None = Field( default=None, description="Foreign key in junction model pointing to related model" ) + related_foreign_key_columns: list[str] | None = Field( + default=None, description="Foreign key columns in junction model pointing to related model" + ) + sql: str | None = Field(default=None, description="Custom join SQL using {from} and {to} runtime placeholders") metadata: dict[str, Any] | None = Field(None, description="Adapter-specific metadata payload") @property @@ -90,4 +97,29 @@ def junction_keys(self) -> tuple[str | None, str | None]: """Get junction keys for many_to_many relationships.""" if self.type != "many_to_many": return None, None - return self.through_foreign_key or self.foreign_key, self.related_foreign_key + source_keys, target_keys = self.junction_key_columns() + return ( + source_keys[0] if source_keys else None, + target_keys[0] if target_keys else None, + ) + + def junction_key_columns(self) -> tuple[list[str], list[str]]: + """Get junction key columns for many_to_many relationships.""" + if self.type != "many_to_many": + return [], [] + + if self.through_foreign_key_columns: + source_keys = self.through_foreign_key_columns + elif self.through_foreign_key: + source_keys = [self.through_foreign_key] + else: + source_keys = self.foreign_key_columns if self.foreign_key else [] + + if self.related_foreign_key_columns: + target_keys = self.related_foreign_key_columns + elif self.related_foreign_key: + target_keys = [self.related_foreign_key] + else: + target_keys = [] + + return source_keys, target_keys diff --git a/sidemantic/core/semantic_graph.py b/sidemantic/core/semantic_graph.py index ced92634..48dff5bb 100644 --- a/sidemantic/core/semantic_graph.py +++ b/sidemantic/core/semantic_graph.py @@ -10,6 +10,21 @@ from sidemantic.core.table_calculation import TableCalculation +def _reverse_custom_join_condition(sql: str | None) -> str | None: + if sql is None: + return None + return sql.replace("{from}", "__SIDEMANTIC_FROM__").replace("{to}", "{from}").replace("__SIDEMANTIC_FROM__", "{to}") + + +def _custom_join_condition(sql: str | None) -> str | None: + """Return custom join SQL only for the placeholder-based native contract.""" + if not sql: + return None + if "{from}" in sql or "{to}" in sql: + return sql + return None + + @dataclass class JoinPath: """Represents a join between two models.""" @@ -19,6 +34,7 @@ class JoinPath: from_columns: list[str] # Foreign key column(s) in from_model to_columns: list[str] # Primary/unique key column(s) in to_model relationship: str # many_to_one, one_to_many, one_to_one + custom_condition: str | None = None # Backwards compatibility properties (return first column) @property @@ -47,9 +63,7 @@ def __init__(self): self.metadata: dict[str, Any] = {} self._version = 0 self._adjacency_dirty = True - self._adjacency: dict[ - str, list[tuple[str, list[str], list[str], str]] - ] = {} # model -> [(to_model, from_keys, to_keys, rel_type)] + self._adjacency: dict[str, list[tuple[str, list[str], list[str], str, str | None]]] = {} def _mark_dirty(self) -> None: self._version += 1 @@ -194,6 +208,28 @@ def get_metric(self, name: str) -> Metric: raise KeyError(f"Measure {name} not found") return self.metrics[name] + def resolve_metric_reference(self, reference: str) -> tuple[str | None, Metric]: + """Resolve a query metric reference using exact graph metric names first. + + Returns: + Tuple of (model_name, metric). model_name is None for graph-level metrics. + + Raises: + KeyError: If the reference does not resolve to a graph or model metric. + """ + if reference in self.metrics: + return None, self.metrics[reference] + + if "." in reference: + model_name, metric_name = reference.split(".", 1) + model = self.models.get(model_name) + if model: + metric = model.get_metric(metric_name) + if metric: + return model_name, metric + + raise KeyError(f"Metric reference {reference} not found") + def build_adjacency(self) -> None: """Build adjacency list for join path discovery. @@ -207,11 +243,16 @@ def build_adjacency(self) -> None: self._adjacency.clear() def add_edge( - from_model: str, to_model: str, from_keys: list[str], to_keys: list[str], relationship_type: str + from_model: str, + to_model: str, + from_keys: list[str], + to_keys: list[str], + relationship_type: str, + custom_condition: str | None = None, ) -> None: if from_model not in self._adjacency: self._adjacency[from_model] = [] - self._adjacency[from_model].append((to_model, from_keys, to_keys, relationship_type)) + self._adjacency[from_model].append((to_model, from_keys, to_keys, relationship_type, custom_condition)) def invert_relationship(relationship_type: str) -> str: if relationship_type == "many_to_one": @@ -239,12 +280,20 @@ def invert_relationship(relationship_type: str) -> str: continue local_keys = model.primary_key_columns remote_keys = relationship.foreign_key_columns - add_edge(model_name, related_model, local_keys, remote_keys, "one_to_many") - add_edge(related_model, model_name, remote_keys, local_keys, "many_to_one") + custom_condition = _custom_join_condition(relationship.sql) + add_edge(model_name, related_model, local_keys, remote_keys, "one_to_many", custom_condition) + add_edge( + related_model, + model_name, + remote_keys, + local_keys, + "many_to_one", + _reverse_custom_join_condition(custom_condition), + ) continue - junction_self_fk, junction_related_fk = relationship.junction_keys() - if not junction_self_fk or not junction_related_fk: + junction_self_fks, junction_related_fks = relationship.junction_key_columns() + if not junction_self_fks or not junction_related_fks: continue base_pk = model.primary_key_columns @@ -254,11 +303,11 @@ def invert_relationship(relationship_type: str) -> str: else self.models[related_model].primary_key_columns ) - add_edge(model_name, junction_model, base_pk, [junction_self_fk], "one_to_many") - add_edge(junction_model, model_name, [junction_self_fk], base_pk, "many_to_one") + add_edge(model_name, junction_model, base_pk, junction_self_fks, "one_to_many") + add_edge(junction_model, model_name, junction_self_fks, base_pk, "many_to_one") - add_edge(junction_model, related_model, [junction_related_fk], related_pk, "many_to_one") - add_edge(related_model, junction_model, related_pk, [junction_related_fk], "one_to_many") + add_edge(junction_model, related_model, junction_related_fks, related_pk, "many_to_one") + add_edge(related_model, junction_model, related_pk, junction_related_fks, "one_to_many") continue # Get the join key names @@ -279,8 +328,16 @@ def invert_relationship(relationship_type: str) -> str: ) remote_keys = relationship.foreign_key_columns # [customer_id] (in orders) - add_edge(model_name, related_model, local_keys, remote_keys, relationship.type) - add_edge(related_model, model_name, remote_keys, local_keys, invert_relationship(relationship.type)) + custom_condition = _custom_join_condition(relationship.sql) + add_edge(model_name, related_model, local_keys, remote_keys, relationship.type, custom_condition) + add_edge( + related_model, + model_name, + remote_keys, + local_keys, + invert_relationship(relationship.type), + _reverse_custom_join_condition(custom_condition), + ) def find_relationship_path(self, from_model: str, to_model: str) -> list[JoinPath]: """Find join path between two models using BFS. @@ -317,7 +374,7 @@ def find_relationship_path(self, from_model: str, to_model: str) -> list[JoinPat if current not in self._adjacency: continue - for next_model, from_keys, to_keys, relationship_type in self._adjacency[current]: + for next_model, from_keys, to_keys, relationship_type, custom_condition in self._adjacency[current]: if next_model in visited: continue @@ -330,6 +387,7 @@ def find_relationship_path(self, from_model: str, to_model: str) -> list[JoinPat from_columns=from_keys, to_columns=to_keys, relationship=relationship_type, + custom_condition=custom_condition, ) ] diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index 60b00b25..e401810c 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -1118,6 +1118,9 @@ def _connection_dict_to_url(config: dict) -> str: """ from urllib.parse import quote, urlencode + def quote_userinfo(value) -> str: + return quote(str(value), safe="") + conn_type = config.get("type", "duckdb").lower() if conn_type == "duckdb": @@ -1134,9 +1137,9 @@ def _connection_dict_to_url(config: dict) -> str: password = config.get("password", "") if user and password: - return f"postgres://{quote(user)}:{quote(password)}@{host}:{port}/{database}" + return f"postgres://{quote_userinfo(user)}:{quote_userinfo(password)}@{host}:{port}/{database}" elif user: - return f"postgres://{quote(user)}@{host}:{port}/{database}" + return f"postgres://{quote_userinfo(user)}@{host}:{port}/{database}" else: return f"postgres://{host}:{port}/{database}" @@ -1162,9 +1165,9 @@ def _connection_dict_to_url(config: dict) -> str: path += f"/{schema}" if user and password: - return f"snowflake://{quote(user)}:{quote(password)}@{account}{path}" + return f"snowflake://{quote_userinfo(user)}:{quote_userinfo(password)}@{account}{path}" elif user: - return f"snowflake://{quote(user)}@{account}{path}" + return f"snowflake://{quote_userinfo(user)}@{account}{path}" else: return f"snowflake://{account}{path}" @@ -1176,9 +1179,9 @@ def _connection_dict_to_url(config: dict) -> str: password = config.get("password", "") if user and password: - return f"clickhouse://{quote(user)}:{quote(password)}@{host}:{port}/{database}" + return f"clickhouse://{quote_userinfo(user)}:{quote_userinfo(password)}@{host}:{port}/{database}" elif user: - return f"clickhouse://{quote(user)}@{host}:{port}/{database}" + return f"clickhouse://{quote_userinfo(user)}@{host}:{port}/{database}" else: return f"clickhouse://{host}:{port}/{database}" @@ -1192,7 +1195,7 @@ def _connection_dict_to_url(config: dict) -> str: if not http_path: raise ValueError("Databricks connection requires 'http_path' field") - return f"databricks://{token}@{server}/{http_path}" + return f"databricks://{quote_userinfo(token)}@{server}/{http_path}" elif conn_type == "spark": host = config.get("host", "localhost") diff --git a/sidemantic/core/sql_definitions.py b/sidemantic/core/sql_definitions.py index 125f0080..7128b7ec 100644 --- a/sidemantic/core/sql_definitions.py +++ b/sidemantic/core/sql_definitions.py @@ -625,7 +625,11 @@ def _parse_statement_defs( pre_aggregations: list[PreAggregation] = [] for stmt in statements: - if isinstance(stmt, ModelDef): + if stmt is None: + continue + if isinstance(stmt, TableBlockModelDef): + model_def = _parse_table_block_model_def(stmt) + elif isinstance(stmt, ModelDef): model_def = _parse_model_def(stmt) elif isinstance(stmt, DimensionDef): dimension = _parse_dimension_def(stmt) @@ -651,6 +655,8 @@ def _parse_statement_defs( preagg = _parse_pre_aggregation_def(stmt) if preagg: pre_aggregations.append(preagg) + else: + raise ValueError(f"Unsupported SQL definition statement: {stmt.__class__.__name__}") return model_def, dimensions, relationships, metrics, segments, parameters, pre_aggregations @@ -664,11 +670,7 @@ def parse_sql_definitions(sql: str) -> tuple[list[Metric], list[Segment]]: Returns: Tuple of (metrics, segments) """ - try: - metrics, segments, _ = parse_sql_graph_definitions(sql) - except Exception: - return [], [] - + metrics, segments, _ = parse_sql_graph_definitions(sql) return metrics, segments @@ -681,11 +683,7 @@ def parse_sql_graph_definitions(sql: str) -> tuple[list[Metric], list[Segment], Returns: Tuple of (metrics, segments, parameters) """ - try: - _, _, _, metrics, segments, parameters, _ = _parse_sql_statements(sql) - except Exception: - return [], [], [] - + _, _, _, metrics, segments, parameters, _ = _parse_sql_statements(sql) return metrics, segments, parameters @@ -733,13 +731,7 @@ def parse_sql_file_with_frontmatter_extended( if frontmatter_text: frontmatter = yaml.safe_load(frontmatter_text) or {} - metrics, segments, parameters = parse_sql_graph_definitions(sql_body) - - pre_aggregations: list[PreAggregation] = [] - try: - _, _, _, _, _, _, pre_aggregations = _parse_sql_statements(sql_body) - except Exception: - pre_aggregations = [] + _, _, _, metrics, segments, parameters, pre_aggregations = _parse_sql_statements(sql_body) return frontmatter, metrics, segments, parameters, pre_aggregations diff --git a/sidemantic/db/base.py b/sidemantic/db/base.py index 742edb02..b95165ba 100644 --- a/sidemantic/db/base.py +++ b/sidemantic/db/base.py @@ -40,6 +40,47 @@ def validate_identifier(value: str, name: str = "identifier") -> str: return value +def _coerce_positive_int(value: Any, name: str) -> int: + if isinstance(value, bool): + raise ValueError(f"{name} must be a positive integer") + + if isinstance(value, int): + coerced = value + elif isinstance(value, str): + stripped = value.strip() + if not stripped: + raise ValueError(f"{name} must be a positive integer") + try: + coerced = int(stripped, 10) + except ValueError: + raise ValueError(f"{name} must be a positive integer") + else: + raise ValueError(f"{name} must be a positive integer") + + if coerced < 1: + raise ValueError(f"{name} must be a positive integer") + return coerced + + +def validate_query_history_params( + days_back: int, + limit: int, + *, + max_days_back: int = 365, + max_limit: int = 10_000, +) -> tuple[int, int]: + """Validate query-history lookback and row limit values for SQL interpolation.""" + days_back_int = _coerce_positive_int(days_back, "days_back") + limit_int = _coerce_positive_int(limit, "limit") + + if days_back_int > max_days_back: + raise ValueError(f"days_back must be <= {max_days_back}") + if limit_int > max_limit: + raise ValueError(f"limit must be <= {max_limit}") + + return days_back_int, limit_int + + class BaseDatabaseAdapter(ABC): """Abstract base class for database adapters. diff --git a/sidemantic/db/bigquery.py b/sidemantic/db/bigquery.py index 45b9b6da..ba33dab8 100644 --- a/sidemantic/db/bigquery.py +++ b/sidemantic/db/bigquery.py @@ -2,7 +2,7 @@ from typing import Any -from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier +from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier, validate_query_history_params class BigQueryResult: @@ -176,6 +176,7 @@ def get_query_history(self, days_back: int = 7, limit: int = 1000) -> list[str]: Returns: List of SQL query strings containing '-- sidemantic:' comments """ + days_back, limit = validate_query_history_params(days_back, limit) sql = f""" SELECT query FROM `{self.project_id}.region-{self.client.location}.INFORMATION_SCHEMA.JOBS_BY_PROJECT` diff --git a/sidemantic/db/clickhouse.py b/sidemantic/db/clickhouse.py index 3dd3b11e..6eb99b6d 100644 --- a/sidemantic/db/clickhouse.py +++ b/sidemantic/db/clickhouse.py @@ -3,7 +3,7 @@ from typing import Any from urllib.parse import parse_qs, unquote, urlparse -from sidemantic.db.base import BaseDatabaseAdapter +from sidemantic.db.base import BaseDatabaseAdapter, validate_query_history_params class ClickHouseResult: @@ -223,6 +223,7 @@ def get_query_history(self, days_back: int = 7, limit: int = 1000) -> list[str]: Returns: List of SQL query strings containing '-- sidemantic:' comments """ + days_back, limit = validate_query_history_params(days_back, limit) sql = f""" SELECT query FROM system.query_log diff --git a/sidemantic/db/databricks.py b/sidemantic/db/databricks.py index 386b7ac8..89ebfbed 100644 --- a/sidemantic/db/databricks.py +++ b/sidemantic/db/databricks.py @@ -3,7 +3,7 @@ from typing import Any from urllib.parse import parse_qs, unquote, urlparse -from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier +from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier, validate_query_history_params class DatabricksResult: @@ -185,6 +185,7 @@ def get_query_history(self, days_back: int = 7, limit: int = 1000) -> list[str]: Note: Requires Unity Catalog and appropriate permissions to query system.query.history """ + days_back, limit = validate_query_history_params(days_back, limit) sql = f""" SELECT statement_text FROM system.query.history diff --git a/sidemantic/db/postgres.py b/sidemantic/db/postgres.py index a74ca1af..20dfe7cc 100644 --- a/sidemantic/db/postgres.py +++ b/sidemantic/db/postgres.py @@ -1,7 +1,7 @@ """PostgreSQL database adapter.""" from typing import Any -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, unquote, urlparse from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier @@ -198,7 +198,7 @@ def from_url(cls, url: str) -> "PostgreSQLAdapter": host=parsed.hostname or "localhost", port=parsed.port or 5432, database=parsed.path.lstrip("/") if parsed.path else "postgres", - user=parsed.username, - password=parsed.password, + user=unquote(parsed.username) if parsed.username is not None else None, + password=unquote(parsed.password) if parsed.password is not None else None, **params, ) diff --git a/sidemantic/db/snowflake.py b/sidemantic/db/snowflake.py index 03636efd..b8f374f6 100644 --- a/sidemantic/db/snowflake.py +++ b/sidemantic/db/snowflake.py @@ -3,7 +3,7 @@ from typing import Any from urllib.parse import parse_qs, unquote, urlparse -from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier +from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier, validate_query_history_params class SnowflakeResult: @@ -232,6 +232,7 @@ def get_query_history(self, days_back: int = 7, limit: int = 1000) -> list[str]: Returns: List of SQL query strings containing '-- sidemantic:' comments """ + days_back, limit = validate_query_history_params(days_back, limit, max_days_back=7) sql = f""" SELECT query_text FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY( diff --git a/sidemantic/rust_bridge.py b/sidemantic/rust_bridge.py index 757a1737..c8b86887 100644 --- a/sidemantic/rust_bridge.py +++ b/sidemantic/rust_bridge.py @@ -1529,12 +1529,16 @@ def _serialize_relationship(relationship, source_model, target_model) -> dict | through_foreign_key = getattr(relationship, "through_foreign_key", None) related_foreign_key = getattr(relationship, "related_foreign_key", None) + through_foreign_key_columns = getattr(relationship, "through_foreign_key_columns", None) + related_foreign_key_columns = getattr(relationship, "related_foreign_key_columns", None) if relationship.type == "many_to_many": - junction_keys_fn = getattr(relationship, "junction_keys", None) - if callable(junction_keys_fn): - junction_self_fk, junction_related_fk = junction_keys_fn() - through_foreign_key = through_foreign_key or junction_self_fk - related_foreign_key = related_foreign_key or junction_related_fk + junction_key_columns_fn = getattr(relationship, "junction_key_columns", None) + if callable(junction_key_columns_fn): + junction_self_fks, junction_related_fks = junction_key_columns_fn() + through_foreign_key_columns = through_foreign_key_columns or junction_self_fks + related_foreign_key_columns = related_foreign_key_columns or junction_related_fks + through_foreign_key = through_foreign_key or (junction_self_fks[0] if junction_self_fks else None) + related_foreign_key = related_foreign_key or (junction_related_fks[0] if junction_related_fks else None) return { "name": relationship.name, @@ -1545,7 +1549,9 @@ def _serialize_relationship(relationship, source_model, target_model) -> dict | "primary_key_columns": primary_keys, "through": getattr(relationship, "through", None), "through_foreign_key": through_foreign_key, + "through_foreign_key_columns": through_foreign_key_columns, "related_foreign_key": related_foreign_key, + "related_foreign_key_columns": related_foreign_key_columns, "sql": sql, "metadata": getattr(relationship, "metadata", None), } diff --git a/sidemantic/schema.py b/sidemantic/schema.py index 0428de5d..0f7dc554 100644 --- a/sidemantic/schema.py +++ b/sidemantic/schema.py @@ -1,6 +1,7 @@ """Generate JSON Schema from Pydantic models for YAML editor completion.""" import json +from copy import deepcopy from pathlib import Path from sidemantic.core.dimension import Dimension @@ -10,6 +11,48 @@ from sidemantic.core.relationship import Relationship +def add_native_relationship_aliases(schema: dict) -> dict: + """Expose native YAML relationship aliases that map to Python API fields.""" + properties = schema.setdefault("properties", {}) + + if "foreign_key" in properties and "foreign_key_columns" not in properties: + foreign_key_columns = deepcopy(properties["foreign_key"]) + foreign_key_columns["title"] = "Foreign Key Columns" + foreign_key_columns["description"] = "Explicit source-column list (alias for foreign_key)" + properties["foreign_key_columns"] = foreign_key_columns + + if "primary_key" in properties and "primary_key_columns" not in properties: + primary_key_columns = deepcopy(properties["primary_key"]) + primary_key_columns["title"] = "Primary Key Columns" + primary_key_columns["description"] = "Explicit target-column list (alias for primary_key)" + properties["primary_key_columns"] = primary_key_columns + + if "sql" not in properties: + properties["sql"] = { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Custom join SQL using {from} and {to} runtime placeholders", + "title": "Sql", + } + + return schema + + +def patch_relationship_schemas(schema: dict) -> None: + """Patch every embedded Relationship schema emitted by Pydantic.""" + if not isinstance(schema, dict): + return + if schema.get("title") == "Relationship": + add_native_relationship_aliases(schema) + for value in schema.values(): + if isinstance(value, dict): + patch_relationship_schemas(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + patch_relationship_schemas(item) + + def generate_yaml_schema() -> dict: """Generate JSON Schema for Sidemantic YAML format. @@ -51,6 +94,8 @@ def generate_yaml_schema() -> dict: }, } + patch_relationship_schemas(schema) + return schema diff --git a/sidemantic/sql/aggregation_detection.py b/sidemantic/sql/aggregation_detection.py index 90743d41..70e31743 100644 --- a/sidemantic/sql/aggregation_detection.py +++ b/sidemantic/sql/aggregation_detection.py @@ -21,7 +21,7 @@ _AGGREGATE_REGEX = re.compile( r"\b(" - r"sum|count|avg|min|max|median|stddev|stddev_pop|variance|variance_pop|mode|" + r"sum|count|avg|min|max|median|stddev|stddev_pop|variance|variance_pop|var_pop|mode|" r"quantile|percentile|product|entropy|kurtosis|skewness|geometric_mean|weighted_avg|list" r")\s*\(", re.IGNORECASE, diff --git a/sidemantic/sql/generator.py b/sidemantic/sql/generator.py index 143b198a..239be481 100644 --- a/sidemantic/sql/generator.py +++ b/sidemantic/sql/generator.py @@ -36,6 +36,11 @@ def __init__( self._generate_cache: dict[tuple[object, ...], str] = {} self._generate_cache_limit = 256 + @staticmethod + def _agg_sql_name(agg: str) -> str: + """Return the SQL function name for a normalized metric aggregation.""" + return {"variance_pop": "VAR_POP", "var_pop": "VAR_POP"}.get(agg, agg.upper()) + @staticmethod def _model_from_clause(model) -> str: if model.sql: @@ -298,6 +303,12 @@ def _cte_ref(self, model_name: str, column_name: str) -> str: """Build a quoted reference to a CTE column.""" return f"{self._quote_identifier(self._cte_name(model_name))}.{self._quote_identifier(column_name)}" + def _custom_join_condition(self, join_path) -> str: + """Render a relationship custom SQL join condition for generated CTE names.""" + from_alias = self._quote_identifier(self._cte_name(join_path.from_model)) + to_alias = self._quote_identifier(self._cte_name(join_path.to_model)) + return join_path.custom_condition.replace("{from}", from_alias).replace("{to}", to_alias) + def _apply_default_time_dimensions(self, metrics: list[str], dimensions: list[str]) -> list[str]: """Auto-include default_time_dimension from models if not already present. @@ -328,28 +339,28 @@ def _apply_default_time_dimensions(self, metrics: list[str], dimensions: list[st added_dims = [] models_checked = set() for metric_ref in metrics: - if "." in metric_ref: - model_name, _ = metric_ref.split(".") - if model_name in models_checked: - continue - models_checked.add(model_name) + try: + model_name, _ = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + continue + if not model_name or model_name in models_checked: + continue + models_checked.add(model_name) - # Try to get model - may not exist if this is a graph-level metric - # with a dotted name (not model.measure format) - try: - model = self.graph.get_model(model_name) - except KeyError: - model = None - if model and model.default_time_dimension: - # Only add if this model doesn't already have a time dimension - if model_name not in models_with_time_dims: - time_dim_ref = f"{model_name}.{model.default_time_dimension}" - # Apply default_grain if specified - if model.default_grain: - time_dim_ref = f"{time_dim_ref}__{model.default_grain}" - if time_dim_ref not in dimensions and time_dim_ref not in added_dims: - added_dims.append(time_dim_ref) - models_with_time_dims.add(model_name) + try: + model = self.graph.get_model(model_name) + except KeyError: + model = None + if model and model.default_time_dimension: + # Only add if this model doesn't already have a time dimension + if model_name not in models_with_time_dims: + time_dim_ref = f"{model_name}.{model.default_time_dimension}" + # Apply default_grain if specified + if model.default_grain: + time_dim_ref = f"{time_dim_ref}__{model.default_grain}" + if time_dim_ref not in dimensions and time_dim_ref not in added_dims: + added_dims.append(time_dim_ref) + models_with_time_dims.add(model_name) return dimensions + added_dims @@ -495,22 +506,12 @@ def generate( def metric_needs_window(m): # Try to get metric - could be model.measure or just metric name metric = None - if "." in m: - # model.measure format - model_name, measure_name = m.split(".") - try: - model = self.graph.get_model(model_name) - if model: - metric = model.get_metric(measure_name) - except KeyError: - pass - # Fall back to graph-level metric with dotted name - if not metric: - try: - metric = self.graph.get_metric(m) - except KeyError: - pass - else: + try: + _, metric = self.graph.resolve_metric_reference(m) + except KeyError: + pass + + if not metric and "." not in m: # Just metric name - try graph-level metric try: metric = self.graph.get_metric(m) @@ -855,57 +856,49 @@ def add_model(model_name: str): def collect_models_from_metric(metric_ref: str): """Recursively collect models needed from a metric.""" - if "." in metric_ref: - # Direct measure reference (model.measure) - model_name, measure_name = metric_ref.split(".", 1) - add_model(model_name) - try: - model = self.graph.get_model(model_name) - measure = model.get_metric(measure_name) if model else None - except KeyError: - measure = None + try: + model_name, metric = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + return - if measure: - if measure.type == "ratio": - if measure.numerator: - collect_models_from_metric(measure.numerator) - if measure.denominator: - collect_models_from_metric(measure.denominator) - elif measure.type == "derived" or (not measure.type and not measure.agg and measure.sql): - for ref_metric in measure.get_dependencies(self.graph, model_name): - collect_models_from_metric(ref_metric) - if measure.sql and "." in measure.sql: - for ref_model_name in self._extract_models_from_sql(measure.sql): - add_model(ref_model_name) - elif measure.agg and measure.sql and "." in measure.sql: - for ref_model_name in self._extract_models_from_sql(measure.sql): + if model_name: + add_model(model_name) + if metric.type == "ratio": + if metric.numerator: + collect_models_from_metric(metric.numerator) + if metric.denominator: + collect_models_from_metric(metric.denominator) + elif metric.type == "derived" or (not metric.type and not metric.agg and metric.sql): + for ref_metric in metric.get_dependencies(self.graph, model_name): + collect_models_from_metric(ref_metric) + if metric.sql: + for ref_model_name in self._extract_models_from_sql(metric.sql): add_model(ref_model_name) - else: - # It's a metric, need to resolve its dependencies - try: - metric = self.graph.get_metric(metric_ref) - if metric: - if metric.type == "ratio": - if metric.numerator: - collect_models_from_metric(metric.numerator) - if metric.denominator: - collect_models_from_metric(metric.denominator) - elif metric.type == "derived" or (not metric.type and not metric.agg and metric.sql): - # Derived or untyped metrics with sql - auto-detect dependencies - for ref_metric in metric.get_dependencies(self.graph): - collect_models_from_metric(ref_metric) - # Inline SQL expression metrics (e.g., SUM(orders.amount)) - # can have empty dependencies, so also parse model refs directly. - if metric.sql: - for model_name in self._extract_models_from_sql(metric.sql): - add_model(model_name) - elif metric.agg and metric.sql: - # Graph-level simple aggregations can qualify fields - # (e.g., SUM(orders.amount)); include those models. - for model_name in self._extract_models_from_sql(metric.sql): - add_model(model_name) - except KeyError: - pass + elif metric.agg and metric.sql: + for ref_model_name in self._extract_models_from_sql(metric.sql): + add_model(ref_model_name) + return + + # It's a graph-level metric, need to resolve its dependencies. + if metric.type == "ratio": + if metric.numerator: + collect_models_from_metric(metric.numerator) + if metric.denominator: + collect_models_from_metric(metric.denominator) + elif metric.type == "derived" or (not metric.type and not metric.agg and metric.sql): + # Derived or untyped metrics with sql - auto-detect dependencies + for ref_metric in metric.get_dependencies(self.graph): + collect_models_from_metric(ref_metric) + # Inline SQL expression metrics (e.g., SUM(orders.amount)) + # can have empty dependencies, so also parse model refs directly. + if metric.sql: + for model_name in self._extract_models_from_sql(metric.sql): + add_model(model_name) + elif metric.agg and metric.sql: + # Graph-level simple aggregations can qualify fields + # (e.g., SUM(orders.amount)); include those models. + for model_name in self._extract_models_from_sql(metric.sql): + add_model(model_name) # Collect from dimensions first (since they define the grain) for dim in dimensions: @@ -1093,7 +1086,10 @@ def extract_from_measure_ref(metric_ref: str): """Extract filter columns from a model.measure reference.""" if "." not in metric_ref: return - model_name, measure_name = metric_ref.split(".") + if metric_ref in self.graph.metrics: + extract_from_metric(self.graph.metrics[metric_ref]) + return + model_name, measure_name = metric_ref.split(".", 1) model = self.graph.get_model(model_name) if model: measure = model.get_metric(measure_name) @@ -1135,8 +1131,11 @@ def extract_from_metric(metric): if metric.filters: deps = metric.get_dependencies(self.graph) for dep in deps: - if "." in dep: - dep_model_name = dep.split(".")[0] + try: + dep_model_name, _ = self.graph.resolve_metric_reference(dep) + except KeyError: + dep_model_name = dep.split(".", 1)[0] if "." in dep else None + if dep_model_name: add_filter_columns(dep_model_name, metric.filters) break @@ -1167,16 +1166,16 @@ def extract_from_metric(metric): add_sql_columns(metric.sql) for metric_ref in metrics: - if "." in metric_ref: + try: + model_name, metric = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + continue + if model_name: # model.measure format - extract directly extract_from_measure_ref(metric_ref) else: # Graph-level metric - recursively extract - try: - metric = self.graph.get_metric(metric_ref) - extract_from_metric(metric) - except KeyError: - pass + extract_from_metric(metric) return columns_by_model @@ -1333,8 +1332,8 @@ def add_passthrough_column(column: str) -> None: for other_join in other_model.relationships: if other_join.type != "many_to_many" or other_join.through != model_name: continue - junction_self_fk, junction_related_fk = other_join.junction_keys() - for fk in (junction_self_fk, junction_related_fk): + junction_self_fks, junction_related_fks = other_join.junction_key_columns() + for fk in (*junction_self_fks, *junction_related_fks): if fk and fk not in columns_added: select_cols.append(f"{self._quote_identifier(fk)} AS {self._quote_alias(fk)}") columns_added.add(fk) @@ -1417,30 +1416,42 @@ def collect_measures_from_metric(metric_ref: str, visited: set[str] | None = Non return visited.add(metric_ref) - if "." in metric_ref: + try: + ref_model_name, resolved_metric = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + ref_model_name = None + resolved_metric = None + + if resolved_metric and ref_model_name is None: + for dep in resolved_metric.get_dependencies(self.graph, model_name): + collect_measures_from_metric(dep, visited) + if resolved_metric.sql and sql_has_aggregate(resolved_metric.sql, self.dialect): + collect_sql_columns_for_model(resolved_metric.sql) + return + + if resolved_metric and ref_model_name: # It's a qualified reference (model.measure) - ref_model_name, measure_name = metric_ref.split(".", 1) - if ref_model_name == model_name: - # It's for this model - check if it's a derived measure - measure = model.get_metric(measure_name) - if measure: - if ( - not measure.type - and not measure.agg - and measure.sql - and sql_has_aggregate(measure.sql, self.dialect) - ): - collect_sql_columns_for_model(measure.sql) - return - if measure.type in ("derived", "ratio") or ( - not measure.type and not measure.agg and measure.sql - ): - # Derived/ratio measure - get its dependencies - for dep in measure.get_dependencies(self.graph, ref_model_name): - collect_measures_from_metric(dep, visited) - elif measure.agg: - # Simple aggregation measure - add it - measures_needed.add(measure_name) + if ref_model_name != model_name: + return + measure_name = metric_ref.split(".", 1)[1] + measure = resolved_metric + if ( + not measure.type + and not measure.agg + and measure.sql + and sql_has_aggregate(measure.sql, self.dialect) + ): + collect_sql_columns_for_model(measure.sql) + return + if measure.type in ("derived", "ratio") or (not measure.type and not measure.agg and measure.sql): + # Derived/ratio measure - get its dependencies + for dep in measure.get_dependencies(self.graph, ref_model_name): + collect_measures_from_metric(dep, visited) + elif measure.agg: + # Simple aggregation measure - add it + measures_needed.add(measure_name) + elif "." in metric_ref: + return else: # Unqualified reference - could be: # 1. A graph-level metric @@ -1599,6 +1610,12 @@ def _model_needs_keyed_join_columns(self, model_name: str, all_models: set[str]) for relationship in other_model.relationships: if relationship.name == model_name and relationship.type != "cross": return True + if ( + relationship.type == "many_to_many" + and relationship.through == model_name + and relationship.name in all_models + ): + return True return False @@ -1714,8 +1731,11 @@ def _needs_preaggregation_for_fanout(self, metrics: list[str], dimensions: list[ # Get unique metric models metric_models = set() for metric_ref in metrics: - if "." in metric_ref: - model_name = metric_ref.split(".")[0] + try: + model_name, _ = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + model_name = None + if model_name: metric_models.add(model_name) if len(metric_models) < 2: @@ -1788,8 +1808,11 @@ def _generate_with_preaggregation( # Group metrics by their model metrics_by_model: dict[str, list[str]] = {} for metric_ref in metrics: - if "." in metric_ref: - model_name = metric_ref.split(".")[0] + try: + model_name, _ = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + model_name = None + if model_name: if model_name not in metrics_by_model: metrics_by_model[model_name] = [] metrics_by_model[model_name].append(metric_ref) @@ -1899,14 +1922,14 @@ def metric_source_name(metric_ref: str, metric_name: str) -> str: metric_name_counts: dict[str, int] = {} for model_metrics in metrics_by_model.values(): for metric_ref in model_metrics: - metric_name = metric_ref.split(".")[1] if "." in metric_ref else metric_ref + metric_name = metric_ref.split(".", 1)[1] if "." in metric_ref else metric_ref metric_name_counts[metric_name] = metric_name_counts.get(metric_name, 0) + 1 # Add metrics from each CTE for model_name, model_metrics in metrics_by_model.items(): cte_name = f"{model_name}_preagg" for metric_ref in model_metrics: - metric_name = metric_ref.split(".")[1] if "." in metric_ref else metric_ref + metric_name = metric_ref.split(".", 1)[1] if "." in metric_ref else metric_ref source_name = metric_source_name(metric_ref, metric_name) # Check for custom alias first if metric_ref in aliases: @@ -1992,9 +2015,9 @@ def metric_source_name(metric_ref: str, metric_name: str) -> str: final_query += f"\nORDER BY {', '.join(order_clauses)}" # Add LIMIT and OFFSET - if limit: + if limit is not None: final_query += f"\nLIMIT {limit}" - if offset: + if offset is not None: final_query += f"\nOFFSET {offset}" # Combine CTEs and main query @@ -2037,20 +2060,24 @@ def _add_join_paths_to_query( joined_models.add(jp.to_model) continue - if len(jp.from_columns) != len(jp.to_columns): - raise ValueError( - f"Join between {jp.from_model} and {jp.to_model} has mismatched key columns: " - f"from_columns has {len(jp.from_columns)}, to_columns has {len(jp.to_columns)}" - ) + if jp.custom_condition: + join_cond = self._custom_join_condition(jp) + else: + if len(jp.from_columns) != len(jp.to_columns): + raise ValueError( + f"Join between {jp.from_model} and {jp.to_model} has mismatched key columns: " + f"from_columns has {len(jp.from_columns)}, to_columns has {len(jp.to_columns)}" + ) - join_conditions = [ - self._cte_ref(jp.from_model, fk) + " = " + self._cte_ref(jp.to_model, pk) - for fk, pk in zip(jp.from_columns, jp.to_columns) - ] + join_conditions = [ + self._cte_ref(jp.from_model, fk) + " = " + self._cte_ref(jp.to_model, pk) + for fk, pk in zip(jp.from_columns, jp.to_columns) + ] + join_cond = " AND ".join(join_conditions) join_type = self._explicit_join_type_for_path(jp) if join_type is None: join_type = "inner" if jp.to_model in models_with_filters else "left" - query = query.join(right_table, on=" AND ".join(join_conditions), join_type=join_type) + query = query.join(right_table, on=join_cond, join_type=join_type) joined_models.add(jp.to_model) return query @@ -2180,7 +2207,14 @@ def _build_main_select( symmetric_agg_needed = self._has_fanout_joins(base_model_name, other_models) dimension_models = {dim_ref.split(".")[0] for dim_ref, _ in parsed_dims if "." in dim_ref} - metric_models = {metric_ref.split(".")[0] for metric_ref in metrics if "." in metric_ref} + metric_models = set() + for metric_ref in metrics: + try: + metric_model_name, _ = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + metric_model_name = None + if metric_model_name: + metric_models.add(metric_model_name) for metric_model in metric_models: for dimension_model in dimension_models: if metric_model == dimension_model: @@ -2204,11 +2238,20 @@ def _build_main_select( field_names[field_key].append(model_name) for metric_ref in metrics: - if "." in metric_ref: - model_name, measure_name = metric_ref.split(".") + try: + model_name, resolved_metric = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + model_name = None + resolved_metric = None + if model_name: + measure_name = metric_ref.split(".", 1)[1] if measure_name not in field_names: field_names[measure_name] = [] field_names[measure_name].append(model_name) + elif resolved_metric: + if resolved_metric.name not in field_names: + field_names[resolved_metric.name] = [] + field_names[resolved_metric.name].append("") # Determine which fields have collisions has_collision = {name: len(models) > 1 for name, models in field_names.items()} @@ -2245,11 +2288,22 @@ def _build_main_select( # Add metrics for metric_ref in metrics: - if "." in metric_ref: + try: + resolved_model_name, resolved_metric = self.graph.resolve_metric_reference(metric_ref) + except KeyError: + resolved_model_name = None + resolved_metric = None + + if resolved_metric and resolved_model_name is None: + metric_expr = self._build_metric_sql(resolved_metric) + metric_expr = self._wrap_with_fill_nulls(metric_expr, resolved_metric) + alias = aliases.get(metric_ref, resolved_metric.name) + select_exprs.append(f"{metric_expr} AS {self._quote_alias(alias)}") + elif resolved_metric and resolved_model_name: # It's a measure reference (model.measure) - model_name, measure_name = metric_ref.split(".") - model = self.graph.get_model(model_name) - measure = model.get_metric(measure_name) + model_name = resolved_model_name + measure_name = metric_ref.split(".", 1)[1] + measure = resolved_metric if measure: # Check for custom alias first @@ -2293,19 +2347,20 @@ def _build_main_select( pk_cols = model_obj.primary_key_columns # For composite keys, concatenate columns for hashing if len(pk_cols) == 1: - pk = self._quote_identifier(pk_cols[0]) + pk = self._cte_ref(model_name, pk_cols[0]) else: pk = ( "CONCAT(" - + ", '|', ".join(f"CAST({self._quote_identifier(c)} AS VARCHAR)" for c in pk_cols) + + ", '|', ".join( + f"CAST({self._cte_ref(model_name, c)} AS VARCHAR)" for c in pk_cols + ) + ")" ) agg_expr = build_symmetric_aggregate_sql( - measure_expr=f"{measure_name}_raw", + measure_expr=self._cte_ref(model_name, f"{measure_name}_raw"), primary_key=pk, agg_type=measure.agg, - model_alias=f"{model_name}_cte", dialect=self.dialect, ) else: @@ -2393,9 +2448,9 @@ def replace_metric_ref(match): query = query.order_by(*order_by_aliases) # Add LIMIT and OFFSET - if limit: + if limit is not None: query = query.limit(limit) - if offset: + if offset is not None: query = query.offset(offset) return query.sql(dialect=self.dialect, pretty=True) @@ -2514,7 +2569,7 @@ def _build_measure_aggregation_sql(self, model_name: str, measure) -> str: Returns: SQL aggregation expression string """ - agg_func = measure.agg.upper() + agg_func = self._agg_sql_name(measure.agg) raw_col = self._cte_ref(model_name, f"{measure.name}_raw") # Simple aggregation - filters are already applied in CTE's raw column @@ -2641,22 +2696,7 @@ def _build_metric_sql(self, metric, model_context: str | None = None) -> str: raise ValueError(f"Ratio metric {metric.name} requires numerator and denominator") def resolve_ratio_ref(ref: str) -> str: - # First try model-scoped references (qualified or model_context-qualified). - if "." in ref: - ref_model, ref_name = ref.split(".", 1) - try: - ref_model_obj = self.graph.get_model(ref_model) - except KeyError: - ref_model_obj = None - - if ref_model_obj: - ref_metric = ref_model_obj.get_metric(ref_name) - if ref_metric: - if ref_metric.agg: - return self._build_measure_aggregation_sql(ref_model, ref_metric) - return self._build_metric_sql(ref_metric, ref_model) - - elif model_context: + if "." not in ref and model_context: try: context_model = self.graph.get_model(model_context) except KeyError: @@ -2669,12 +2709,16 @@ def resolve_ratio_ref(ref: str) -> str: return self._build_measure_aggregation_sql(model_context, ref_metric) return self._build_metric_sql(ref_metric, model_context) - # Fallback to graph-level metrics (including dotted metric names). try: - ref_metric = self.graph.get_metric(ref) + ref_model, ref_metric = self.graph.resolve_metric_reference(ref) except KeyError as exc: raise ValueError(f"Metric {ref} not found") from exc + if ref_model: + if ref_metric.agg: + return self._build_measure_aggregation_sql(ref_model, ref_metric) + return self._build_metric_sql(ref_metric, ref_model) + return self._build_metric_sql(ref_metric, model_context) num_expr = resolve_ratio_ref(metric.numerator) @@ -2695,7 +2739,7 @@ def resolve_ratio_ref(ref: str) -> str: return f"COUNT({inner_expr})" if metric.agg == "count_distinct": return f"COUNT(DISTINCT {inner_expr})" - return f"{metric.agg.upper()}({inner_expr})" + return f"{self._agg_sql_name(metric.agg)}({inner_expr})" elif metric.type == "derived" or (not metric.type and not metric.agg and metric.sql): # Parse formula and replace metric references (handles both typed "derived" and untyped metrics with sql) @@ -2755,20 +2799,20 @@ def resolve_ratio_ref(ref: str) -> str: # Replace each metric reference with its SQL expression for metric_name in sorted_deps: - # Check if it's a measure reference (model.measure) first - if "." in metric_name: - model_name, measure_name = metric_name.split(".") - model = self.graph.get_model(model_name) - measure = model.get_metric(measure_name) + try: + ref_model_name, ref_metric = self.graph.resolve_metric_reference(metric_name) + except KeyError: + ref_model_name = None + ref_metric = None - if measure: - if measure.agg: - # Use helper that applies metric-level filters - metric_sql = self._build_measure_aggregation_sql(model_name, measure) - else: - metric_sql = self._build_metric_sql(measure, model_name) + if ref_metric and ref_model_name: + if ref_metric.agg: + # Use helper that applies metric-level filters + metric_sql = self._build_measure_aggregation_sql(ref_model_name, ref_metric) else: - raise ValueError(f"Measure {metric_name} not found") + metric_sql = self._build_metric_sql(ref_metric, ref_model_name) + elif ref_metric: + metric_sql = self._build_metric_sql(ref_metric, model_context) else: # Try as graph-level metric try: @@ -2983,7 +3027,7 @@ def _replace_model_placeholder(expr: str) -> str: inner_metric_selects = [] for im in metric.inner_metrics: im_name = im["name"] - im_agg = im.get("agg", "count").upper() + im_agg = self._agg_sql_name(im.get("agg", "count")) im_sql = im.get("sql") if im_sql: @@ -3011,7 +3055,7 @@ def _replace_model_placeholder(expr: str) -> str: having_clause = _replace_model_placeholder(metric.having) # Build outer query - outer_agg = metric.agg.upper() + outer_agg = self._agg_sql_name(metric.agg) outer_sql = metric.sql if outer_sql: # Outer SQL references columns from the inner subquery (aliased as @@ -4286,17 +4330,14 @@ def build_time_comparison_base_expression( pass if base_metric and base_metric.sql: - # Use the underlying measure name - if "." in base_metric.sql: - base_alias = base_metric.sql.split(".")[1] - else: - base_alias = base_metric.sql + # Window over the base query's metric alias, not the raw column behind that metric. + base_alias = base_metric.name else: # Fallback to the metric name itself base_alias = base_ref # Determine aggregation function (default to SUM for backwards compatibility) - agg_func = (metric.agg or "sum").upper() + agg_func = self._agg_sql_name(metric.agg or "sum") if agg_func == "COUNT_DISTINCT": agg_func = "COUNT" base_col = f"DISTINCT base.{base_alias}" @@ -4567,9 +4608,9 @@ def build_time_comparison_base_expression( outer_query += f"\nORDER BY {', '.join(order_clauses)}" # Add LIMIT and OFFSET if specified - if limit: + if limit is not None: outer_query += f"\nLIMIT {limit}" - if offset: + if offset is not None: outer_query += f"\nOFFSET {offset}" return outer_query @@ -4923,9 +4964,9 @@ def dimension_output_name(dim_ref: str, dim_name: str, granularity: str | None) order_by_clause = "\nORDER BY " + ", ".join(order_clauses) limit_clause = "" - if limit: + if limit is not None: limit_clause = f"\nLIMIT {limit}" - if offset: + if offset is not None: limit_clause += f"\nOFFSET {offset}" select_exprs_str = ",\n ".join(select_exprs) @@ -5216,9 +5257,9 @@ def metric_output_name(metric_ref: str, metric_name: str) -> str: # Build LIMIT/OFFSET clause limit_clause = "" - if limit: + if limit is not None: limit_clause = f"\nLIMIT {limit}" - if offset: + if offset is not None: limit_clause += f"\nOFFSET {offset}" # Combine into final query diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index eba63b39..b6b0c2a4 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -4297,7 +4297,7 @@ def _build_yardstick_aggregation_expr(self, measure, model_alias: str, model_nam "stddev": "STDDEV", "stddev_pop": "STDDEV_POP", "variance": "VARIANCE", - "variance_pop": "VARIANCE_POP", + "variance_pop": "VAR_POP", } if agg not in agg_map: raise ValueError(f"Unsupported Yardstick aggregation '{measure.agg}'") @@ -5780,9 +5780,9 @@ def replace_node(node: exp.Expression) -> exp.Expression: limit = self._extract_limit(parsed) offset = self._extract_offset(parsed) - if limit: + if limit is not None: outer_sql += f"\nLIMIT {limit}" - if offset: + if offset is not None: outer_sql += f"\nOFFSET {offset}" return outer_sql diff --git a/sidemantic/validation.py b/sidemantic/validation.py index a0609b5b..c05d0f63 100644 --- a/sidemantic/validation.py +++ b/sidemantic/validation.py @@ -213,8 +213,8 @@ def validate_metric(measure: "Metric", graph: "SemanticGraph") -> list[str]: ("numerator", measure.numerator), ("denominator", measure.denominator), ]: - if ref and "." in ref: - model_name, measure_name = ref.split(".") + if ref and "." in ref and ref not in graph.metrics: + model_name, measure_name = ref.split(".", 1) model = graph.models.get(model_name) if not model: errors.append(f"Ratio measure '{measure.name}': {ref_type} model '{model_name}' not found") @@ -309,9 +309,15 @@ def validate_query(metrics: list[str], dimensions: list[str], graph: "SemanticGr # Validate metric references for metric_ref in metrics: + try: + graph.resolve_metric_reference(metric_ref) + continue + except KeyError: + pass + if "." in metric_ref: # Direct measure reference - model_name, measure_name = metric_ref.split(".") + model_name, measure_name = metric_ref.split(".", 1) model = graph.models.get(model_name) if not model: errors.append(f"Model '{model_name}' not found (referenced in '{metric_ref}')") @@ -340,7 +346,7 @@ def validate_query(metrics: list[str], dimensions: list[str], graph: "SemanticGr dim_ref = dim_ref_base if "." in dim_ref: - model_name, dim_name = dim_ref.split(".") + model_name, dim_name = dim_ref.split(".", 1) model = graph.models.get(model_name) if not model: errors.append(f"Model '{model_name}' not found (referenced in '{dim_ref}')") @@ -352,21 +358,20 @@ def validate_query(metrics: list[str], dimensions: list[str], graph: "SemanticGr # Check for join paths model_names = set() for metric_ref in metrics: - if "." in metric_ref: - model_names.add(metric_ref.split(".")[0]) - else: - try: - measure = graph.get_metric(metric_ref) - if measure and measure.sql and "." in measure.sql: - model_names.add(measure.sql.split(".")[0]) - except KeyError: - pass # Already reported as error above + try: + metric_model_name, measure = graph.resolve_metric_reference(metric_ref) + except KeyError: + continue # Already reported as error above + if metric_model_name: + model_names.add(metric_model_name) + elif measure and measure.sql and "." in measure.sql: + model_names.add(measure.sql.split(".", 1)[0]) for dim_ref in dimensions: if "__" in dim_ref: dim_ref = dim_ref.rsplit("__", 1)[0] if "." in dim_ref: - model_names.add(dim_ref.split(".")[0]) + model_names.add(dim_ref.split(".", 1)[0]) # Check that all model pairs can be joined # Only check models that exist in the graph (errors for missing models already reported above) diff --git a/tests/adapters/sidemantic_adapter/test_parsing.py b/tests/adapters/sidemantic_adapter/test_parsing.py index d4eba4f4..2283af21 100644 --- a/tests/adapters/sidemantic_adapter/test_parsing.py +++ b/tests/adapters/sidemantic_adapter/test_parsing.py @@ -101,6 +101,575 @@ def test_parse_native_yaml_accepts_version_one(tmp_path): assert "orders" in graph.models +def test_parse_native_yaml_accepts_compatibility_aliases(tmp_path): + """Python compatibility aliases are accepted as native input.""" + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + auto_dimensions: false + dimensions: + - name: status + type: categorical + expr: order_status + measures: + - name: total_revenue + agg: sum + expr: amount + - name: revenue_per_order + type: derived + measure: total_revenue / order_count + - name: order_count + agg: count +""" + ) + + graph = adapter.parse(yaml_path) + orders = graph.models["orders"] + + assert orders.auto_dimensions is False + assert orders.dimensions[0].sql == "order_status" + assert [metric.name for metric in orders.metrics] == ["total_revenue", "revenue_per_order", "order_count"] + assert orders.metrics[0].sql == "amount" + assert orders.metrics[1].sql == "total_revenue / order_count" + + +@pytest.mark.parametrize( + ("yaml_body", "error_text"), + [ + ( + """ +models: + - name: orders + table: orders + metrcs: [] +""", + "unknown native field(s) in model: metrcs", + ), + ( + """ +models: + - name: orders + table: orders + dimensions: + - name: status + type: categorical + sqll: status +""", + "unknown native field(s) in model 'orders' dimension: sqll", + ), + ( + """ +models: + - name: orders + table: orders + metrics: + - name: total_revenue + agg: sum + sqll: amount +""", + "unknown native field(s) in model 'orders' metric: sqll", + ), + ( + """ +models: + - name: orders + table: orders + relationships: + - name: customers + type: many_to_one + foreign_keys: customer_id +""", + "unknown native field(s) in model 'orders' relationship: foreign_keys", + ), + ( + """ +models: + - name: orders + table: orders + pre_aggregations: + - name: daily + measures: [total_revenue] + time_dimensions: created_at +""", + "unknown native field(s) in model 'orders' pre_aggregation: time_dimensions", + ), + ], +) +def test_parse_native_yaml_rejects_unknown_nested_fields(tmp_path, yaml_body, error_text): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text(f"version: 1\n{yaml_body}") + + with pytest.raises(ValueError) as exc_info: + adapter.parse(yaml_path) + assert error_text in str(exc_info.value) + + +def test_parse_export_preserves_native_metadata_visibility_and_granularity(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + meta: + owner: analytics + dimensions: + - name: created_at + type: time + sql: created_at + granularity: day + supported_granularities: [day, week, month] + meta: + role: event_time + public: false + metrics: + - name: total_revenue + agg: sum + sql: amount + meta: + unit: usd + public: false +metrics: + - name: revenue_per_order + type: derived + sql: orders.total_revenue / orders.order_count + meta: + owner: finance + public: false +""" + ) + + graph = adapter.parse(yaml_path) + orders = graph.models["orders"] + created_at = orders.dimensions[0] + total_revenue = orders.metrics[0] + revenue_per_order = graph.metrics["revenue_per_order"] + + assert orders.meta == {"owner": "analytics"} + assert created_at.supported_granularities == ["day", "week", "month"] + assert created_at.meta == {"role": "event_time"} + assert created_at.public is False + assert total_revenue.meta == {"unit": "usd"} + assert total_revenue.public is False + assert revenue_per_order.meta == {"owner": "finance"} + assert revenue_per_order.public is False + + export_path = tmp_path / "exported.yml" + adapter.export(graph, export_path) + exported = yaml.safe_load(export_path.read_text()) + exported_model = exported["models"][0] + exported_dimension = exported_model["dimensions"][0] + exported_metric = exported_model["metrics"][0] + exported_graph_metric = exported["metrics"][0] + + assert exported_model["meta"] == {"owner": "analytics"} + assert exported_dimension["supported_granularities"] == ["day", "week", "month"] + assert exported_dimension["meta"] == {"role": "event_time"} + assert exported_dimension["public"] is False + assert exported_metric["meta"] == {"unit": "usd"} + assert exported_metric["public"] is False + assert exported_graph_metric["meta"] == {"owner": "finance"} + assert exported_graph_metric["public"] is False + + graph2 = adapter.parse(export_path) + assert graph2.models["orders"].dimensions[0].supported_granularities == ["day", "week", "month"] + assert graph2.models["orders"].dimensions[0].meta == {"role": "event_time"} + assert graph2.models["orders"].dimensions[0].public is False + assert graph2.models["orders"].metrics[0].meta == {"unit": "usd"} + assert graph2.models["orders"].metrics[0].public is False + assert graph2.metrics["revenue_per_order"].meta == {"owner": "finance"} + assert graph2.metrics["revenue_per_order"].public is False + + +def test_parse_export_preserves_top_level_parameters(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +parameters: + - name: status + type: string + description: Order status + label: Status + default_value: paid + allowed_values: [paid, refunded] + - name: report_date + type: date + default_to_today: true +models: + - name: orders + table: orders +""" + ) + + graph = adapter.parse(yaml_path) + assert graph.parameters["status"].default_value == "paid" + assert graph.parameters["status"].allowed_values == ["paid", "refunded"] + assert graph.parameters["report_date"].default_to_today is True + + export_path = tmp_path / "exported.yml" + adapter.export(graph, export_path) + exported = yaml.safe_load(export_path.read_text()) + + assert exported["parameters"] == [ + { + "name": "status", + "type": "string", + "description": "Order status", + "label": "Status", + "default_value": "paid", + "allowed_values": ["paid", "refunded"], + }, + { + "name": "report_date", + "type": "date", + "default_to_today": True, + }, + ] + + graph2 = adapter.parse(export_path) + assert graph2.parameters["status"].default_value == "paid" + assert graph2.parameters["status"].allowed_values == ["paid", "refunded"] + assert graph2.parameters["report_date"].default_to_today is True + + +def test_parse_export_preserves_relationship_custom_sql(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + relationships: + - name: customers + type: many_to_one + foreign_key_columns: [customer_id, tenant_id] + primary_key_columns: [customer_id, tenant_id] + sql: "{from}.customer_id = {to}.customer_id AND {from}.tenant_id IS NOT DISTINCT FROM {to}.tenant_id" + - name: customers + table: customers + primary_key_columns: [customer_id, tenant_id] +""" + ) + + graph = adapter.parse(yaml_path) + relationship = graph.models["orders"].relationships[0] + assert relationship.sql == ( + "{from}.customer_id = {to}.customer_id AND {from}.tenant_id IS NOT DISTINCT FROM {to}.tenant_id" + ) + + export_path = tmp_path / "exported.yml" + adapter.export(graph, export_path) + exported = yaml.safe_load(export_path.read_text()) + + assert exported["models"][0]["relationships"][0]["sql"] == relationship.sql + + graph2 = adapter.parse(export_path) + assert graph2.models["orders"].relationships[0].sql == relationship.sql + + +def test_parse_native_yaml_rejects_relationship_sql_without_placeholders(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + sql: customer_id +""" + ) + + with pytest.raises(ValueError, match=r"relationship 'customers' sql must include both \{from\} and \{to\}"): + adapter.parse(yaml_path) + + +def test_parse_export_preserves_pre_aggregations(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + dimensions: + - name: status + type: categorical + - name: created_at + type: time + metrics: + - name: total_revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + type: rollup + sql: "select status, sum(amount) as total_revenue from orders group by 1" + measures: [total_revenue] + dimensions: [status] + time_dimension: created_at + granularity: day + partition_granularity: month + build_range_start: "date '2026-01-01'" + build_range_end: current_date + scheduled_refresh: false + refresh_key: + every: 1 hour + sql: "select max(updated_at) from orders" + incremental: true + update_window: 7 days + indexes: + - name: by_status + columns: [status] + type: aggregate + meta: + owner: analytics +""" + ) + + graph = adapter.parse(yaml_path) + preagg = graph.models["orders"].pre_aggregations[0] + assert preagg.sql == "select status, sum(amount) as total_revenue from orders group by 1" + assert preagg.partition_granularity == "month" + assert preagg.scheduled_refresh is False + assert preagg.refresh_key.every == "1 hour" + assert preagg.refresh_key.sql == "select max(updated_at) from orders" + assert preagg.refresh_key.incremental is True + assert preagg.refresh_key.update_window == "7 days" + assert preagg.indexes[0].name == "by_status" + assert preagg.indexes[0].type == "aggregate" + assert preagg.meta == {"owner": "analytics"} + + export_path = tmp_path / "exported.yml" + adapter.export(graph, export_path) + exported = yaml.safe_load(export_path.read_text()) + exported_preagg = exported["models"][0]["pre_aggregations"][0] + + assert exported_preagg == { + "name": "daily_revenue", + "type": "rollup", + "sql": "select status, sum(amount) as total_revenue from orders group by 1", + "measures": ["total_revenue"], + "dimensions": ["status"], + "time_dimension": "created_at", + "granularity": "day", + "partition_granularity": "month", + "build_range_start": "date '2026-01-01'", + "build_range_end": "current_date", + "scheduled_refresh": False, + "refresh_key": { + "every": "1 hour", + "sql": "select max(updated_at) from orders", + "incremental": True, + "update_window": "7 days", + }, + "indexes": [{"name": "by_status", "columns": ["status"], "type": "aggregate"}], + "meta": {"owner": "analytics"}, + } + + graph2 = adapter.parse(export_path) + preagg2 = graph2.models["orders"].pre_aggregations[0] + assert preagg2.sql == preagg.sql + assert preagg2.refresh_key.incremental is True + assert preagg2.indexes[0].type == "aggregate" + assert preagg2.meta == {"owner": "analytics"} + + +def test_parse_native_yaml_explicit_key_columns(tmp_path): + """Explicit *_columns key fields are part of the native YAML contract.""" + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: order_items + table: order_items + primary_key_columns: [order_id, item_id] + unique_keys: + - [order_id, item_id] + metrics: + - name: count + agg: count + - name: shipments + table: shipments + primary_key: shipment_id + dimensions: + - name: carrier + type: categorical + relationships: + - name: order_items + type: many_to_one + foreign_key_columns: [order_id, item_id] + primary_key_columns: [order_id, item_id] +""" + ) + + graph = adapter.parse(yaml_path) + + order_items = graph.models["order_items"] + assert order_items.primary_key == ["order_id", "item_id"] + assert order_items.primary_key_columns == ["order_id", "item_id"] + assert order_items.unique_keys == [["order_id", "item_id"]] + + relationship = graph.models["shipments"].relationships[0] + assert relationship.foreign_key == ["order_id", "item_id"] + assert relationship.foreign_key_columns == ["order_id", "item_id"] + assert relationship.primary_key == ["order_id", "item_id"] + assert relationship.primary_key_columns == ["order_id", "item_id"] + + layer = SemanticLayer(auto_register=False) + for model in graph.models.values(): + layer.add_model(model) + + sql = layer.compile(metrics=["order_items.count"], dimensions=["shipments.carrier"]) + assert "shipments_cte.order_id = order_items_cte.order_id" in sql + assert "shipments_cte.item_id = order_items_cte.item_id" in sql + + +def test_parse_native_yaml_resolves_model_and_metric_inheritance(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: base_orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + sql: status + metrics: + - name: revenue + agg: sum + sql: amount + - name: completed_orders + extends: base_orders + dimensions: + - name: completed_at + type: time + sql: completed_at + granularity: day + metrics: + - name: completed_revenue + extends: revenue + filters: + - status = 'completed' +metrics: + - name: base_revenue + sql: base_orders.revenue + - name: display_revenue + extends: base_revenue + label: Revenue +""" + ) + + graph = adapter.parse(yaml_path) + + completed_orders = graph.models["completed_orders"] + assert completed_orders.table == "orders" + assert completed_orders.primary_key == "order_id" + assert completed_orders.extends is None + assert completed_orders.get_dimension("status") is not None + assert completed_orders.get_dimension("completed_at") is not None + + completed_revenue = completed_orders.get_metric("completed_revenue") + assert completed_revenue.agg == "sum" + assert completed_revenue.sql == "amount" + assert completed_revenue.filters == ["status = 'completed'"] + assert completed_revenue.extends is None + + display_revenue = graph.metrics["display_revenue"] + assert display_revenue.sql == "base_orders.revenue" + assert display_revenue.label == "Revenue" + assert display_revenue.extends is None + + +def test_parse_native_yaml_rejects_invalid_top_level_sql_metric_block(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +sql_metrics: | + SELECT 1; +""" + ) + + with pytest.raises(ValueError, match=r"orders\.yml: invalid sql_metrics"): + adapter.parse(yaml_path) + + +def test_parse_native_yaml_rejects_invalid_model_sql_metric_block(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + sql_metrics: | + SELECT 1; +""" + ) + + with pytest.raises(ValueError, match=r"orders\.yml: invalid model 'orders' sql_metrics"): + adapter.parse(yaml_path) + + +def test_parse_native_yaml_rejects_invalid_top_level_sql_segment_block(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +sql_segments: | + SELECT 1; +""" + ) + + with pytest.raises(ValueError, match=r"orders\.yml: invalid sql_segments"): + adapter.parse(yaml_path) + + +def test_parse_native_yaml_rejects_invalid_model_sql_segment_block(tmp_path): + adapter = SidemanticAdapter() + yaml_path = tmp_path / "orders.yml" + yaml_path.write_text( + """ +version: 1 +models: + - name: orders + table: orders + sql_segments: | + SELECT 1; +""" + ) + + with pytest.raises(ValueError, match=r"orders\.yml: invalid model 'orders' sql_segments"): + adapter.parse(yaml_path) + + def test_parse_native_yaml_rejects_unsupported_version(tmp_path): """Test unsupported native YAML versions fail early.""" adapter = SidemanticAdapter() diff --git a/tests/core/test_rust_bridge_yaml_serialization.py b/tests/core/test_rust_bridge_yaml_serialization.py index 81f49022..6e85f0c3 100644 --- a/tests/core/test_rust_bridge_yaml_serialization.py +++ b/tests/core/test_rust_bridge_yaml_serialization.py @@ -6,6 +6,7 @@ from sidemantic.core.metric import Metric from sidemantic.core.model import Model from sidemantic.core.pre_aggregation import Index, PreAggregation, RefreshKey +from sidemantic.core.relationship import Relationship from sidemantic.core.semantic_graph import SemanticGraph from sidemantic.rust_bridge import graph_to_rust_yaml, models_to_rust_yaml @@ -20,6 +21,21 @@ def test_models_to_rust_yaml_preserves_extended_core_metadata(): unique_keys=[["order_id", "tenant_id"]], default_time_dimension="order_date", default_grain="day", + relationships=[ + Relationship( + name="customers", + type="many_to_one", + foreign_key=["customer_id", "tenant_id"], + primary_key=["customer_id", "tenant_id"], + ), + Relationship( + name="products", + type="many_to_many", + through="order_products", + through_foreign_key_columns=["order_id", "tenant_id"], + related_foreign_key_columns=["product_id", "tenant_id"], + ), + ], dimensions=[ Dimension( name="order_date", @@ -57,13 +73,21 @@ def test_models_to_rust_yaml_preserves_extended_core_metadata(): payload = yaml.safe_load(models_to_rust_yaml([model], include_extends=True)) model_payload = payload["models"][0] + relationship_payload = model_payload["relationships"][0] + many_to_many_payload = model_payload["relationships"][1] dimension_payload = model_payload["dimensions"][0] metric_payload = model_payload["metrics"][0] preagg_payload = model_payload["pre_aggregations"][0] assert model_payload["source_uri"] == "s3://warehouse/orders" assert model_payload["extends"] == "base_orders" + assert model_payload["primary_key_columns"] == ["order_id", "tenant_id"] assert model_payload["unique_keys"] == [["order_id", "tenant_id"]] + assert relationship_payload["foreign_key_columns"] == ["customer_id", "tenant_id"] + assert relationship_payload["primary_key_columns"] == ["customer_id", "tenant_id"] + assert many_to_many_payload["through"] == "order_products" + assert many_to_many_payload["through_foreign_key_columns"] == ["order_id", "tenant_id"] + assert many_to_many_payload["related_foreign_key_columns"] == ["product_id", "tenant_id"] assert dimension_payload["supported_granularities"] == ["day", "week", "month"] assert dimension_payload["format"] == "yyyy-mm-dd" diff --git a/tests/core/test_sql_definitions.py b/tests/core/test_sql_definitions.py index 2c9b057f..c4ddd58f 100644 --- a/tests/core/test_sql_definitions.py +++ b/tests/core/test_sql_definitions.py @@ -773,6 +773,18 @@ def test_parse_graph_definitions_after_table_block(): assert parameters == [] +def test_parse_graph_definitions_rejects_plain_sql(): + """Graph definition blocks should not silently ignore unsupported SQL.""" + with pytest.raises(ValueError, match="Unsupported SQL definition statement: Select"): + parse_sql_graph_definitions("SELECT 1;") + + +def test_parse_sql_definitions_propagates_parse_errors(): + """Malformed embedded definition syntax should surface as a parse failure.""" + with pytest.raises(Exception): + parse_sql_definitions("NOT_A_DEF (name x);") + + def test_parse_table_block_multiline_field_expression(): """Test compact field expressions can span lines before their alias.""" sql_content = """ diff --git a/tests/db/test_postgres_adapter.py b/tests/db/test_postgres_adapter.py index c9a5cc63..caebe7fd 100644 --- a/tests/db/test_postgres_adapter.py +++ b/tests/db/test_postgres_adapter.py @@ -5,6 +5,7 @@ import pytest +from sidemantic.core.semantic_layer import SemanticLayer from sidemantic.db.postgres import PostgreSQLAdapter, PostgresResult @@ -69,6 +70,7 @@ def test_postgres_adapter_import_error_message(monkeypatch): def test_postgres_from_url_matrix(monkeypatch): cases = [ ("postgres://u:p@host:5432/db", {"host": "host", "port": 5432, "database": "db", "user": "u", "password": "p"}), + ("postgres://u%40corp:p%40ss%2Fword@host/db", {"user": "u@corp", "password": "p@ss/word"}), ("postgresql://u@host/db", {"host": "host", "port": 5432, "database": "db", "user": "u", "password": None}), ("postgres://host/db", {"host": "host", "port": 5432, "database": "db", "user": None, "password": None}), ("postgres://host", {"host": "host", "port": 5432, "database": "postgres", "user": None, "password": None}), @@ -104,6 +106,32 @@ def fake_init(self, **kwargs): assert captured[key] == value +def test_postgres_connection_dict_encoded_credentials_round_trip(monkeypatch): + captured = {} + + def fake_init(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(PostgreSQLAdapter, "__init__", fake_init) + + url = SemanticLayer._connection_dict_to_url( + { + "type": "postgres", + "host": "host", + "port": 5432, + "database": "db", + "user": "u@corp", + "password": "p@ss/word", + } + ) + + assert url == "postgres://u%40corp:p%40ss%2Fword@host:5432/db" + adapter = PostgreSQLAdapter.from_url(url) + assert isinstance(adapter, PostgreSQLAdapter) + assert captured["user"] == "u@corp" + assert captured["password"] == "p@ss/word" + + def test_postgres_from_url_invalid(): with pytest.raises(ValueError, match="Invalid PostgreSQL URL"): PostgreSQLAdapter.from_url("mysql://host/db") diff --git a/tests/db/test_query_history_validation.py b/tests/db/test_query_history_validation.py new file mode 100644 index 00000000..4beda306 --- /dev/null +++ b/tests/db/test_query_history_validation.py @@ -0,0 +1,57 @@ +"""Shared query-history parameter validation coverage.""" + +import pytest + +from sidemantic.db.base import validate_query_history_params +from sidemantic.db.bigquery import BigQueryAdapter +from sidemantic.db.clickhouse import ClickHouseAdapter +from sidemantic.db.databricks import DatabricksAdapter +from sidemantic.db.snowflake import SnowflakeAdapter + + +def test_validate_query_history_params_coerces_safe_integer_strings(): + assert validate_query_history_params("4", "9") == (4, 9) + + +@pytest.mark.parametrize( + ("days_back", "limit", "match"), + [ + ("1; DROP TABLE jobs", 10, "days_back must be a positive integer"), + (-1, 10, "days_back must be a positive integer"), + (0, 10, "days_back must be a positive integer"), + (True, 10, "days_back must be a positive integer"), + (366, 10, "days_back must be <= 365"), + (7, "10; DROP TABLE jobs", "limit must be a positive integer"), + (7, -1, "limit must be a positive integer"), + (7, 0, "limit must be a positive integer"), + (7, True, "limit must be a positive integer"), + (7, 10_001, "limit must be <= 10000"), + ], +) +def test_validate_query_history_params_rejects_unsafe_values(days_back, limit, match): + with pytest.raises(ValueError, match=match): + validate_query_history_params(days_back, limit) + + +@pytest.mark.parametrize( + "adapter", + [ + BigQueryAdapter.__new__(BigQueryAdapter), + ClickHouseAdapter.__new__(ClickHouseAdapter), + DatabricksAdapter.__new__(DatabricksAdapter), + SnowflakeAdapter.__new__(SnowflakeAdapter), + ], +) +def test_query_history_adapters_reject_interpolated_values_before_execution(adapter): + with pytest.raises(ValueError, match="days_back must be a positive integer"): + adapter.get_query_history(days_back="1; DROP TABLE query_log", limit=10) + + with pytest.raises(ValueError, match="limit must be a positive integer"): + adapter.get_query_history(days_back=1, limit="10; DROP TABLE query_log") + + +def test_snowflake_query_history_enforces_information_schema_lookback_limit(): + adapter = SnowflakeAdapter.__new__(SnowflakeAdapter) + + with pytest.raises(ValueError, match="days_back must be <= 7"): + adapter.get_query_history(days_back=8, limit=10) diff --git a/tests/joins/test_many_to_many_joins.py b/tests/joins/test_many_to_many_joins.py index fa3feecf..8f2a17fb 100644 --- a/tests/joins/test_many_to_many_joins.py +++ b/tests/joins/test_many_to_many_joins.py @@ -71,3 +71,58 @@ def test_many_to_many_join_path(layer): "order_items_cte.product_id = products_cte.product_id" in sql or "products_cte.product_id = order_items_cte.product_id" in sql ) + + +def test_many_to_many_through_composite_junction_keys(layer): + """Composite many-to-many joins use every source and target junction key.""" + orders = Model( + name="orders", + table="orders", + primary_key=["tenant_id", "order_id"], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + relationships=[ + Relationship( + name="products", + type="many_to_many", + through="order_items", + through_foreign_key_columns=["tenant_id", "order_id"], + related_foreign_key_columns=["tenant_id", "product_id"], + ) + ], + ) + products = Model( + name="products", + table="products", + primary_key=["tenant_id", "product_id"], + dimensions=[Dimension(name="name", type="categorical")], + ) + order_items = Model( + name="order_items", + table="order_items", + primary_key=["tenant_id", "order_id", "product_id"], + ) + + layer.add_model(orders) + layer.add_model(products) + layer.add_model(order_items) + + path = layer.graph.find_relationship_path("orders", "products") + assert len(path) == 2 + assert path[0].from_columns == ["tenant_id", "order_id"] + assert path[0].to_columns == ["tenant_id", "order_id"] + assert path[1].from_columns == ["tenant_id", "product_id"] + assert path[1].to_columns == ["tenant_id", "product_id"] + + sql = layer.compile(metrics=["orders.revenue"], dimensions=["products.name"]) + assert "orders_cte.tenant_id = order_items_cte.tenant_id" in sql or ( + "order_items_cte.tenant_id = orders_cte.tenant_id" in sql + ) + assert "orders_cte.order_id = order_items_cte.order_id" in sql or ( + "order_items_cte.order_id = orders_cte.order_id" in sql + ) + assert "order_items_cte.tenant_id = products_cte.tenant_id" in sql or ( + "products_cte.tenant_id = order_items_cte.tenant_id" in sql + ) + assert "order_items_cte.product_id = products_cte.product_id" in sql or ( + "products_cte.product_id = order_items_cte.product_id" in sql + ) diff --git a/tests/native-fixtures/advanced_metrics/expected/cumulative_revenue_by_month_result.json b/tests/native-fixtures/advanced_metrics/expected/cumulative_revenue_by_month_result.json new file mode 100644 index 00000000..9541f690 --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/expected/cumulative_revenue_by_month_result.json @@ -0,0 +1,12 @@ +[ + { + "event_date__month": "2024-01-01", + "total_revenue": 175, + "cumulative_revenue": 175 + }, + { + "event_date__month": "2024-02-01", + "total_revenue": 400, + "cumulative_revenue": 575 + } +] diff --git a/tests/native-fixtures/advanced_metrics/expected/multi_platform_users_result.json b/tests/native-fixtures/advanced_metrics/expected/multi_platform_users_result.json new file mode 100644 index 00000000..767802d0 --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/expected/multi_platform_users_result.json @@ -0,0 +1,10 @@ +[ + { + "region": "eu", + "multi_platform_users": 2 + }, + { + "region": "us", + "multi_platform_users": 2 + } +] diff --git a/tests/native-fixtures/advanced_metrics/expected/revenue_mom_by_month_region_result.json b/tests/native-fixtures/advanced_metrics/expected/revenue_mom_by_month_region_result.json new file mode 100644 index 00000000..5dfab1d3 --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/expected/revenue_mom_by_month_region_result.json @@ -0,0 +1,26 @@ +[ + { + "event_date__month": "2024-01-01", + "region": "eu", + "total_revenue": 75, + "events.revenue_mom": null + }, + { + "event_date__month": "2024-01-01", + "region": "us", + "total_revenue": 100, + "events.revenue_mom": null + }, + { + "event_date__month": "2024-02-01", + "region": "eu", + "total_revenue": 150, + "events.revenue_mom": 100.0 + }, + { + "event_date__month": "2024-02-01", + "region": "us", + "total_revenue": 250, + "events.revenue_mom": 150.0 + } +] diff --git a/tests/native-fixtures/advanced_metrics/expected/signup_conversion_by_region_result.json b/tests/native-fixtures/advanced_metrics/expected/signup_conversion_by_region_result.json new file mode 100644 index 00000000..3bcfced3 --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/expected/signup_conversion_by_region_result.json @@ -0,0 +1,10 @@ +[ + { + "region": "eu", + "signup_conversion": 1.0 + }, + { + "region": "us", + "signup_conversion": 0.5 + } +] diff --git a/tests/native-fixtures/advanced_metrics/expected/signup_retention_result.json b/tests/native-fixtures/advanced_metrics/expected/signup_retention_result.json new file mode 100644 index 00000000..80a6e94b --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/expected/signup_retention_result.json @@ -0,0 +1,30 @@ +[ + { + "cohort_date": "2024-01-01", + "days_since": 1, + "active_users": 1, + "cohort_size": 1, + "retention_pct": 100.0 + }, + { + "cohort_date": "2024-01-01", + "days_since": 4, + "active_users": 1, + "cohort_size": 1, + "retention_pct": 100.0 + }, + { + "cohort_date": "2024-01-05", + "days_since": 2, + "active_users": 1, + "cohort_size": 1, + "retention_pct": 100.0 + }, + { + "cohort_date": "2024-02-01", + "days_since": 6, + "active_users": 1, + "cohort_size": 2, + "retention_pct": 50.0 + } +] diff --git a/tests/native-fixtures/advanced_metrics/queries/multi_platform_users.query.yml b/tests/native-fixtures/advanced_metrics/queries/multi_platform_users.query.yml index c16505a9..f485317f 100644 --- a/tests/native-fixtures/advanced_metrics/queries/multi_platform_users.query.yml +++ b/tests/native-fixtures/advanced_metrics/queries/multi_platform_users.query.yml @@ -1,2 +1,4 @@ metrics: - events.multi_platform_users +order_by: + - region diff --git a/tests/native-fixtures/advanced_metrics/seed/duckdb.sql b/tests/native-fixtures/advanced_metrics/seed/duckdb.sql new file mode 100644 index 00000000..f9c1ed98 --- /dev/null +++ b/tests/native-fixtures/advanced_metrics/seed/duckdb.sql @@ -0,0 +1,24 @@ +create table events ( + event_id integer, + user_id varchar, + event_type varchar, + event_date date, + region varchar, + raw_platform varchar, + amount integer +); + +insert into events values + (1, 'u1', 'signup', '2024-01-01', 'us', 'ios', 0), + (2, 'u1', 'purchase', '2024-01-03', 'us', 'web', 100), + (3, 'u1', 'active', '2024-01-02', 'us', 'ios', 0), + (4, 'u1', 'active', '2024-01-05', 'us', 'web', 0), + (5, 'u2', 'signup', '2024-01-05', 'eu', 'android', 0), + (6, 'u2', 'purchase', '2024-01-07', 'eu', 'web', 75), + (7, 'u2', 'active', '2024-01-07', 'eu', 'android', 0), + (8, 'u3', 'signup', '2024-02-01', 'us', 'ios', 0), + (9, 'u3', 'purchase', '2024-02-10', 'us', 'web', 200), + (10, 'u4', 'signup', '2024-02-01', 'eu', 'ios', 0), + (11, 'u4', 'purchase', '2024-02-05', 'eu', 'web', 150), + (12, 'u4', 'active', '2024-02-07', 'eu', 'web', 0), + (13, 'u5', 'purchase', '2024-02-10', 'us', 'web', 50); diff --git a/tests/native-fixtures/compact_sql_model/README.md b/tests/native-fixtures/compact_sql_model/README.md new file mode 100644 index 00000000..d2a15606 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/README.md @@ -0,0 +1,4 @@ +# Compact SQL Model + +Verifies that Python and Rust both load SQL-first compact model blocks in +`model name from source (...)` form, including multiple models in one `.sql` file. diff --git a/tests/native-fixtures/compact_sql_model/expected/result.json b/tests/native-fixtures/compact_sql_model/expected/result.json new file mode 100644 index 00000000..d96ff131 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/expected/result.json @@ -0,0 +1,10 @@ +[ + { + "status": "paid", + "total_revenue": 250 + }, + { + "status": "refunded", + "total_revenue": 50 + } +] diff --git a/tests/native-fixtures/compact_sql_model/expected/validation.json b/tests/native-fixtures/compact_sql_model/expected/validation.json new file mode 100644 index 00000000..39be4476 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": true +} diff --git a/tests/native-fixtures/compact_sql_model/models/orders.sql b/tests/native-fixtures/compact_sql_model/models/orders.sql new file mode 100644 index 00000000..a3e48b47 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/models/orders.sql @@ -0,0 +1,11 @@ +model orders from orders ( + primary key (order_id) + status + created_at as created_at : time grain day + sum(amount) as total_revenue +) + +model customers from customers ( + primary key (customer_id) + country +) diff --git a/tests/native-fixtures/compact_sql_model/queries/revenue_by_status.query.yml b/tests/native-fixtures/compact_sql_model/queries/revenue_by_status.query.yml new file mode 100644 index 00000000..e1fcbe86 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/queries/revenue_by_status.query.yml @@ -0,0 +1,6 @@ +metrics: + - orders.total_revenue +dimensions: + - orders.status +order_by: + - orders.status diff --git a/tests/native-fixtures/compact_sql_model/seed/duckdb.sql b/tests/native-fixtures/compact_sql_model/seed/duckdb.sql new file mode 100644 index 00000000..4d3bfc50 --- /dev/null +++ b/tests/native-fixtures/compact_sql_model/seed/duckdb.sql @@ -0,0 +1,19 @@ +create table orders ( + order_id integer, + status varchar, + amount integer, + created_at timestamp +); + +create table customers ( + customer_id integer, + country varchar +); + +insert into orders values + (1, 'paid', 100, timestamp '2026-01-01 10:00:00'), + (2, 'paid', 150, timestamp '2026-01-02 10:00:00'), + (3, 'refunded', 50, timestamp '2026-01-03 10:00:00'); + +insert into customers values + (1, 'US'); diff --git a/tests/native-fixtures/custom_relationship_sql/README.md b/tests/native-fixtures/custom_relationship_sql/README.md new file mode 100644 index 00000000..db0d6eee --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/README.md @@ -0,0 +1,6 @@ +# Custom Relationship SQL + +Validates that native relationship `sql` is honored by Python and Rust query +generation. The fixture uses `IS NOT DISTINCT FROM` in the custom join so the +result set differs from the default composite equality join when tenant IDs are +NULL. diff --git a/tests/native-fixtures/custom_relationship_sql/expected/result.json b/tests/native-fixtures/custom_relationship_sql/expected/result.json new file mode 100644 index 00000000..d703573d --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/expected/result.json @@ -0,0 +1,10 @@ +[ + { + "country": "Global", + "total_revenue": 70 + }, + { + "country": "US", + "total_revenue": 50 + } +] diff --git a/tests/native-fixtures/custom_relationship_sql/expected/validation.json b/tests/native-fixtures/custom_relationship_sql/expected/validation.json new file mode 100644 index 00000000..99ac6c9e --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/expected/validation.json @@ -0,0 +1,4 @@ +{ + "valid": true, + "errors": [] +} diff --git a/tests/native-fixtures/custom_relationship_sql/models/orders.yml b/tests/native-fixtures/custom_relationship_sql/models/orders.yml new file mode 100644 index 00000000..1ade31ca --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/models/orders.yml @@ -0,0 +1,24 @@ +version: 1 +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: amount + type: numeric + relationships: + - name: customers + type: many_to_one + foreign_key_columns: [customer_id, tenant_id] + primary_key_columns: [customer_id, tenant_id] + sql: "{from}.customer_id = {to}.customer_id AND {from}.tenant_id IS NOT DISTINCT FROM {to}.tenant_id" + metrics: + - name: total_revenue + agg: sum + sql: amount + - name: customers + table: customers + primary_key_columns: [customer_id, tenant_id] + dimensions: + - name: country + type: categorical diff --git a/tests/native-fixtures/custom_relationship_sql/queries/revenue_by_country.query.yml b/tests/native-fixtures/custom_relationship_sql/queries/revenue_by_country.query.yml new file mode 100644 index 00000000..cd4d7e47 --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/queries/revenue_by_country.query.yml @@ -0,0 +1,6 @@ +metrics: + - orders.total_revenue +dimensions: + - customers.country +order_by: + - customers.country diff --git a/tests/native-fixtures/custom_relationship_sql/seed/duckdb.sql b/tests/native-fixtures/custom_relationship_sql/seed/duckdb.sql new file mode 100644 index 00000000..e2e8b55e --- /dev/null +++ b/tests/native-fixtures/custom_relationship_sql/seed/duckdb.sql @@ -0,0 +1,20 @@ +create table orders ( + order_id integer, + customer_id integer, + tenant_id integer, + amount integer +); + +create table customers ( + customer_id integer, + tenant_id integer, + country varchar +); + +insert into orders values + (1, 100, 1, 50), + (2, 100, null, 70); + +insert into customers values + (100, 1, 'US'), + (100, null, 'Global'); diff --git a/tests/native-fixtures/invalid_unknown_native_field/expected/validation.json b/tests/native-fixtures/invalid_unknown_native_field/expected/validation.json new file mode 100644 index 00000000..f71325bf --- /dev/null +++ b/tests/native-fixtures/invalid_unknown_native_field/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": false +} diff --git a/tests/native-fixtures/invalid_unknown_native_field/models/orders.yml b/tests/native-fixtures/invalid_unknown_native_field/models/orders.yml new file mode 100644 index 00000000..9a0f3744 --- /dev/null +++ b/tests/native-fixtures/invalid_unknown_native_field/models/orders.yml @@ -0,0 +1,8 @@ +version: 1 + +models: + - name: orders + table: orders + metrcs: + - name: order_count + agg: count diff --git a/tests/native-fixtures/manifest.yml b/tests/native-fixtures/manifest.yml index 58d7c3bd..01e21510 100644 --- a/tests/native-fixtures/manifest.yml +++ b/tests/native-fixtures/manifest.yml @@ -51,6 +51,54 @@ fixtures: result_columns: [status, total_revenue] sql_contains: [raw_orders, is_deleted, total_revenue] + - name: compact_sql_model + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: revenue_by_status_from_compact_sql_model + file: queries/revenue_by_status.query.yml + expected_result: expected/result.json + result_columns: [status, total_revenue] + sql_contains: [SUM, orders, total_revenue] + + - name: relationship_default_keys + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: customer_count_by_order_status + file: queries/customer_count_by_order_status.query.yml + expected_result: expected/result.json + result_columns: [status, customer_count] + sql_contains: [JOIN, customers, orders, customers_cte.id, orders_cte.id] + - name: customer_count_by_profile_tier + file: queries/customer_count_by_profile_tier.query.yml + expected_result: expected/profile_tier_result.json + result_columns: [tier, customer_count] + sql_contains: [JOIN, customers, profiles, customers_cte.id, profiles_cte.id] + - name: payment_count_by_account_region + file: queries/payment_count_by_account_region.query.yml + expected_result: expected/account_region_result.json + result_columns: [region, payment_count] + sql_contains: [JOIN, payments, accounts, payments_cte.accounts_id, accounts_cte.account_uid] + - name: invoice_count_by_vendor_segment + file: queries/invoice_count_by_vendor_segment.query.yml + expected_result: expected/vendor_segment_result.json + result_columns: [segment, invoice_count] + sql_contains: [JOIN, invoices, vendors, invoices_cte.vendor_ref, vendors_cte.vendor_uid] + + - name: native_aliases + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: revenue_by_status + file: queries/revenue_by_status.query.yml + expected_result: expected/result.json + result_columns: [status, total_revenue, revenue_per_order] + sql_contains: [order_status, total_revenue, revenue_per_order] + - name: default_time_dimension valid: true seed: seed/duckdb.sql @@ -106,6 +154,28 @@ fixtures: result_columns: [category, total_revenue] sql_contains: [JOIN, order_items, products] + - name: many_to_many_composite_keys + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: revenue_by_category + file: queries/revenue_by_category.query.yml + expected_result: expected/result.json + result_columns: [category, total_revenue] + sql_contains: [JOIN, order_items, products, tenant_id, order_id, product_id, AND] + + - name: statistical_aggregations + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: amount_stats + file: queries/amount_stats.query.yml + expected_result: expected/result.json + result_columns: [amount_stddev, amount_stddev_pop, amount_variance, amount_variance_pop] + sql_contains: [STDDEV, STDDEV_POP, VARIANCE, VAR_POP] + - name: fanout_symmetric_aggregation valid: true seed: seed/duckdb.sql @@ -130,8 +200,7 @@ fixtures: queries: - name: revenue_with_window_calcs file: queries/revenue_with_window_calcs.query.yml - rust_expected_result: expected/result.json - rust_only_reason: Python native query API does not accept table_calculations yet. + expected_result: expected/result.json result_columns: [status, total_revenue, running_revenue, revenue_pct_of_total] sql_contains: [total_revenue, status] rust_sql_contains: [running_revenue, revenue_pct_of_total, "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", NULLIF] @@ -202,24 +271,57 @@ fixtures: result_columns: [status, total_revenue] sql_contains: ["'paid'", ">= 100", total_revenue] + - name: top_level_metric_contract + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: revenue_per_order + file: queries/revenue_per_order.query.yml + expected_result: expected/result.json + result_columns: [status, finance.revenue_per_order] + sql_contains: [finance.revenue_per_order, ">= 10", NULLIF] + + - name: custom_relationship_sql + valid: true + seed: seed/duckdb.sql + expected_validation: expected/validation.json + queries: + - name: revenue_by_country + file: queries/revenue_by_country.query.yml + expected_result: expected/result.json + result_columns: [country, total_revenue] + sql_contains: ["IS NOT DISTINCT FROM", customers, total_revenue] + - name: advanced_metrics valid: true + seed: seed/duckdb.sql expected_validation: expected/validation.json queries: - name: cumulative_revenue_by_month file: queries/cumulative_revenue_by_month.query.yml + expected_result: expected/cumulative_revenue_by_month_result.json + result_columns: [event_date__month, total_revenue, cumulative_revenue] sql_contains: [SUM, OVER, cumulative_revenue] - name: revenue_mom_by_month_region file: queries/revenue_mom_by_month_region.query.yml + expected_result: expected/revenue_mom_by_month_region_result.json + result_columns: [event_date__month, region, total_revenue, events.revenue_mom] sql_contains: [LAG, "PARTITION BY", revenue_mom] - name: signup_conversion_by_region file: queries/signup_conversion_by_region.query.yml + expected_result: expected/signup_conversion_by_region_result.json + result_columns: [region, signup_conversion] sql_contains: [base_events, conversion_events, signup_conversion] - name: signup_retention file: queries/signup_retention.query.yml + expected_result: expected/signup_retention_result.json + result_columns: [cohort_date, days_since, active_users, cohort_size, retention_pct] sql_contains: ["WITH cohorts", retention_pct, days_since] - name: multi_platform_users file: queries/multi_platform_users.query.yml + expected_result: expected/multi_platform_users_result.json + result_columns: [region, multi_platform_users] sql_contains: [cohort_sub, platform_count, HAVING, multi_platform_users] - name: invalid_duplicate_dimension @@ -240,6 +342,13 @@ fixtures: error_contains: - references unknown measure 'missing_revenue' + - name: invalid_unknown_native_field + valid: false + expected_validation: expected/validation.json + error_contains: + - unknown + - metrcs + - name: unsupported_version valid: false expected_validation: expected/validation.json diff --git a/tests/native-fixtures/many_to_many_composite_keys/README.md b/tests/native-fixtures/many_to_many_composite_keys/README.md new file mode 100644 index 00000000..d65ecc7d --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/README.md @@ -0,0 +1,3 @@ +# Many-To-Many Composite Keys + +Exercises a many-to-many relationship through a junction model where both the source and target models use composite primary keys. diff --git a/tests/native-fixtures/many_to_many_composite_keys/expected/result.json b/tests/native-fixtures/many_to_many_composite_keys/expected/result.json new file mode 100644 index 00000000..b205fef4 --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/expected/result.json @@ -0,0 +1,14 @@ +[ + { + "category": "hardware", + "total_revenue": 100 + }, + { + "category": "services", + "total_revenue": 50 + }, + { + "category": "software", + "total_revenue": 200 + } +] diff --git a/tests/native-fixtures/many_to_many_composite_keys/expected/validation.json b/tests/native-fixtures/many_to_many_composite_keys/expected/validation.json new file mode 100644 index 00000000..39be4476 --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": true +} diff --git a/tests/native-fixtures/many_to_many_composite_keys/models/sales.yml b/tests/native-fixtures/many_to_many_composite_keys/models/sales.yml new file mode 100644 index 00000000..5cbae95e --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/models/sales.yml @@ -0,0 +1,26 @@ +version: 1 +models: + - name: orders + table: orders + primary_key_columns: [tenant_id, order_id] + relationships: + - name: products + type: many_to_many + through: order_items + through_foreign_key_columns: [tenant_id, order_id] + related_foreign_key_columns: [tenant_id, product_id] + metrics: + - name: total_revenue + agg: sum + sql: amount + + - name: order_items + table: order_items + primary_key_columns: [tenant_id, order_id, product_id] + + - name: products + table: products + primary_key_columns: [tenant_id, product_id] + dimensions: + - name: category + type: categorical diff --git a/tests/native-fixtures/many_to_many_composite_keys/queries/revenue_by_category.query.yml b/tests/native-fixtures/many_to_many_composite_keys/queries/revenue_by_category.query.yml new file mode 100644 index 00000000..1316b13f --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/queries/revenue_by_category.query.yml @@ -0,0 +1,6 @@ +metrics: + - orders.total_revenue +dimensions: + - products.category +order_by: + - products.category diff --git a/tests/native-fixtures/many_to_many_composite_keys/seed/duckdb.sql b/tests/native-fixtures/many_to_many_composite_keys/seed/duckdb.sql new file mode 100644 index 00000000..be6e4350 --- /dev/null +++ b/tests/native-fixtures/many_to_many_composite_keys/seed/duckdb.sql @@ -0,0 +1,32 @@ +create table orders ( + tenant_id integer, + order_id integer, + amount integer +); + +create table order_items ( + tenant_id integer, + order_id integer, + product_id integer +); + +create table products ( + tenant_id integer, + product_id integer, + category varchar +); + +insert into orders values + (1, 100, 100), + (1, 101, 50), + (2, 100, 200); + +insert into order_items values + (1, 100, 10), + (1, 101, 11), + (2, 100, 10); + +insert into products values + (1, 10, 'hardware'), + (1, 11, 'services'), + (2, 10, 'software'); diff --git a/tests/native-fixtures/native_aliases/README.md b/tests/native-fixtures/native_aliases/README.md new file mode 100644 index 00000000..968f83fc --- /dev/null +++ b/tests/native-fixtures/native_aliases/README.md @@ -0,0 +1,8 @@ +# Native Compatibility Aliases + +Verifies that Python compatibility input aliases are accepted by both native runtimes: + +- model `measures` as an alias for `metrics` +- dimension and metric `expr` as aliases for `sql` +- metric `measure` as an alias for `sql` +- model `auto_dimensions: false` as accepted compatibility input diff --git a/tests/native-fixtures/native_aliases/expected/result.json b/tests/native-fixtures/native_aliases/expected/result.json new file mode 100644 index 00000000..72a1f900 --- /dev/null +++ b/tests/native-fixtures/native_aliases/expected/result.json @@ -0,0 +1,12 @@ +[ + { + "status": "paid", + "total_revenue": 150, + "revenue_per_order": 75 + }, + { + "status": "refunded", + "total_revenue": 30, + "revenue_per_order": 30 + } +] diff --git a/tests/native-fixtures/native_aliases/expected/validation.json b/tests/native-fixtures/native_aliases/expected/validation.json new file mode 100644 index 00000000..39be4476 --- /dev/null +++ b/tests/native-fixtures/native_aliases/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": true +} diff --git a/tests/native-fixtures/native_aliases/models/models.yml b/tests/native-fixtures/native_aliases/models/models.yml new file mode 100644 index 00000000..62e779cd --- /dev/null +++ b/tests/native-fixtures/native_aliases/models/models.yml @@ -0,0 +1,20 @@ +version: 1 + +models: + - name: orders + table: orders + primary_key: order_id + auto_dimensions: false + dimensions: + - name: status + type: categorical + expr: order_status + measures: + - name: total_revenue + agg: sum + expr: amount + - name: order_count + agg: count + - name: revenue_per_order + type: derived + measure: total_revenue / NULLIF(order_count, 0) diff --git a/tests/native-fixtures/native_aliases/queries/revenue_by_status.query.yml b/tests/native-fixtures/native_aliases/queries/revenue_by_status.query.yml new file mode 100644 index 00000000..72c156dc --- /dev/null +++ b/tests/native-fixtures/native_aliases/queries/revenue_by_status.query.yml @@ -0,0 +1,7 @@ +metrics: + - orders.total_revenue + - orders.revenue_per_order +dimensions: + - orders.status +order_by: + - orders.status diff --git a/tests/native-fixtures/native_aliases/seed/duckdb.sql b/tests/native-fixtures/native_aliases/seed/duckdb.sql new file mode 100644 index 00000000..76a99248 --- /dev/null +++ b/tests/native-fixtures/native_aliases/seed/duckdb.sql @@ -0,0 +1,10 @@ +create table orders ( + order_id integer, + order_status varchar, + amount integer +); + +insert into orders values + (1, 'paid', 100), + (2, 'paid', 50), + (3, 'refunded', 30); diff --git a/tests/native-fixtures/relationship_default_keys/README.md b/tests/native-fixtures/relationship_default_keys/README.md new file mode 100644 index 00000000..dcdf95be --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/README.md @@ -0,0 +1,7 @@ +# Relationship Default Keys + +Verifies relationship default-key compatibility between Python and Rust: + +- omitted `one_to_many` and `one_to_one` keys use `id` +- omitted `many_to_one` foreign keys use `{name}_id` +- omitted relationship `primary_key` resolves to the target model's declared primary key diff --git a/tests/native-fixtures/relationship_default_keys/expected/account_region_result.json b/tests/native-fixtures/relationship_default_keys/expected/account_region_result.json new file mode 100644 index 00000000..ba2df519 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/expected/account_region_result.json @@ -0,0 +1,10 @@ +[ + { + "region": "east", + "payment_count": 1 + }, + { + "region": "west", + "payment_count": 1 + } +] diff --git a/tests/native-fixtures/relationship_default_keys/expected/profile_tier_result.json b/tests/native-fixtures/relationship_default_keys/expected/profile_tier_result.json new file mode 100644 index 00000000..d83edfc4 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/expected/profile_tier_result.json @@ -0,0 +1,10 @@ +[ + { + "tier": "gold", + "customer_count": 1 + }, + { + "tier": "silver", + "customer_count": 1 + } +] diff --git a/tests/native-fixtures/relationship_default_keys/expected/result.json b/tests/native-fixtures/relationship_default_keys/expected/result.json new file mode 100644 index 00000000..9c93df0a --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/expected/result.json @@ -0,0 +1,10 @@ +[ + { + "status": "paid", + "customer_count": 1 + }, + { + "status": "refunded", + "customer_count": 1 + } +] diff --git a/tests/native-fixtures/relationship_default_keys/expected/validation.json b/tests/native-fixtures/relationship_default_keys/expected/validation.json new file mode 100644 index 00000000..39be4476 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": true +} diff --git a/tests/native-fixtures/relationship_default_keys/expected/vendor_segment_result.json b/tests/native-fixtures/relationship_default_keys/expected/vendor_segment_result.json new file mode 100644 index 00000000..db45943c --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/expected/vendor_segment_result.json @@ -0,0 +1,10 @@ +[ + { + "segment": "enterprise", + "invoice_count": 1 + }, + { + "segment": "midmarket", + "invoice_count": 1 + } +] diff --git a/tests/native-fixtures/relationship_default_keys/models/models.yml b/tests/native-fixtures/relationship_default_keys/models/models.yml new file mode 100644 index 00000000..32a2859c --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/models/models.yml @@ -0,0 +1,69 @@ +version: 1 + +models: + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical + metrics: + - name: customer_count + agg: count + relationships: + - name: orders + type: one_to_many + - name: profiles + type: one_to_one + + - name: orders + table: orders + primary_key: id + dimensions: + - name: status + type: categorical + + - name: profiles + table: profiles + primary_key: id + dimensions: + - name: tier + type: categorical + + - name: payments + table: payments + primary_key: payment_id + dimensions: + - name: amount + type: numeric + metrics: + - name: payment_count + agg: count + relationships: + - name: accounts + type: many_to_one + + - name: accounts + table: accounts + primary_key: account_uid + dimensions: + - name: region + type: categorical + + - name: invoices + table: invoices + primary_key: invoice_id + metrics: + - name: invoice_count + agg: count + relationships: + - name: vendors + type: many_to_one + foreign_key: vendor_ref + + - name: vendors + table: vendors + primary_key: vendor_uid + dimensions: + - name: segment + type: categorical diff --git a/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_order_status.query.yml b/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_order_status.query.yml new file mode 100644 index 00000000..a9c3be1f --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_order_status.query.yml @@ -0,0 +1,6 @@ +metrics: + - customers.customer_count +dimensions: + - orders.status +order_by: + - orders.status diff --git a/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_profile_tier.query.yml b/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_profile_tier.query.yml new file mode 100644 index 00000000..5cc69b34 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/queries/customer_count_by_profile_tier.query.yml @@ -0,0 +1,6 @@ +metrics: + - customers.customer_count +dimensions: + - profiles.tier +order_by: + - profiles.tier diff --git a/tests/native-fixtures/relationship_default_keys/queries/invoice_count_by_vendor_segment.query.yml b/tests/native-fixtures/relationship_default_keys/queries/invoice_count_by_vendor_segment.query.yml new file mode 100644 index 00000000..b000e9b0 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/queries/invoice_count_by_vendor_segment.query.yml @@ -0,0 +1,6 @@ +metrics: + - invoices.invoice_count +dimensions: + - vendors.segment +order_by: + - vendors.segment diff --git a/tests/native-fixtures/relationship_default_keys/queries/payment_count_by_account_region.query.yml b/tests/native-fixtures/relationship_default_keys/queries/payment_count_by_account_region.query.yml new file mode 100644 index 00000000..9ba10d6a --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/queries/payment_count_by_account_region.query.yml @@ -0,0 +1,6 @@ +metrics: + - payments.payment_count +dimensions: + - accounts.region +order_by: + - accounts.region diff --git a/tests/native-fixtures/relationship_default_keys/seed/duckdb.sql b/tests/native-fixtures/relationship_default_keys/seed/duckdb.sql new file mode 100644 index 00000000..cfdc81a3 --- /dev/null +++ b/tests/native-fixtures/relationship_default_keys/seed/duckdb.sql @@ -0,0 +1,63 @@ +create table customers ( + id integer, + country varchar +); + +create table orders ( + id integer, + status varchar +); + +create table profiles ( + id integer, + tier varchar +); + +create table accounts ( + account_uid integer, + region varchar +); + +create table payments ( + payment_id integer, + accounts_id integer, + amount integer +); + +create table vendors ( + vendor_uid integer, + segment varchar +); + +create table invoices ( + invoice_id integer, + vendor_ref integer +); + +insert into customers values + (1, 'US'), + (2, 'CA'); + +insert into orders values + (1, 'paid'), + (2, 'refunded'); + +insert into profiles values + (1, 'gold'), + (2, 'silver'); + +insert into accounts values + (101, 'east'), + (102, 'west'); + +insert into payments values + (1001, 101, 40), + (1002, 102, 60); + +insert into vendors values + (201, 'enterprise'), + (202, 'midmarket'); + +insert into invoices values + (3001, 201), + (3002, 202); diff --git a/tests/native-fixtures/statistical_aggregations/README.md b/tests/native-fixtures/statistical_aggregations/README.md new file mode 100644 index 00000000..5e40d38a --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/README.md @@ -0,0 +1,4 @@ +# Statistical Aggregations + +Verifies that Python and Rust accept and compile the native statistical aggregation +contract: `stddev`, `stddev_pop`, `variance`, and `variance_pop`. diff --git a/tests/native-fixtures/statistical_aggregations/expected/result.json b/tests/native-fixtures/statistical_aggregations/expected/result.json new file mode 100644 index 00000000..86fd4ca1 --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/expected/result.json @@ -0,0 +1,8 @@ +[ + { + "amount_stddev": 50.0, + "amount_stddev_pop": 40.824829046386306, + "amount_variance": 2500.0, + "amount_variance_pop": 1666.6666666666667 + } +] diff --git a/tests/native-fixtures/statistical_aggregations/expected/validation.json b/tests/native-fixtures/statistical_aggregations/expected/validation.json new file mode 100644 index 00000000..39be4476 --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/expected/validation.json @@ -0,0 +1,3 @@ +{ + "valid": true +} diff --git a/tests/native-fixtures/statistical_aggregations/models/orders.yml b/tests/native-fixtures/statistical_aggregations/models/orders.yml new file mode 100644 index 00000000..acb2d87f --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/models/orders.yml @@ -0,0 +1,19 @@ +version: 1 + +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: amount_stddev + agg: stddev + sql: amount + - name: amount_stddev_pop + agg: stddev_pop + sql: amount + - name: amount_variance + agg: variance + sql: amount + - name: amount_variance_pop + agg: variance_pop + sql: amount diff --git a/tests/native-fixtures/statistical_aggregations/queries/amount_stats.query.yml b/tests/native-fixtures/statistical_aggregations/queries/amount_stats.query.yml new file mode 100644 index 00000000..3a428191 --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/queries/amount_stats.query.yml @@ -0,0 +1,5 @@ +metrics: + - orders.amount_stddev + - orders.amount_stddev_pop + - orders.amount_variance + - orders.amount_variance_pop diff --git a/tests/native-fixtures/statistical_aggregations/seed/duckdb.sql b/tests/native-fixtures/statistical_aggregations/seed/duckdb.sql new file mode 100644 index 00000000..9aed704a --- /dev/null +++ b/tests/native-fixtures/statistical_aggregations/seed/duckdb.sql @@ -0,0 +1,9 @@ +create table orders ( + order_id integer, + amount integer +); + +insert into orders values + (1, 100), + (2, 150), + (3, 50); diff --git a/tests/native-fixtures/table_calculations/README.md b/tests/native-fixtures/table_calculations/README.md index f5fa8050..c5d90c5f 100644 --- a/tests/native-fixtures/table_calculations/README.md +++ b/tests/native-fixtures/table_calculations/README.md @@ -1,5 +1,8 @@ # Table Calculations -Valid native fixture that proves Rust can compile query-local table calculations into SQL window expressions. +Valid native fixture that proves query-local table calculations produce the same +result in Python and Rust for the shared subset. -Python currently applies table calculations after query execution, so the Python fixture runner compiles the base query while the Rust runner asserts the table-calculation SQL shape. +Python applies table calculations after query execution with +`TableCalculationProcessor`. Rust compiles the same calculations into SQL window +expressions, and the Rust runner also asserts the generated SQL shape. diff --git a/tests/native-fixtures/table_calculations/expected/result.json b/tests/native-fixtures/table_calculations/expected/result.json index ea8bcc8b..d16616bc 100644 --- a/tests/native-fixtures/table_calculations/expected/result.json +++ b/tests/native-fixtures/table_calculations/expected/result.json @@ -3,12 +3,12 @@ "status": "paid", "total_revenue": 250, "running_revenue": 250, - "revenue_pct_of_total": 83.33333333333333 + "revenue_pct_of_total": 83.33333333333334 }, { "status": "refunded", "total_revenue": 50, "running_revenue": 300, - "revenue_pct_of_total": 16.666666666666668 + "revenue_pct_of_total": 16.666666666666664 } ] diff --git a/tests/native-fixtures/top_level_metric_contract/README.md b/tests/native-fixtures/top_level_metric_contract/README.md new file mode 100644 index 00000000..a43e7883 --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/README.md @@ -0,0 +1,8 @@ +# Top-Level Metric Contract + +Validates the portable native contract for top-level metrics and parameters. + +Top-level metrics may have graph-style names, including dotted names, but the Rust +runtime must be able to infer exactly one owning model before query compilation. +Top-level parameters remain graph-scoped and must round-trip through Python and +interpolate in Python and Rust query paths. diff --git a/tests/native-fixtures/top_level_metric_contract/expected/result.json b/tests/native-fixtures/top_level_metric_contract/expected/result.json new file mode 100644 index 00000000..57ad7ffa --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/expected/result.json @@ -0,0 +1,10 @@ +[ + { + "status": "paid", + "finance.revenue_per_order": 75.0 + }, + { + "status": "refunded", + "finance.revenue_per_order": 25.0 + } +] diff --git a/tests/native-fixtures/top_level_metric_contract/expected/validation.json b/tests/native-fixtures/top_level_metric_contract/expected/validation.json new file mode 100644 index 00000000..99ac6c9e --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/expected/validation.json @@ -0,0 +1,4 @@ +{ + "valid": true, + "errors": [] +} diff --git a/tests/native-fixtures/top_level_metric_contract/models/orders.yml b/tests/native-fixtures/top_level_metric_contract/models/orders.yml new file mode 100644 index 00000000..b80c2acd --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/models/orders.yml @@ -0,0 +1,25 @@ +version: 1 +parameters: + - name: min_amount + type: number + default_value: 0 +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: amount + type: numeric + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count +metrics: + - name: finance.revenue_per_order + type: ratio + numerator: orders.revenue + denominator: orders.order_count diff --git a/tests/native-fixtures/top_level_metric_contract/queries/revenue_per_order.query.yml b/tests/native-fixtures/top_level_metric_contract/queries/revenue_per_order.query.yml new file mode 100644 index 00000000..9eb4c981 --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/queries/revenue_per_order.query.yml @@ -0,0 +1,10 @@ +metrics: + - finance.revenue_per_order +dimensions: + - orders.status +filters: + - orders.amount >= {{ min_amount }} +order_by: + - orders.status +parameter_values: + min_amount: 10 diff --git a/tests/native-fixtures/top_level_metric_contract/seed/duckdb.sql b/tests/native-fixtures/top_level_metric_contract/seed/duckdb.sql new file mode 100644 index 00000000..7d82d9d2 --- /dev/null +++ b/tests/native-fixtures/top_level_metric_contract/seed/duckdb.sql @@ -0,0 +1,10 @@ +create table orders ( + order_id integer, + status varchar, + amount integer +); + +insert into orders values + (1, 'paid', 100), + (2, 'paid', 50), + (3, 'refunded', 25); diff --git a/tests/native_compat/test_basic_model_fixture.py b/tests/native_compat/test_basic_model_fixture.py index 2dbcb2c3..d77e420d 100644 --- a/tests/native_compat/test_basic_model_fixture.py +++ b/tests/native_compat/test_basic_model_fixture.py @@ -10,8 +10,10 @@ import yaml from sidemantic.core.semantic_layer import SemanticLayer +from sidemantic.core.table_calculation import TableCalculation from sidemantic.loaders import load_from_directory from sidemantic.sql.query_rewriter import QueryRewriter +from sidemantic.sql.table_calc_processor import TableCalculationProcessor FIXTURE_SUITE_ROOT = Path(__file__).parents[1] / "native-fixtures" @@ -111,12 +113,11 @@ def test_native_fixture_loads_compiles_and_executes(fixture, query_manifest): parameter_values = query_kwargs.pop("parameter_values", None) if parameter_values is not None: query_kwargs["parameters"] = parameter_values - table_calculations = query_kwargs.pop("table_calculations", None) - if table_calculations: - assert query_manifest.get("rust_only_reason"), ( - "fixture table_calculations must declare rust_only_reason while Python strips " - "table_calculations from native query execution" - ) + table_calculation_defs = query_kwargs.pop("table_calculations", None) + table_calculations = [ + value if isinstance(value, TableCalculation) else TableCalculation(**value) + for value in (table_calculation_defs or []) + ] compiled = layer.compile(**query_kwargs) @@ -130,8 +131,16 @@ def test_native_fixture_loads_compiles_and_executes(fixture, query_manifest): return layer.adapter.execute((fixture_root / fixture["seed"]).read_text()) - rows = layer.query(**query_kwargs).fetchall() - result_columns = query_manifest["result_columns"] + relation = layer.query(**query_kwargs) + rows = relation.fetchall() + if table_calculations: + base_columns = list(getattr(relation, "columns", []) or []) + if not base_columns: + base_columns = [column[0] for column in getattr(relation, "description", []) or []] + rows, result_columns = TableCalculationProcessor(table_calculations).process(rows, base_columns) + assert result_columns == query_manifest["result_columns"] + else: + result_columns = query_manifest["result_columns"] actual = [ {column: normalize_value(value) for column, value in zip(result_columns, row, strict=True)} for row in rows ] diff --git a/tests/queries/test_basic.py b/tests/queries/test_basic.py index 5fc3e0d3..0553a6d2 100644 --- a/tests/queries/test_basic.py +++ b/tests/queries/test_basic.py @@ -250,6 +250,23 @@ def test_sql_compilation(layer): assert "GROUP BY" in sql +def test_sql_compilation_preserves_zero_limit_and_offset(layer): + orders = Model( + name="orders", + table="public.orders", + primary_key="order_id", + dimensions=[Dimension(name="status", type="categorical")], + metrics=[Metric(name="revenue", agg="sum", sql="order_amount")], + ) + + layer.add_model(orders) + + sql = layer.compile(metrics=["orders.revenue"], dimensions=["orders.status"], limit=0, offset=0) + + assert "\nLIMIT 0" in sql + assert "\nOFFSET 0" in sql + + def test_multi_model_query(layer): """Test query across multiple models.""" orders = Model( diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index ca4f0f36..04d1bd0e 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -162,6 +162,19 @@ def test_limit(semantic_layer): assert len(rows) == 1 +def test_zero_limit_and_offset_are_preserved(semantic_layer): + """Test rewriting query with explicit zero pagination values.""" + sql = "SELECT orders.revenue, orders.status FROM orders ORDER BY orders.status LIMIT 0 OFFSET 0" + + rewritten = QueryRewriter(semantic_layer.graph).rewrite(sql) + assert "\nLIMIT 0" in rewritten + assert "\nOFFSET 0" in rewritten + + result = semantic_layer.sql(sql) + rows = _rows(result) + assert rows == [] + + def test_join_query(semantic_layer): """Test query that requires join.""" sql = "SELECT orders.revenue, customers.region FROM orders" @@ -1435,6 +1448,26 @@ def test_postprocess_limit_in_outer(semantic_layer): assert rows[0]["revenue"] == 250.00 +def test_postprocess_zero_limit_and_offset_in_outer(semantic_layer): + """Test zero pagination in outer query over semantic results.""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + ORDER BY revenue DESC + LIMIT 0 OFFSET 0 + """ + + rewritten = QueryRewriter(semantic_layer.graph).rewrite(sql) + assert "\nLIMIT 0" in rewritten + assert "\nOFFSET 0" in rewritten + + result = semantic_layer.sql(sql) + rows = _rows(result) + assert rows == [] + + def test_postprocess_cross_model_subquery(semantic_layer): """Test post-processing over cross-model semantic subquery.""" sql = """ diff --git a/tests/rust_layer_adapter.py b/tests/rust_layer_adapter.py index 5b1aac84..a9f10937 100644 --- a/tests/rust_layer_adapter.py +++ b/tests/rust_layer_adapter.py @@ -116,8 +116,6 @@ def compile( active_dialect = dialect or self.dialect if active_dialect not in {"duckdb", "bigquery"}: raise NotImplementedError(f"pure Rust test adapter does not support dialect '{active_dialect}' yet") - if offset is not None: - raise NotImplementedError("pure Rust test adapter does not support offset yet") if parameters: raise NotImplementedError("pure Rust test adapter does not support template parameters yet") effective_preaggregations = self.use_preaggregations if use_preaggregations is None else use_preaggregations @@ -134,6 +132,7 @@ def compile( "segments": segments or [], "order_by": order_by or [], "limit": limit, + "offset": offset, "ungrouped": ungrouped, "skip_default_time_dimensions": skip_default_time_dimensions, "dialect": active_dialect, @@ -491,8 +490,6 @@ def generate( aliases: dict[str, str] | None = None, skip_default_time_dimensions: bool = False, ) -> str: - if offset is not None: - raise NotImplementedError("pure Rust test adapter does not support offset yet") if parameters: raise NotImplementedError("pure Rust test adapter does not support template parameters yet") if use_preaggregations: @@ -510,6 +507,7 @@ def generate( "segments": segments or [], "order_by": order_by or [], "limit": limit, + "offset": offset, "ungrouped": ungrouped, "skip_default_time_dimensions": skip_default_time_dimensions, "dialect": self.dialect, @@ -577,7 +575,7 @@ def rewrite(self, sql: str, strict: bool = True) -> str: if "parse" in error.lower(): raise ValueError(f"Failed to parse SQL: {error}") from None raise ValueError(error) - return response["sql"] + return _python_style_rewrite_sql(response["sql"]) def rust_build_symmetric_aggregate_sql( @@ -708,19 +706,26 @@ def _metric_to_rust_dict(metric: Metric) -> dict[str, Any]: def _relationship_to_rust_dict(relationship) -> dict[str, Any]: - return _drop_none( - { - "name": relationship.name, - "type": relationship.type, - "foreign_key": relationship.foreign_key, - "primary_key": relationship.primary_key, - "through": getattr(relationship, "through", None), - "through_foreign_key": getattr(relationship, "through_foreign_key", None), - "related_foreign_key": getattr(relationship, "related_foreign_key", None), - "sql": getattr(relationship, "sql", None), - "metadata": relationship.metadata, - } - ) + payload = { + "name": relationship.name, + "type": relationship.type, + "through": getattr(relationship, "through", None), + "through_foreign_key": getattr(relationship, "through_foreign_key", None), + "through_foreign_key_columns": getattr(relationship, "through_foreign_key_columns", None), + "related_foreign_key": getattr(relationship, "related_foreign_key", None), + "related_foreign_key_columns": getattr(relationship, "related_foreign_key_columns", None), + "sql": getattr(relationship, "sql", None), + "metadata": relationship.metadata, + } + if isinstance(relationship.foreign_key, list): + payload["foreign_key_columns"] = relationship.foreign_key + else: + payload["foreign_key"] = relationship.foreign_key + if isinstance(relationship.primary_key, list): + payload["primary_key_columns"] = relationship.primary_key + else: + payload["primary_key"] = relationship.primary_key + return _drop_none(payload) def _segment_to_rust_dict(segment) -> dict[str, Any]: @@ -757,6 +762,11 @@ def _drop_none(value: dict[str, Any]) -> dict[str, Any]: return {key: item for key, item in value.items() if item is not None} +def _python_style_rewrite_sql(sql: str) -> str: + sql = re.sub(r"\s+(LIMIT\s+\d+\b)", r"\n\1", sql, flags=re.IGNORECASE) + return re.sub(r"\s+(OFFSET\s+\d+\b)", r"\n\1", sql, flags=re.IGNORECASE) + + def _single_model_yaml(model: Model) -> str: return _graph_yaml({model.name: model}, {}) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 491ea407..8c811d29 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -4,6 +4,8 @@ import sys from pathlib import Path +import pytest + from sidemantic import SemanticLayer from sidemantic.loaders import load_from_directory @@ -33,6 +35,116 @@ def blocked_antlr4_import(name, *args, **kwargs): assert "orders" in layer.graph.models +def test_load_from_directory_strict_raises_on_detected_parse_error(tmp_path): + """Strict loading fails instead of returning a partial graph.""" + (tmp_path / "good.yml").write_text( + """ +models: + - name: orders + table: orders + primary_key: id +""" + ) + (tmp_path / "bad.yml").write_text( + """ +models: + - name: broken + table: [ +""" + ) + + layer = SemanticLayer() + with pytest.raises(ValueError, match="Could not parse .*bad.yml"): + load_from_directory(layer, tmp_path) + + assert not layer.graph.models + + +def test_load_from_directory_lenient_mode_skips_detected_parse_error(tmp_path): + """Lenient loading remains available as an explicit opt-in.""" + (tmp_path / "good.yml").write_text( + """ +models: + - name: orders + table: orders + primary_key: id +""" + ) + (tmp_path / "bad.yml").write_text( + """ +models: + - name: broken + table: [ +""" + ) + + layer = SemanticLayer() + load_from_directory(layer, tmp_path, strict=False) + + assert set(layer.graph.models) == {"orders"} + + +def test_load_from_directory_resolves_native_inheritance_across_files(tmp_path): + (tmp_path / "base.yml").write_text( + """ +version: 1 +models: + - name: base_orders + table: orders + primary_key: id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +""" + ) + (tmp_path / "child.yml").write_text( + """ +version: 1 +models: + - name: paid_orders + extends: base_orders + dimensions: + - name: paid_at + type: time + sql: paid_at + granularity: day +""" + ) + + layer = SemanticLayer() + load_from_directory(layer, tmp_path) + + paid_orders = layer.graph.models["paid_orders"] + assert paid_orders.table == "orders" + assert paid_orders.primary_key == "id" + assert paid_orders.extends is None + assert paid_orders.get_dimension("status") is not None + assert paid_orders.get_dimension("paid_at") is not None + assert paid_orders.get_metric("revenue") is not None + + +def test_load_from_directory_strict_raises_on_missing_native_parent(tmp_path): + (tmp_path / "child.yml").write_text( + """ +version: 1 +models: + - name: paid_orders + extends: missing_base + table: orders +""" + ) + + layer = SemanticLayer() + with pytest.raises(ValueError, match="Native model 'paid_orders' extends unknown model 'missing_base'"): + load_from_directory(layer, tmp_path) + + assert not layer.graph.models + + def test_native_inheritance_does_not_register_model_metrics_globally(tmp_path): (tmp_path / "models.yml").write_text( """ diff --git a/tests/test_metric_expressions.py b/tests/test_metric_expressions.py index 5ec879ee..2ba02cb4 100644 --- a/tests/test_metric_expressions.py +++ b/tests/test_metric_expressions.py @@ -1,6 +1,9 @@ """Test simplified metric expression syntax.""" +from sidemantic.core.dimension import Dimension from sidemantic.core.metric import Metric +from sidemantic.core.model import Model +from sidemantic.core.semantic_layer import SemanticLayer from sidemantic.core.sql_definitions import parse_sql_definitions @@ -41,6 +44,27 @@ def test_metric_with_full_expression(): assert m7.agg == "median" assert m7.sql == "price" + # Statistical aggregations + m8 = Metric(name="stddev_price", sql="STDDEV(price)") + assert m8.agg == "stddev" + assert m8.sql == "price" + + m9 = Metric(name="stddev_pop_price", sql="STDDEV_POP(price)") + assert m9.agg == "stddev_pop" + assert m9.sql == "price" + + m10 = Metric(name="variance_price", sql="VARIANCE(price)") + assert m10.agg == "variance" + assert m10.sql == "price" + + m11 = Metric(name="variance_pop_price", sql="VAR_POP(price)") + assert m11.agg == "variance_pop" + assert m11.sql == "price" + + m12 = Metric(name="variance_pop_price", sql="VARIANCE_POP(price)") + assert m12.agg == "variance_pop" + assert m12.sql == "price" + def test_metric_expression_case_insensitive(): """Test that aggregation function parsing is case-insensitive.""" @@ -66,6 +90,28 @@ def test_metric_old_syntax_still_works(): assert m.agg == "sum" assert m.sql == "amount" + variance_pop = Metric(name="variance_pop_price", agg="variance_pop", sql="price") + assert variance_pop.to_sql() == "VAR_POP(price)" + + +def test_graph_level_variance_pop_metric_compiles_to_var_pop(): + """Graph-level statistical aggregations should use DuckDB's VAR_POP spelling.""" + layer = SemanticLayer() + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="id", + metrics=[Metric(name="order_count", agg="count", sql="id")], + ) + ) + layer.graph.add_metric(Metric(name="amount_variance_pop", agg="variance_pop", sql="orders.amount")) + + sql = layer.compile(metrics=["amount_variance_pop"]) + + assert "VAR_POP(" in sql + assert "VARIANCE_POP(" not in sql + def test_metric_expr_alias(): """Test that expr can be used as alias for sql.""" @@ -118,3 +164,39 @@ def test_metric_complex_expression_not_parsed(): assert m.agg is None assert m.sql == "revenue - cost" assert m.type == "derived" + + +def test_ratio_prefers_exact_graph_metric_with_dotted_name(): + """Dotted ratio refs can name graph metrics and must not be split first.""" + layer = SemanticLayer(auto_register=False) + layer.adapter.execute(""" + CREATE TABLE orders ( + id INTEGER, + status VARCHAR, + amount INTEGER + ) + """) + layer.adapter.execute(""" + INSERT INTO orders VALUES + (1, 'paid', 100), + (2, 'paid', 50), + (3, 'open', 25) + """) + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="id", + dimensions=[Dimension(name="status", type="categorical")], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + layer.add_metric(Metric(name="orders.revenue", type="derived", sql="SUM(orders.amount) * 2")) + layer.add_metric(Metric(name="exact_ratio", type="ratio", numerator="orders.revenue", denominator="orders.revenue")) + + sql = layer.compile(metrics=["exact_ratio"], dimensions=["orders.status"]) + assert "orders_cte.revenue_raw" not in sql + assert "SUM(orders_cte.amount) * 2" in sql + + rows = layer.query(metrics=["exact_ratio"], dimensions=["orders.status"], order_by=["orders.status"]).fetchall() + assert rows == [("open", 1.0), ("paid", 1.0)] diff --git a/tests/test_relationships.py b/tests/test_relationships.py index db311a05..f8ead4c4 100644 --- a/tests/test_relationships.py +++ b/tests/test_relationships.py @@ -1,6 +1,8 @@ """Test relationship property methods and edge cases.""" +from sidemantic.core.model import Model from sidemantic.core.relationship import Relationship +from sidemantic.core.semantic_graph import SemanticGraph def test_relationship_sql_expr_with_explicit_foreign_key(): @@ -64,3 +66,99 @@ def test_relationship_all_fields(): assert rel.primary_key == "organization_id" assert rel.sql_expr == "org_id" assert rel.related_key == "organization_id" + + +def test_relationship_default_column_lists_match_native_contract(): + many_to_one = Relationship(name="customers", type="many_to_one") + assert many_to_one.foreign_key_columns == ["customers_id"] + assert many_to_one.primary_key_columns == ["id"] + + one_to_many = Relationship(name="orders", type="one_to_many") + assert one_to_many.foreign_key_columns == ["id"] + assert one_to_many.primary_key_columns == ["id"] + + one_to_one = Relationship(name="profile", type="one_to_one") + assert one_to_one.foreign_key_columns == ["id"] + assert one_to_one.primary_key_columns == ["id"] + + +def test_graph_many_to_one_omitted_keys_use_name_id_and_target_primary_key(): + graph = SemanticGraph() + graph.add_model( + Model( + name="orders", + table="orders", + primary_key="order_id", + relationships=[Relationship(name="customers", type="many_to_one")], + ) + ) + graph.add_model(Model(name="customers", table="customers", primary_key="customer_uid")) + + path = graph.find_relationship_path("orders", "customers") + + assert [(step.from_columns, step.to_columns, step.relationship) for step in path] == [ + (["customers_id"], ["customer_uid"], "many_to_one") + ] + + +def test_graph_one_to_many_omitted_keys_default_to_id_columns(): + graph = SemanticGraph() + graph.add_model( + Model( + name="customers", + table="customers", + primary_key="id", + relationships=[Relationship(name="orders", type="one_to_many")], + ) + ) + graph.add_model(Model(name="orders", table="orders", primary_key="id")) + + path = graph.find_relationship_path("customers", "orders") + + assert [(step.from_columns, step.to_columns, step.relationship) for step in path] == [ + (["id"], ["id"], "one_to_many") + ] + + +def test_graph_one_to_one_omitted_keys_default_to_id_columns(): + graph = SemanticGraph() + graph.add_model( + Model( + name="users", + table="users", + primary_key="id", + relationships=[Relationship(name="profiles", type="one_to_one")], + ) + ) + graph.add_model(Model(name="profiles", table="profiles", primary_key="id")) + + path = graph.find_relationship_path("users", "profiles") + + assert [(step.from_columns, step.to_columns, step.relationship) for step in path] == [ + (["id"], ["id"], "one_to_one") + ] + + +def test_graph_explicit_foreign_key_omitted_primary_key_uses_target_primary_key(): + graph = SemanticGraph() + graph.add_model( + Model( + name="invoices", + table="invoices", + primary_key="invoice_id", + relationships=[ + Relationship( + name="vendors", + type="many_to_one", + foreign_key="vendor_ref", + ) + ], + ) + ) + graph.add_model(Model(name="vendors", table="vendors", primary_key="vendor_uid")) + + path = graph.find_relationship_path("invoices", "vendors") + + assert [(step.from_columns, step.to_columns, step.relationship) for step in path] == [ + (["vendor_ref"], ["vendor_uid"], "many_to_one") + ] diff --git a/tests/test_validation.py b/tests/test_validation.py index c0ee0cbd..8d9dbce8 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -136,6 +136,44 @@ def test_query_validation_metric_not_found(layer): assert "Metric 'nonexistent_metric' not found" in str(exc_info.value) +def test_query_validation_accepts_multidot_graph_metric_name(layer): + """Exact graph metric names with multiple dots should not be split as model.metric.""" + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="id", + dimensions=[Dimension(name="status", type="categorical", sql="status")], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + layer.add_metric(Metric(name="company.sales.revenue", sql="orders.revenue")) + + sql = layer.compile(metrics=["company.sales.revenue"], dimensions=["orders.status"]) + + assert '"company.sales.revenue"' in sql + assert "amount AS revenue_raw" in sql + + +def test_graph_metric_exact_name_wins_over_model_metric_reference(layer): + """Exact graph metric names should resolve before model.metric interpretation.""" + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="id", + dimensions=[Dimension(name="status", type="categorical", sql="status")], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + layer.add_metric(Metric(name="orders.revenue", sql="SUM(orders.amount) * 2")) + + sql = layer.compile(metrics=["orders.revenue"], dimensions=["orders.status"]) + + assert 'AS "orders.revenue"' in sql + assert "* 2" in sql + + def test_query_validation_dimension_not_found(layer): """Test that queries with non-existent dimensions fail validation.""" layer.add_model( From af0f21883a11eb942f58d7ce709c7bf554e3861c Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 15:05:18 -0700 Subject: [PATCH 02/13] Project custom join predicate columns --- sidemantic/sql/generator.py | 39 ++++++++++++++++++++++++ tests/queries/test_basic.py | 60 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/sidemantic/sql/generator.py b/sidemantic/sql/generator.py index 239be481..d621bc78 100644 --- a/sidemantic/sql/generator.py +++ b/sidemantic/sql/generator.py @@ -309,6 +309,42 @@ def _custom_join_condition(self, join_path) -> str: to_alias = self._quote_identifier(self._cte_name(join_path.to_model)) return join_path.custom_condition.replace("{from}", from_alias).replace("{to}", to_alias) + def _custom_join_columns(self, join_path) -> dict[str, set[str]]: + """Extract raw columns that a custom join predicate reads from each side.""" + if not join_path.custom_condition: + return {} + + from_marker = "__from__" + to_marker = "__to__" + condition = join_path.custom_condition.replace("{from}", from_marker).replace("{to}", to_marker) + try: + parsed = sqlglot.parse_one(condition, dialect=self.dialect) + except Exception as exc: + raise ValueError( + "Could not parse custom relationship SQL for " + f"{join_path.from_model} -> {join_path.to_model}: {join_path.custom_condition}" + ) from exc + + columns: dict[str, set[str]] = {join_path.from_model: set(), join_path.to_model: set()} + for column in parsed.find_all(exp.Column): + if column.table == from_marker: + columns[join_path.from_model].add(column.name) + elif column.table == to_marker: + columns[join_path.to_model].add(column.name) + + return {model_name: cols for model_name, cols in columns.items() if cols} + + def _custom_join_columns_by_model(self, base_model_name: str, other_models: list[str]) -> dict[str, set[str]]: + columns_by_model: dict[str, set[str]] = {} + for other_model in other_models: + join_path = self.graph.find_relationship_path(base_model_name, other_model) + if not join_path: + continue + for join_step in join_path: + for model_name, columns in self._custom_join_columns(join_step).items(): + columns_by_model.setdefault(model_name, set()).update(columns) + return columns_by_model + def _apply_default_time_dimensions(self, metrics: list[str], dimensions: list[str]) -> list[str]: """Auto-include default_time_dimension from models if not already present. @@ -669,6 +705,9 @@ def metric_needs_window(m): # Extract columns needed for metric-level filters (before building CTEs) metric_filter_cols_by_model = self._extract_metric_filter_columns(metrics) + custom_join_cols_by_model = self._custom_join_columns_by_model(base_model_name, model_names[1:]) + for model_name, column_names in custom_join_cols_by_model.items(): + metric_filter_cols_by_model.setdefault(model_name, set()).update(column_names) # Ensure dimensions referenced in outer-query filters (e.g. window dims) # are included in the relevant CTE SELECT lists. diff --git a/tests/queries/test_basic.py b/tests/queries/test_basic.py index 0553a6d2..1c020b7c 100644 --- a/tests/queries/test_basic.py +++ b/tests/queries/test_basic.py @@ -399,6 +399,66 @@ def test_no_prefix_when_no_collision(layer): assert "AS customers_customer_name" not in sql +def test_custom_join_sql_projects_extra_predicate_columns(): + conn = duckdb.connect(":memory:") + conn.execute(""" + CREATE TABLE orders ( + order_id INTEGER, + customer_id INTEGER, + amount INTEGER + ) + """) + conn.execute(""" + CREATE TABLE customers ( + customer_id INTEGER, + country VARCHAR, + valid_to DATE + ) + """) + conn.execute("INSERT INTO orders VALUES (1, 100, 50)") + conn.execute(""" + INSERT INTO customers VALUES + (100, 'US', NULL), + (100, 'Expired', DATE '2024-01-01') + """) + + layer = SemanticLayer() + layer.conn = conn + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="order_id", + relationships=[ + Relationship( + name="customers", + type="many_to_one", + foreign_key="customer_id", + sql="{from}.customer_id = {to}.customer_id AND {to}.valid_to IS NULL", + ) + ], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + layer.add_model( + Model( + name="customers", + table="customers", + primary_key="customer_id", + dimensions=[Dimension(name="country", type="categorical")], + ) + ) + + sql = layer.compile(metrics=["orders.revenue"], dimensions=["customers.country"], order_by=["customers.country"]) + assert "valid_to AS valid_to" in sql + assert "customers_cte.valid_to IS NULL" in sql + + rows = df_rows( + layer.query(metrics=["orders.revenue"], dimensions=["customers.country"], order_by=["customers.country"]) + ) + assert rows == [("Expired", None), ("US", 50)] + + def test_count_distinct_without_sql_uses_primary_key(layer): """Test that count_distinct without sql field uses primary key. From c246c6682af0945278598c08badade12e398103f Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 15:35:29 -0700 Subject: [PATCH 03/13] Accept legacy native metric dependencies --- sidemantic-rs/src/config/schema.rs | 23 ++++++++++++++++++ sidemantic/adapters/sidemantic.py | 1 + .../sidemantic_adapter/test_parsing.py | 24 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/sidemantic-rs/src/config/schema.rs b/sidemantic-rs/src/config/schema.rs index 96ff4243..71fd5b7e 100644 --- a/sidemantic-rs/src/config/schema.rs +++ b/sidemantic-rs/src/config/schema.rs @@ -134,6 +134,8 @@ pub struct MetricConfig { pub agg: Option, #[serde(default, alias = "expr", alias = "measure")] pub sql: Option, + #[serde(default, rename = "metrics", skip_serializing_if = "Option::is_none")] + _legacy_metric_dependencies: Option>, pub numerator: Option, pub denominator: Option, pub offset_window: Option, @@ -1359,6 +1361,27 @@ models: ); } + #[test] + fn test_native_yaml_accepts_legacy_metric_dependencies() { + let yaml = r#" +version: 1 +metrics: + - name: revenue_per_order + type: derived + sql: revenue / order_count + metrics: + - revenue + - order_count +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (_, metrics, _) = config.into_parts().unwrap(); + + assert_eq!(metrics.len(), 1); + assert_eq!(metrics[0].name, "revenue_per_order"); + assert_eq!(metrics[0].sql.as_deref(), Some("revenue / order_count")); + } + #[test] fn test_native_yaml_rejects_auto_dimensions_true() { let yaml = r#" diff --git a/sidemantic/adapters/sidemantic.py b/sidemantic/adapters/sidemantic.py index ab045ba4..1193d10a 100644 --- a/sidemantic/adapters/sidemantic.py +++ b/sidemantic/adapters/sidemantic.py @@ -81,6 +81,7 @@ "sql", "expr", "measure", + "metrics", "numerator", "denominator", "offset_window", diff --git a/tests/adapters/sidemantic_adapter/test_parsing.py b/tests/adapters/sidemantic_adapter/test_parsing.py index 2283af21..176b4dc7 100644 --- a/tests/adapters/sidemantic_adapter/test_parsing.py +++ b/tests/adapters/sidemantic_adapter/test_parsing.py @@ -138,6 +138,30 @@ def test_parse_native_yaml_accepts_compatibility_aliases(tmp_path): assert orders.metrics[1].sql == "total_revenue / order_count" +def test_parse_native_yaml_accepts_legacy_metric_dependencies(tmp_path): + """Legacy exported derived metrics used `metrics` for dependency hints.""" + adapter = SidemanticAdapter() + yaml_path = tmp_path / "metrics.yml" + yaml_path.write_text( + """ +version: 1 +metrics: + - name: revenue_per_order + type: derived + sql: total_revenue / order_count + metrics: + - total_revenue + - order_count +""" + ) + + graph = adapter.parse(yaml_path) + + metric = graph.metrics["revenue_per_order"] + assert metric.type == "derived" + assert metric.sql == "total_revenue / order_count" + + @pytest.mark.parametrize( ("yaml_body", "error_text"), [ From 9b8531dc49fda85aa415ea58413551156a9ddc6a Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 16:23:38 -0700 Subject: [PATCH 04/13] Fix dotted graph metric SQL generation --- sidemantic-rs/src/core/graph.rs | 63 +++++++++ sidemantic-rs/src/sql/generator.rs | 217 +++++++++++++++++++++-------- sidemantic/sql/generator.py | 38 +++-- tests/queries/test_basic.py | 41 ++++++ 4 files changed, 296 insertions(+), 63 deletions(-) diff --git a/sidemantic-rs/src/core/graph.rs b/sidemantic-rs/src/core/graph.rs index 30028a1a..a0a59a57 100644 --- a/sidemantic-rs/src/core/graph.rs +++ b/sidemantic-rs/src/core/graph.rs @@ -308,6 +308,12 @@ impl SemanticGraph { ))); } + if Self::metric_uses_inline_aggregation(metric) + && self.inline_aggregate_column_dependency_exists(&dependency) + { + continue; + } + if self.metric_dependency_exists(&dependency)? { continue; } @@ -339,6 +345,63 @@ impl SemanticGraph { .any(|model| model.get_metric(dependency).is_some())) } + fn metric_uses_inline_aggregation(metric: &Metric) -> bool { + metric.r#type == MetricType::Derived + && metric + .sql + .as_deref() + .is_some_and(Self::sql_has_inline_aggregation) + } + + fn sql_has_inline_aggregation(sql: &str) -> bool { + let lower = sql.to_ascii_lowercase(); + let bytes = lower.as_bytes(); + let aggregate_names = [ + "sum", + "avg", + "count", + "min", + "max", + "median", + "stddev", + "stddev_pop", + "variance", + "variance_pop", + ]; + + for name in aggregate_names { + let mut start = 0; + while let Some(offset) = lower[start..].find(name) { + let name_start = start + offset; + let name_end = name_start + name.len(); + let before_is_ident = name_start > 0 + && (bytes[name_start - 1].is_ascii_alphanumeric() + || bytes[name_start - 1] == b'_'); + let after_is_ident = name_end < bytes.len() + && (bytes[name_end].is_ascii_alphanumeric() || bytes[name_end] == b'_'); + if before_is_ident || after_is_ident { + start = name_end; + continue; + } + + if lower[name_end..].trim_start().starts_with('(') { + return true; + } + start = name_end; + } + } + + false + } + + fn inline_aggregate_column_dependency_exists(&self, dependency: &str) -> bool { + if let Some((model_name, _)) = dependency.rsplit_once('.') { + return self.models.contains_key(model_name); + } + + self.models.len() == 1 + } + /// Get a graph-level metric by name. pub fn get_metric(&self, name: &str) -> Option<&Metric> { self.metrics diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index 17b12dc1..fb382533 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -308,11 +308,8 @@ impl<'a> SqlGenerator<'a> { let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); SidemanticError::model_not_found(&model_name, &available) })?; - let metric = model.get_metric(&metric_name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&model_name, &metric_name, &available) - })?; - let raw_alias = format!("{metric_name}_raw"); + let metric = self.metric_for_model(&model_name, &metric_name)?; + let raw_alias = self.metric_raw_alias(model, &metric_name, metric); let mut raw_expr = self.normalize_cte_source_expression(&self.metric_raw_expression(metric, model)); if !metric.filters.is_empty() { @@ -457,16 +454,13 @@ impl<'a> SqlGenerator<'a> { let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); SidemanticError::model_not_found(&metric_ref.model, &available) })?; - let metric = model.get_metric(&metric_ref.name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&metric_ref.model, &metric_ref.name, &available) - })?; + let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; let alias = self.model_alias(&metric_ref.model); let use_symmetric = fan_out_at_risk.contains(&metric_ref.model); let output_alias = self.output_alias(&metric_ref.model, &metric_ref.alias, &alias_collisions); - let raw_alias = format!("{}_raw", metric_ref.name); + let raw_alias = self.metric_raw_alias(model, &metric_ref.name, metric); let raw_col = format!("{alias}.{}", self.quote_identifier(&raw_alias)); let sql_expr = match metric.r#type { @@ -793,12 +787,136 @@ impl<'a> SqlGenerator<'a> { } match owners.len() { - 0 => Ok(None), - 1 => Ok(Some((owners[0].clone(), reference.to_string()))), - _ => Err(SidemanticError::InvalidReference { + 0 => {} + 1 => return Ok(Some((owners[0].clone(), reference.to_string()))), + _ => { + return Err(SidemanticError::InvalidReference { + reference: reference.to_string(), + }); + } + } + + if let Some(metric) = self.graph.get_metric(reference) { + let graph_metric_owners = self.graph_metric_owner_models(reference, metric)?; + return match graph_metric_owners.len() { + 0 => Ok(None), + 1 => Ok(Some(( + graph_metric_owners[0].clone(), + reference.to_string(), + ))), + _ => Err(SidemanticError::InvalidReference { + reference: reference.to_string(), + }), + }; + } + + Ok(None) + } + + fn graph_metric_owner_models(&self, reference: &str, metric: &Metric) -> Result> { + let mut owners = HashSet::new(); + for model in self.graph.models() { + if model.get_metric(reference).is_some() { + owners.insert(model.name.clone()); + } + } + + if owners.is_empty() { + for fragment in [ + metric.sql.as_deref(), + metric.numerator.as_deref(), + metric.denominator.as_deref(), + metric.base_metric.as_deref(), + metric.entity.as_deref(), + metric.base_event.as_deref(), + metric.conversion_event.as_deref(), + metric.cohort_event.as_deref(), + metric.activity_event.as_deref(), + metric.having.as_deref(), + ] + .into_iter() + .flatten() + { + self.collect_owner_models_from_fragment(fragment, &mut owners); + } + + if let Some(steps) = metric.steps.as_ref() { + for step in steps { + self.collect_owner_models_from_fragment(step, &mut owners); + } + } + if let Some(inner_metrics) = metric.inner_metrics.as_ref() { + for inner_metric in inner_metrics { + if let Some(sql) = inner_metric.sql.as_deref() { + self.collect_owner_models_from_fragment(sql, &mut owners); + } + } + } + if let Some(entity_dimensions) = metric.entity_dimensions.as_ref() { + for dimension in entity_dimensions { + self.collect_owner_models_from_fragment(dimension, &mut owners); + } + } + } + + if owners.is_empty() { + let mut model_names: Vec = self + .graph + .models() + .map(|model| model.name.clone()) + .collect(); + if model_names.len() == 1 { + owners.insert(model_names.pop().expect("single model name")); + } + } + + let mut owners: Vec = owners.into_iter().collect(); + owners.sort(); + if owners.len() > 1 { + return Err(SidemanticError::InvalidReference { reference: reference.to_string(), - }), + }); + } + Ok(owners) + } + + fn collect_owner_models_from_fragment(&self, fragment: &str, owners: &mut HashSet) { + let model_ref_re = + regex::Regex::new(r"\b([A-Za-z_][A-Za-z0-9_]*)\.([A-Za-z_][A-Za-z0-9_]*)\b") + .expect("valid model reference regex"); + for cap in model_ref_re.captures_iter(fragment) { + let Some(model_match) = cap.get(1) else { + continue; + }; + let model_name = model_match.as_str(); + if self.graph.get_model(model_name).is_some() { + owners.insert(model_name.to_string()); + } + } + } + + fn metric_for_model(&self, model_name: &str, metric_name: &str) -> Result<&Metric> { + let model = self.graph.get_model(model_name).ok_or_else(|| { + let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); + SidemanticError::model_not_found(model_name, &available) + })?; + if let Some(metric) = model.get_metric(metric_name) { + return Ok(metric); + } + + if let Some(metric) = self.graph.get_metric(metric_name) { + let owners = self.graph_metric_owner_models(metric_name, metric)?; + if owners.iter().any(|owner| owner == model_name) { + return Ok(metric); + } } + + let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); + Err(SidemanticError::metric_not_found( + model_name, + metric_name, + &available, + )) } /// Find all models required by the query @@ -852,14 +970,7 @@ impl<'a> SqlGenerator<'a> { return Ok(()); } - let model = self.graph.get_model(&metric_ref.model).ok_or_else(|| { - let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); - SidemanticError::model_not_found(&metric_ref.model, &available) - })?; - let metric = model.get_metric(&metric_ref.name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&metric_ref.model, &metric_ref.name, &available) - })?; + let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; let exprs: Vec<&str> = [ metric.sql.as_deref(), @@ -991,6 +1102,29 @@ impl<'a> SqlGenerator<'a> { } } + fn metric_raw_alias(&self, model: &Model, metric_name: &str, metric: &Metric) -> String { + if metric_name.contains('.') && metric.r#type == MetricType::Simple { + if let Some(column_name) = self.simple_metric_source_column(model, metric) { + return column_name; + } + } + format!("{metric_name}_raw") + } + + fn simple_metric_source_column(&self, model: &Model, metric: &Metric) -> Option { + let sql = metric.sql.as_deref()?.trim(); + if Self::is_simple_identifier(sql) { + return Some(sql.to_string()); + } + + let (model_name, field_name) = sql.split_once('.')?; + if model_name == model.name && Self::is_simple_identifier(field_name) { + return Some(field_name.to_string()); + } + + None + } + fn collect_simple_metric_dependencies( &self, metric_ref: &MetricRef, @@ -1002,14 +1136,7 @@ impl<'a> SqlGenerator<'a> { return Ok(()); } - let model = self.graph.get_model(&metric_ref.model).ok_or_else(|| { - let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); - SidemanticError::model_not_found(&metric_ref.model, &available) - })?; - let metric = model.get_metric(&metric_ref.name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&metric_ref.model, &metric_ref.name, &available) - })?; + let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; match metric.r#type { MetricType::Simple => { @@ -1090,14 +1217,7 @@ impl<'a> SqlGenerator<'a> { return Ok(()); } - let model = self.graph.get_model(&metric_ref.model).ok_or_else(|| { - let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); - SidemanticError::model_not_found(&metric_ref.model, &available) - })?; - let metric = model.get_metric(&metric_ref.name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&metric_ref.model, &metric_ref.name, &available) - })?; + let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; match metric.r#type { MetricType::Derived if Self::is_inline_aggregate_expression(metric.sql_expr()) => { @@ -1356,14 +1476,7 @@ impl<'a> SqlGenerator<'a> { fn has_cumulative_metrics(&self, metric_refs: &[MetricRef]) -> Result { for metric_ref in metric_refs { - let model = self.graph.get_model(&metric_ref.model).ok_or_else(|| { - let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); - SidemanticError::model_not_found(&metric_ref.model, &available) - })?; - let metric = model.get_metric(&metric_ref.name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&metric_ref.model, &metric_ref.name, &available) - })?; + let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; if metric.r#type == MetricType::Cumulative || metric.r#type == MetricType::TimeComparison || metric.r#type == MetricType::Conversion @@ -4029,7 +4142,8 @@ impl<'a> SqlGenerator<'a> { metric_name: &str, alias: &str, ) -> String { - let raw_col = format!("{alias}.{metric_name}_raw"); + let raw_alias = format!("{metric_name}_raw"); + let raw_col = format!("{alias}.{}", self.quote_identifier(&raw_alias)); match metric.agg.as_ref() { Some(Aggregation::CountDistinct) => format!("COUNT(DISTINCT {raw_col})"), Some(Aggregation::Count) => format!("COUNT({raw_col})"), @@ -4078,14 +4192,7 @@ impl<'a> SqlGenerator<'a> { return Ok(None); } - let model = self.graph.get_model(&model_name).ok_or_else(|| { - let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); - SidemanticError::model_not_found(&model_name, &available) - })?; - let metric = model.get_metric(&metric_name).ok_or_else(|| { - let available: Vec<&str> = model.metrics.iter().map(|m| m.name.as_str()).collect(); - SidemanticError::metric_not_found(&model_name, &metric_name, &available) - })?; + let metric = self.metric_for_model(&model_name, &metric_name)?; let alias = self.model_alias(&model_name); let expanded = match metric.r#type { diff --git a/sidemantic/sql/generator.py b/sidemantic/sql/generator.py index d621bc78..7702324d 100644 --- a/sidemantic/sql/generator.py +++ b/sidemantic/sql/generator.py @@ -1464,7 +1464,9 @@ def collect_measures_from_metric(metric_ref: str, visited: set[str] | None = Non if resolved_metric and ref_model_name is None: for dep in resolved_metric.get_dependencies(self.graph, model_name): collect_measures_from_metric(dep, visited) - if resolved_metric.sql and sql_has_aggregate(resolved_metric.sql, self.dialect): + if resolved_metric.sql and ( + resolved_metric.agg or sql_has_aggregate(resolved_metric.sql, self.dialect) + ): collect_sql_columns_for_model(resolved_metric.sql) return @@ -2297,6 +2299,7 @@ def _build_main_select( # Build SELECT columns select_exprs = [] + output_aliases: dict[str, str] = {} # Add dimensions for dim_ref, gran in parsed_dims: @@ -2305,17 +2308,20 @@ def _build_main_select( # Check for custom alias first full_ref = f"{model_name}.{dim_name}__{gran}" if gran else dim_ref + base_alias = f"{dim_name}__{gran}" if gran else dim_name if full_ref in aliases: alias = aliases[full_ref] else: # Generate alias (with model prefix if collision) - base_alias = f"{dim_name}__{gran}" if gran else dim_name if has_collision.get(base_alias, False): alias = f"{model_name}_{base_alias}" else: alias = base_alias select_exprs.append(f"{self._cte_ref(model_name, cte_col_name)} AS {self._quote_alias(alias)}") + output_aliases[full_ref] = alias + output_aliases[dim_ref] = alias + output_aliases[base_alias] = alias previous_bsl_all_context = getattr(self, "_bsl_all_query_context", None) self._bsl_all_query_context = { @@ -2338,6 +2344,8 @@ def _build_main_select( metric_expr = self._wrap_with_fill_nulls(metric_expr, resolved_metric) alias = aliases.get(metric_ref, resolved_metric.name) select_exprs.append(f"{metric_expr} AS {self._quote_alias(alias)}") + output_aliases[metric_ref] = alias + output_aliases[resolved_metric.name] = alias elif resolved_metric and resolved_model_name: # It's a measure reference (model.measure) model_name = resolved_model_name @@ -2355,6 +2363,9 @@ def _build_main_select( else: alias = measure_name + output_aliases[metric_ref] = alias + output_aliases[measure_name] = alias + # Complex metric types (derived, ratio) can be built inline # Note: cumulative, time_comparison, conversion are handled via special query generators # and won't appear in this code path @@ -2415,6 +2426,8 @@ def _build_main_select( metric_expr = self._build_metric_sql(metric) metric_expr = self._wrap_with_fill_nulls(metric_expr, metric) select_exprs.append(f"{metric_expr} AS {self._quote_alias(metric.name)}") + output_aliases[metric_ref] = metric.name + output_aliases[metric.name] = metric.name else: # It's a metric reference (just metric name) metric = self.graph.get_metric(metric_ref) @@ -2422,6 +2435,8 @@ def _build_main_select( metric_expr = self._build_metric_sql(metric) metric_expr = self._wrap_with_fill_nulls(metric_expr, metric) select_exprs.append(f"{metric_expr} AS {self._quote_alias(metric.name)}") + output_aliases[metric_ref] = metric.name + output_aliases[metric.name] = metric.name else: raise ValueError(f"Metric {metric_ref} not found") @@ -2475,15 +2490,22 @@ def replace_metric_ref(match): # Add ORDER BY if order_by: - # Strip model prefixes from order_by fields to use column aliases order_by_aliases = [] for field in order_by: - if "." in field: - # Extract just the field name (with optional granularity) - field_alias = field.split(".", 1)[1] + parts = field.rsplit(" ", 1) + direction = "" + field_ref = field + if len(parts) == 2 and parts[1].upper() in {"ASC", "DESC"}: + field_ref = parts[0] + direction = f" {parts[1].upper()}" + + if field_ref in output_aliases: + field_alias = self._quote_alias(output_aliases[field_ref]) + elif "." in field_ref: + field_alias = field_ref.split(".", 1)[1] else: - field_alias = field - order_by_aliases.append(field_alias) + field_alias = field_ref + order_by_aliases.append(f"{field_alias}{direction}") query = query.order_by(*order_by_aliases) # Add LIMIT and OFFSET diff --git a/tests/queries/test_basic.py b/tests/queries/test_basic.py index 1c020b7c..34837f75 100644 --- a/tests/queries/test_basic.py +++ b/tests/queries/test_basic.py @@ -459,6 +459,47 @@ def test_custom_join_sql_projects_extra_predicate_columns(): assert rows == [("Expired", None), ("US", 50)] +def test_dotted_graph_metric_projects_sql_column_and_orders_by_alias(layer): + layer.conn.execute("CREATE TABLE events (event_id INTEGER, status VARCHAR, latency INTEGER)") + layer.conn.execute( + """ + INSERT INTO events VALUES + (1, 'ok', 100), + (2, 'ok', 250), + (3, 'slow', 400) + """ + ) + + layer.add_model( + Model( + name="events", + table="events", + primary_key="event_id", + dimensions=[Dimension(name="status", type="categorical")], + ) + ) + layer.add_metric(Metric(name="events.p95.latency", agg="max", sql="latency")) + + sql = layer.compile( + metrics=["events.p95.latency"], + dimensions=["events.status"], + order_by=["events.p95.latency DESC"], + ) + + assert "latency AS latency" in sql + assert "ORDER BY" in sql + assert '"events.p95.latency" DESC' in sql + + rows = df_rows( + layer.query( + metrics=["events.p95.latency"], + dimensions=["events.status"], + order_by=["events.p95.latency DESC"], + ) + ) + assert rows == [("slow", 400), ("ok", 250)] + + def test_count_distinct_without_sql_uses_primary_key(layer): """Test that count_distinct without sql field uses primary key. From c259a04105f02b966ba1fea551cb7a75f52bd54c Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 17:05:23 -0700 Subject: [PATCH 05/13] Defer native metric inheritance until model merge --- sidemantic/adapters/sidemantic.py | 12 ++++++---- sidemantic/core/inheritance.py | 9 +++++++ sidemantic/loaders.py | 3 ++- tests/test_loaders.py | 39 +++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/sidemantic/adapters/sidemantic.py b/sidemantic/adapters/sidemantic.py index 1193d10a..72d00091 100644 --- a/sidemantic/adapters/sidemantic.py +++ b/sidemantic/adapters/sidemantic.py @@ -422,7 +422,11 @@ def _parse_embedded_sql_definitions( raise ValueError(f"{location}invalid {scope}: {exc}") from exc def _resolve_inheritance(self, graph: SemanticGraph) -> None: - from sidemantic.core.inheritance import resolve_metric_inheritance, resolve_model_inheritance + from sidemantic.core.inheritance import ( + resolve_metric_inheritance, + resolve_model_inheritance, + resolve_model_metric_inheritance, + ) if any(model.extends for model in graph.models.values()): missing_parent = any(model.extends and model.extends not in graph.models for model in graph.models.values()) @@ -431,9 +435,9 @@ def _resolve_inheritance(self, graph: SemanticGraph) -> None: graph._mark_dirty() for model in graph.models.values(): - if any(metric.extends for metric in model.metrics): - resolved_metrics = resolve_metric_inheritance({metric.name: metric for metric in model.metrics}) - model.metrics = list(resolved_metrics.values()) + if model.extends and model.extends not in graph.models: + continue + resolve_model_metric_inheritance(model) if any(metric.extends for metric in graph.metrics.values()): graph.metrics = resolve_metric_inheritance(graph.metrics) diff --git a/sidemantic/core/inheritance.py b/sidemantic/core/inheritance.py index f6cfab86..27985e47 100644 --- a/sidemantic/core/inheritance.py +++ b/sidemantic/core/inheritance.py @@ -230,3 +230,12 @@ def resolve(name: str) -> Metric: resolve(name) return resolved + + +def resolve_model_metric_inheritance(model: Model) -> None: + """Resolve inheritance between metrics declared on a single model.""" + if not any(metric.extends for metric in model.metrics): + return + + resolved_metrics = resolve_metric_inheritance({metric.name: metric for metric in model.metrics}) + model.metrics = list(resolved_metrics.values()) diff --git a/sidemantic/loaders.py b/sidemantic/loaders.py index eddbbd9c..a90d7302 100644 --- a/sidemantic/loaders.py +++ b/sidemantic/loaders.py @@ -472,7 +472,7 @@ def _resolve_native_model_inheritance(all_models: dict, *, strict: bool) -> None if not native_children: return - from sidemantic.core.inheritance import merge_model + from sidemantic.core.inheritance import merge_model, resolve_model_metric_inheritance resolved = {} resolving = set() @@ -513,6 +513,7 @@ def resolve(name: str): return None merged = _run_without_auto_registration(merge_model, model, parent) + _run_without_auto_registration(resolve_model_metric_inheritance, merged) _copy_model_source_attrs(model, merged) resolved[name] = merged all_models[name] = merged diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 8c811d29..0d0bf069 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -127,6 +127,45 @@ def test_load_from_directory_resolves_native_inheritance_across_files(tmp_path): assert paid_orders.get_metric("revenue") is not None +def test_load_from_directory_resolves_native_metric_inheritance_after_model_merge(tmp_path): + (tmp_path / "base.yml").write_text( + """ +version: 1 +models: + - name: base_orders + table: orders + primary_key: id + metrics: + - name: revenue + agg: sum + sql: amount +""" + ) + (tmp_path / "child.yml").write_text( + """ +version: 1 +models: + - name: paid_orders + extends: base_orders + metrics: + - name: paid_revenue + extends: revenue + filters: + - status = 'paid' +""" + ) + + layer = SemanticLayer() + load_from_directory(layer, tmp_path) + + paid_revenue = layer.graph.models["paid_orders"].get_metric("paid_revenue") + assert paid_revenue is not None + assert paid_revenue.extends is None + assert paid_revenue.agg == "sum" + assert paid_revenue.sql == "amount" + assert paid_revenue.filters == ["status = 'paid'"] + + def test_load_from_directory_strict_raises_on_missing_native_parent(tmp_path): (tmp_path / "child.yml").write_text( """ From 96e167f987da836c86270749f57a339c17332fdc Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 17:50:01 -0700 Subject: [PATCH 06/13] Detect native metrics-only YAML files --- sidemantic/loaders.py | 9 +++++++++ tests/test_loaders.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/sidemantic/loaders.py b/sidemantic/loaders.py index a90d7302..213e379a 100644 --- a/sidemantic/loaders.py +++ b/sidemantic/loaders.py @@ -130,6 +130,8 @@ def load_from_directory(layer: "SemanticLayer", directory: str | Path, *, strict # Check for Sidemantic native format (explicit models: key) elif _yaml_has_top_level_key(yaml_data, "models"): adapter = SidemanticAdapter() + elif _looks_like_native_sidemantic_yaml(yaml_data): + adapter = SidemanticAdapter() elif _yaml_has_top_level_key(yaml_data, "metrics") and "type: " in content: adapter = MetricFlowAdapter() elif _contains_yaml_key(yaml_data, "base_sql_table") and _contains_yaml_key(yaml_data, "measures"): @@ -341,6 +343,13 @@ def _looks_like_semantic_yaml_text(content: str) -> bool: return any(line.lstrip().startswith(prefixes) for line in content.splitlines()) +def _looks_like_native_sidemantic_yaml(data: dict) -> bool: + """Return True for explicit native Sidemantic YAML files without models.""" + if not isinstance(data, dict) or data.get("version") != 1: + return False + return any(_yaml_has_top_level_key(data, key) for key in ("metrics", "parameters", "sql_metrics", "sql_segments")) + + def _yaml_has_top_level_key(data: dict, key: str) -> bool: """Return True when a YAML mapping has an exact top-level key.""" return isinstance(data, dict) and key in data diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 0d0bf069..0e3078d5 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -166,6 +166,41 @@ def test_load_from_directory_resolves_native_metric_inheritance_after_model_merg assert paid_revenue.filters == ["status = 'paid'"] +def test_load_from_directory_detects_native_metrics_only_file(tmp_path): + (tmp_path / "models.yml").write_text( + """ +version: 1 +models: + - name: orders + table: orders + primary_key: id + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count +""" + ) + (tmp_path / "metrics.yml").write_text( + """ +version: 1 +metrics: + - name: finance.revenue_per_order + type: ratio + numerator: orders.revenue + denominator: orders.order_count +""" + ) + + layer = SemanticLayer() + load_from_directory(layer, tmp_path) + + metric = layer.graph.metrics["finance.revenue_per_order"] + assert metric.numerator == "orders.revenue" + assert metric.denominator == "orders.order_count" + + def test_load_from_directory_strict_raises_on_missing_native_parent(tmp_path): (tmp_path / "child.yml").write_text( """ From 6db8e8fdf8d14fddf44adb6831f8953a71ad1119 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 18:00:35 -0700 Subject: [PATCH 07/13] Ignore root-only SQL frontmatter --- sidemantic/adapters/sidemantic.py | 3 +++ tests/core/test_sql_definitions.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/sidemantic/adapters/sidemantic.py b/sidemantic/adapters/sidemantic.py index 72d00091..8e8dfa85 100644 --- a/sidemantic/adapters/sidemantic.py +++ b/sidemantic/adapters/sidemantic.py @@ -264,6 +264,9 @@ def normalize_sql_frontmatter(frontmatter: dict) -> dict: validate_native_format_version(frontmatter) normalized = dict(frontmatter) normalized.pop("version", None) + normalized.pop("connection", None) + normalized.pop("models", None) + normalized.pop("parameters", None) return normalized diff --git a/tests/core/test_sql_definitions.py b/tests/core/test_sql_definitions.py index c4ddd58f..afee38af 100644 --- a/tests/core/test_sql_definitions.py +++ b/tests/core/test_sql_definitions.py @@ -391,6 +391,35 @@ def test_adapter_parse_sql_file(): temp_path.unlink(missing_ok=True) +def test_adapter_parse_sql_file_with_root_only_frontmatter(): + """Root-only native frontmatter should not be parsed as a model.""" + sql_content = """--- +version: 1 +connection: + type: duckdb +--- + +METRIC ( + name total_revenue, + sql orders.revenue +); +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f: + f.write(sql_content) + temp_path = Path(f.name) + + try: + adapter = SidemanticAdapter() + graph = adapter.parse(temp_path) + + assert not graph.models + assert "total_revenue" in graph.metrics + + finally: + temp_path.unlink(missing_ok=True) + + def test_yaml_with_embedded_sql_metrics(): """Test YAML file with embedded sql_metrics field.""" yaml_content = """ From f9765bbe6cde6c011b5c4c80e87e8461aa9e8b1c Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 18:10:50 -0700 Subject: [PATCH 08/13] Defer native graph metric inheritance --- sidemantic/adapters/sidemantic.py | 8 +++- sidemantic/loaders.py | 70 ++++++++++++++++++++++++++++++- tests/test_loaders.py | 48 +++++++++++++++++++++ 3 files changed, 122 insertions(+), 4 deletions(-) diff --git a/sidemantic/adapters/sidemantic.py b/sidemantic/adapters/sidemantic.py index 8e8dfa85..a9407245 100644 --- a/sidemantic/adapters/sidemantic.py +++ b/sidemantic/adapters/sidemantic.py @@ -443,8 +443,12 @@ def _resolve_inheritance(self, graph: SemanticGraph) -> None: resolve_model_metric_inheritance(model) if any(metric.extends for metric in graph.metrics.values()): - graph.metrics = resolve_metric_inheritance(graph.metrics) - graph._mark_dirty() + missing_parent = any( + metric.extends and metric.extends not in graph.metrics for metric in graph.metrics.values() + ) + if not missing_parent: + graph.metrics = resolve_metric_inheritance(graph.metrics) + graph._mark_dirty() def export(self, graph: SemanticGraph, output_path: str | Path) -> None: """Export semantic graph to Sidemantic YAML. diff --git a/sidemantic/loaders.py b/sidemantic/loaders.py index 213e379a..a8b89cf6 100644 --- a/sidemantic/loaders.py +++ b/sidemantic/loaders.py @@ -179,6 +179,11 @@ def load_from_directory(layer: "SemanticLayer", directory: str | Path, *, strict model._source_format = adapter_name if not hasattr(model, "_source_file"): model._source_file = str(file_path.relative_to(directory)) + for metric in graph.metrics.values(): + if not hasattr(metric, "_source_format"): + metric._source_format = adapter_name + if not hasattr(metric, "_source_file"): + metric._source_file = str(file_path.relative_to(directory)) all_models.update(graph.models) all_metrics.update(graph.metrics) all_parameters.update(graph.parameters) @@ -186,6 +191,7 @@ def load_from_directory(layer: "SemanticLayer", directory: str | Path, *, strict _handle_parse_error(file_path, e, strict=strict) _resolve_native_model_inheritance(all_models, strict=strict) + _resolve_native_metric_inheritance(all_metrics, strict=strict) # BSL files are parsed one at a time during auto-discovery. Finalize join # aliases after all files have been loaded so aliases can target models @@ -465,7 +471,7 @@ def _run_without_auto_registration(callback, *args): set_current_layer(previous_layer) -def _copy_model_source_attrs(source, target) -> None: +def _copy_source_attrs(source, target) -> None: for attr in ("_source_format", "_source_file"): if hasattr(source, attr): setattr(target, attr, getattr(source, attr)) @@ -523,7 +529,7 @@ def resolve(name: str): merged = _run_without_auto_registration(merge_model, model, parent) _run_without_auto_registration(resolve_model_metric_inheritance, merged) - _copy_model_source_attrs(model, merged) + _copy_source_attrs(model, merged) resolved[name] = merged all_models[name] = merged return merged @@ -532,6 +538,66 @@ def resolve(name: str): resolve(name) +def _resolve_native_metric_inheritance(all_metrics: dict, *, strict: bool) -> None: + """Resolve Sidemantic-native graph metric inheritance after directory-wide parsing.""" + native_children = { + name: metric + for name, metric in all_metrics.items() + if getattr(metric, "_source_format", None) == "Sidemantic" and metric.extends + } + if not native_children: + return + + from sidemantic.core.inheritance import merge_metric + + resolved = {} + resolving = set() + + def fail(message: str): + if strict: + raise ValueError(message) + logging.warning(message) + return None + + def resolve(name: str): + if name in resolved: + return resolved[name] + + metric = all_metrics.get(name) + if metric is None: + return fail(f"Native metric '{name}' not found") + + if name in resolving: + return fail(f"Circular native metric inheritance detected for metric '{name}'") + + if not metric.extends: + resolved[name] = metric + return metric + + parent = all_metrics.get(metric.extends) + if parent is None: + return fail(f"Native metric '{name}' extends unknown metric '{metric.extends}'") + + resolving.add(name) + try: + if parent.extends: + parent = resolve(metric.extends) + finally: + resolving.remove(name) + + if parent is None: + return None + + merged = _run_without_auto_registration(merge_metric, metric, parent) + _copy_source_attrs(metric, merged) + resolved[name] = merged + all_metrics[name] = merged + return merged + + for name in native_children: + resolve(name) + + def _try_load_python_file(file_path: Path, directory: Path, all_models: dict, *, strict: bool) -> bool: """Load semantic definitions from a Python file if it looks like Sidemantic code.""" if not _looks_like_python_semantic_definition(file_path): diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 0e3078d5..15d178c4 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -201,6 +201,54 @@ def test_load_from_directory_detects_native_metrics_only_file(tmp_path): assert metric.denominator == "orders.order_count" +def test_load_from_directory_resolves_native_graph_metric_inheritance_across_files(tmp_path): + (tmp_path / "base_metrics.yml").write_text( + """ +version: 1 +metrics: + - name: gross_revenue + agg: sum + sql: orders.amount +""" + ) + (tmp_path / "child_metrics.yml").write_text( + """ +version: 1 +metrics: + - name: paid_revenue + extends: gross_revenue + filters: + - orders.status = 'paid' +""" + ) + + layer = SemanticLayer() + load_from_directory(layer, tmp_path) + + metric = layer.graph.metrics["paid_revenue"] + assert metric.extends is None + assert metric.agg == "sum" + assert metric.sql == "orders.amount" + assert metric.filters == ["orders.status = 'paid'"] + + +def test_load_from_directory_strict_raises_on_missing_native_graph_metric_parent(tmp_path): + (tmp_path / "metrics.yml").write_text( + """ +version: 1 +metrics: + - name: paid_revenue + extends: missing_revenue + filters: + - orders.status = 'paid' +""" + ) + + layer = SemanticLayer() + with pytest.raises(ValueError, match="Native metric 'paid_revenue' extends unknown metric 'missing_revenue'"): + load_from_directory(layer, tmp_path) + + def test_load_from_directory_strict_raises_on_missing_native_parent(tmp_path): (tmp_path / "child.yml").write_text( """ From 0a94caf300b9572b26faa94f0d4d9eddf0d01c34 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 18:45:22 -0700 Subject: [PATCH 09/13] Prefer exact graph metrics in Rust queries --- sidemantic-rs/src/sql/generator.rs | 249 ++++++++++++++++++----------- 1 file changed, 155 insertions(+), 94 deletions(-) diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index fb382533..b6bdf1c8 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -121,6 +121,7 @@ struct MetricRef { model: String, name: String, alias: String, + graph_metric: bool, } /// SQL generator for semantic queries @@ -265,7 +266,7 @@ impl<'a> SqlGenerator<'a> { &mut HashSet::new(), )?; } - let mut raw_metric_dependencies: Vec<(String, String)> = + let mut raw_metric_dependencies: Vec<(String, String, bool)> = raw_metric_dependencies.into_iter().collect(); raw_metric_dependencies.sort(); let mut raw_column_dependencies: Vec<(String, String)> = @@ -303,12 +304,13 @@ impl<'a> SqlGenerator<'a> { .insert(dimension.name.clone()); } } - for (model_name, metric_name) in raw_metric_dependencies { + for (model_name, metric_name, graph_metric) in raw_metric_dependencies { let model = self.graph.get_model(&model_name).ok_or_else(|| { let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); SidemanticError::model_not_found(&model_name, &available) })?; - let metric = self.metric_for_model(&model_name, &metric_name)?; + let metric = + self.metric_for_model_with_source(&model_name, &metric_name, graph_metric)?; let raw_alias = self.metric_raw_alias(model, &metric_name, metric); let mut raw_expr = self.normalize_cte_source_expression(&self.metric_raw_expression(metric, model)); @@ -454,7 +456,7 @@ impl<'a> SqlGenerator<'a> { let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); SidemanticError::model_not_found(&metric_ref.model, &available) })?; - let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; + let metric = self.metric_for_ref(metric_ref)?; let alias = self.model_alias(&metric_ref.model); let use_symmetric = fan_out_at_risk.contains(&metric_ref.model); @@ -706,31 +708,33 @@ impl<'a> SqlGenerator<'a> { let mut refs = Vec::new(); for metric in metrics { - let (model, name) = if let Some((model, name)) = self.exact_metric_reference(metric)? { - (model, name) - } else if metric.contains('.') { - let (model, name, _) = self.graph.parse_reference(metric)?; - (model, name) - } else { - let mut owners = Vec::new(); - for model in self.graph.models() { - if model.get_metric(metric).is_some() { - owners.push(model.name.clone()); - } - } - if owners.len() == 1 { - (owners[0].clone(), metric.clone()) + let (model, name, graph_metric) = + if let Some((model, name, graph_metric)) = self.exact_metric_reference(metric)? { + (model, name, graph_metric) + } else if metric.contains('.') { + let (model, name, _) = self.graph.parse_reference(metric)?; + (model, name, false) } else { - return Err(SidemanticError::Validation(format!( - "Metric '{metric}' not found" - ))); - } - }; + let mut owners = Vec::new(); + for model in self.graph.models() { + if model.get_metric(metric).is_some() { + owners.push(model.name.clone()); + } + } + if owners.len() == 1 { + (owners[0].clone(), metric.clone(), false) + } else { + return Err(SidemanticError::Validation(format!( + "Metric '{metric}' not found" + ))); + } + }; refs.push(MetricRef { model, name: name.clone(), alias: name, + graph_metric, }); } @@ -778,7 +782,22 @@ impl<'a> SqlGenerator<'a> { Ok(()) } - fn exact_metric_reference(&self, reference: &str) -> Result> { + fn exact_metric_reference(&self, reference: &str) -> Result> { + if let Some(metric) = self.graph.get_metric(reference) { + let graph_metric_owners = self.graph_metric_owner_models(reference, metric)?; + return match graph_metric_owners.len() { + 0 => Ok(None), + 1 => Ok(Some(( + graph_metric_owners[0].clone(), + reference.to_string(), + true, + ))), + _ => Err(SidemanticError::InvalidReference { + reference: reference.to_string(), + }), + }; + } + let mut owners = Vec::new(); for model in self.graph.models() { if model.get_metric(reference).is_some() { @@ -788,7 +807,7 @@ impl<'a> SqlGenerator<'a> { match owners.len() { 0 => {} - 1 => return Ok(Some((owners[0].clone(), reference.to_string()))), + 1 => return Ok(Some((owners[0].clone(), reference.to_string(), false))), _ => { return Err(SidemanticError::InvalidReference { reference: reference.to_string(), @@ -796,20 +815,6 @@ impl<'a> SqlGenerator<'a> { } } - if let Some(metric) = self.graph.get_metric(reference) { - let graph_metric_owners = self.graph_metric_owner_models(reference, metric)?; - return match graph_metric_owners.len() { - 0 => Ok(None), - 1 => Ok(Some(( - graph_metric_owners[0].clone(), - reference.to_string(), - ))), - _ => Err(SidemanticError::InvalidReference { - reference: reference.to_string(), - }), - }; - } - Ok(None) } @@ -895,11 +900,33 @@ impl<'a> SqlGenerator<'a> { } } - fn metric_for_model(&self, model_name: &str, metric_name: &str) -> Result<&Metric> { + fn metric_for_ref(&self, metric_ref: &MetricRef) -> Result<&Metric> { + self.metric_for_model_with_source( + &metric_ref.model, + &metric_ref.name, + metric_ref.graph_metric, + ) + } + + fn metric_for_model_with_source( + &self, + model_name: &str, + metric_name: &str, + graph_metric: bool, + ) -> Result<&Metric> { let model = self.graph.get_model(model_name).ok_or_else(|| { let available: Vec<&str> = self.graph.models().map(|m| m.name.as_str()).collect(); SidemanticError::model_not_found(model_name, &available) })?; + if graph_metric { + if let Some(metric) = self.graph.get_metric(metric_name) { + let owners = self.graph_metric_owner_models(metric_name, metric)?; + if owners.iter().any(|owner| owner == model_name) { + return Ok(metric); + } + } + } + if let Some(metric) = model.get_metric(metric_name) { return Ok(metric); } @@ -963,14 +990,18 @@ impl<'a> SqlGenerator<'a> { &self, metric_ref: &MetricRef, models: &mut HashSet, - visiting: &mut HashSet<(String, String)>, + visiting: &mut HashSet<(String, String, bool)>, ) -> Result<()> { - let key = (metric_ref.model.clone(), metric_ref.name.clone()); + let key = ( + metric_ref.model.clone(), + metric_ref.name.clone(), + metric_ref.graph_metric, + ); if !visiting.insert(key.clone()) { return Ok(()); } - let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; + let metric = self.metric_for_ref(metric_ref)?; let exprs: Vec<&str> = [ metric.sql.as_deref(), @@ -1005,7 +1036,7 @@ impl<'a> SqlGenerator<'a> { if Self::is_sql_keyword_or_function(token) { continue; } - if let Some((model_name, metric_name)) = + if let Some((model_name, metric_name, graph_metric)) = self.resolve_metric_reference_location(token, &metric_ref.model)? { models.insert(model_name.clone()); @@ -1014,6 +1045,7 @@ impl<'a> SqlGenerator<'a> { model: model_name, name: metric_name.clone(), alias: metric_name, + graph_metric, }, models, visiting, @@ -1128,15 +1160,19 @@ impl<'a> SqlGenerator<'a> { fn collect_simple_metric_dependencies( &self, metric_ref: &MetricRef, - deps: &mut HashSet<(String, String)>, - visiting: &mut HashSet<(String, String)>, + deps: &mut HashSet<(String, String, bool)>, + visiting: &mut HashSet<(String, String, bool)>, ) -> Result<()> { - let key = (metric_ref.model.clone(), metric_ref.name.clone()); + let key = ( + metric_ref.model.clone(), + metric_ref.name.clone(), + metric_ref.graph_metric, + ); if !visiting.insert(key.clone()) { return Ok(()); } - let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; + let metric = self.metric_for_ref(metric_ref)?; match metric.r#type { MetricType::Simple => { @@ -1174,8 +1210,8 @@ impl<'a> SqlGenerator<'a> { &self, expr: &str, default_model: &str, - deps: &mut HashSet<(String, String)>, - visiting: &mut HashSet<(String, String)>, + deps: &mut HashSet<(String, String, bool)>, + visiting: &mut HashSet<(String, String, bool)>, ) -> Result<()> { let ref_re = regex::Regex::new( r"\b([A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*|[A-Za-z_][A-Za-z0-9_]*)\b", @@ -1187,7 +1223,7 @@ impl<'a> SqlGenerator<'a> { continue; }; let token = token_match.as_str(); - let Some((model, name)) = + let Some((model, name, graph_metric)) = self.resolve_metric_reference_location(token, default_model)? else { continue; @@ -1197,6 +1233,7 @@ impl<'a> SqlGenerator<'a> { model, name: name.clone(), alias: name, + graph_metric, }, deps, visiting, @@ -1210,14 +1247,18 @@ impl<'a> SqlGenerator<'a> { &self, metric_ref: &MetricRef, deps: &mut HashSet<(String, String)>, - visiting: &mut HashSet<(String, String)>, + visiting: &mut HashSet<(String, String, bool)>, ) -> Result<()> { - let key = (metric_ref.model.clone(), metric_ref.name.clone()); + let key = ( + metric_ref.model.clone(), + metric_ref.name.clone(), + metric_ref.graph_metric, + ); if !visiting.insert(key.clone()) { return Ok(()); } - let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; + let metric = self.metric_for_ref(metric_ref)?; match metric.r#type { MetricType::Derived if Self::is_inline_aggregate_expression(metric.sql_expr()) => { @@ -1237,7 +1278,7 @@ impl<'a> SqlGenerator<'a> { continue; }; let token = token_match.as_str(); - let Some((model, name)) = + let Some((model, name, graph_metric)) = self.resolve_metric_reference_location(token, &metric_ref.model)? else { continue; @@ -1247,6 +1288,7 @@ impl<'a> SqlGenerator<'a> { model, name: name.clone(), alias: name, + graph_metric, }, deps, visiting, @@ -1258,7 +1300,7 @@ impl<'a> SqlGenerator<'a> { .into_iter() .flatten() { - if let Some((model, name)) = + if let Some((model, name, graph_metric)) = self.resolve_metric_reference_location(expr, &metric_ref.model)? { self.collect_inline_metric_column_dependencies( @@ -1266,6 +1308,7 @@ impl<'a> SqlGenerator<'a> { model, name: name.clone(), alias: name, + graph_metric, }, deps, visiting, @@ -1335,18 +1378,21 @@ impl<'a> SqlGenerator<'a> { &self, reference: &str, default_model: &str, - ) -> Result> { + ) -> Result> { + if let Some((model_name, metric_name, graph_metric)) = + self.exact_metric_reference(reference)? + { + return Ok(Some((model_name, metric_name, graph_metric))); + } + if reference.contains('.') { - if let Some((model_name, metric_name)) = self.exact_metric_reference(reference)? { - return Ok(Some((model_name, metric_name))); - } let (model_name, metric_name, _) = self.graph.parse_reference(reference)?; let Some(model) = self.graph.get_model(&model_name) else { return Ok(None); }; return Ok(model .get_metric(&metric_name) - .map(|_| (model_name, metric_name))); + .map(|_| (model_name, metric_name, false))); } let mut owners = Vec::new(); @@ -1356,11 +1402,15 @@ impl<'a> SqlGenerator<'a> { } } if owners.len() == 1 { - return Ok(Some((owners[0].clone(), reference.to_string()))); + return Ok(Some((owners[0].clone(), reference.to_string(), false))); } if let Some(default) = self.graph.get_model(default_model) { if default.get_metric(reference).is_some() { - return Ok(Some((default_model.to_string(), reference.to_string()))); + return Ok(Some(( + default_model.to_string(), + reference.to_string(), + false, + ))); } } @@ -1476,7 +1526,7 @@ impl<'a> SqlGenerator<'a> { fn has_cumulative_metrics(&self, metric_refs: &[MetricRef]) -> Result { for metric_ref in metric_refs { - let metric = self.metric_for_model(&metric_ref.model, &metric_ref.name)?; + let metric = self.metric_for_ref(metric_ref)?; if metric.r#type == MetricType::Cumulative || metric.r#type == MetricType::TimeComparison || metric.r#type == MetricType::Conversion @@ -3281,7 +3331,7 @@ impl<'a> SqlGenerator<'a> { let mut seen_models = HashSet::new(); for metric_ref in metrics { let model_name = - if let Some((model_name, _)) = self.exact_metric_reference(metric_ref)? { + if let Some((model_name, _, _)) = self.exact_metric_reference(metric_ref)? { model_name } else { if !metric_ref.contains('.') { @@ -4158,41 +4208,20 @@ impl<'a> SqlGenerator<'a> { &self, reference: &str, default_model: &str, - visited: &mut HashSet<(String, String)>, + visited: &mut HashSet<(String, String, bool)>, ) -> Result> { - let (model_name, metric_name) = if reference.contains('.') { - if let Some((model_name, metric_name)) = self.exact_metric_reference(reference)? { - (model_name, metric_name) - } else { - let (m, n, _) = self.graph.parse_reference(reference)?; - (m, n) - } - } else { - let mut owners = Vec::new(); - for model in self.graph.models() { - if model.get_metric(reference).is_some() { - owners.push(model.name.clone()); - } - } - if owners.len() == 1 { - (owners[0].clone(), reference.to_string()) - } else if let Some(default) = self.graph.get_model(default_model) { - if default.get_metric(reference).is_some() { - (default_model.to_string(), reference.to_string()) - } else { - return Ok(None); - } - } else { - return Ok(None); - } + let Some((model_name, metric_name, graph_metric)) = + self.resolve_metric_reference_location(reference, default_model)? + else { + return Ok(None); }; - let key = (model_name.clone(), metric_name.clone()); + let key = (model_name.clone(), metric_name.clone(), graph_metric); if !visited.insert(key.clone()) { return Ok(None); } - let metric = self.metric_for_model(&model_name, &metric_name)?; + let metric = self.metric_for_model_with_source(&model_name, &metric_name, graph_metric)?; let alias = self.model_alias(&model_name); let expanded = match metric.r#type { @@ -4228,7 +4257,7 @@ impl<'a> SqlGenerator<'a> { &self, expr: &str, default_model: &str, - visited: &mut HashSet<(String, String)>, + visited: &mut HashSet<(String, String, bool)>, ) -> Result { if Self::is_inline_aggregate_expression(expr) { return self.rewrite_inline_aggregate_expression(expr, default_model); @@ -4613,6 +4642,38 @@ mod tests { assert!(sql.contains("GROUP BY 1")); } + #[test] + fn test_unqualified_graph_metric_wins_over_same_name_model_metric() { + let mut graph = create_test_graph(); + graph + .add_metric(Metric::sum("revenue", "gross_cents")) + .unwrap(); + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new().with_metrics(vec!["revenue".into()]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("gross_cents AS revenue_raw"), "{sql}"); + assert!(!sql.contains("amount AS revenue_raw"), "{sql}"); + } + + #[test] + fn test_qualified_model_metric_wins_over_same_name_graph_metric() { + let mut graph = create_test_graph(); + graph + .add_metric(Metric::sum("revenue", "gross_cents")) + .unwrap(); + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new().with_metrics(vec!["orders.revenue".into()]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("amount AS revenue_raw"), "{sql}"); + assert!(!sql.contains("gross_cents AS revenue_raw"), "{sql}"); + } + #[test] fn test_statistical_aggregation_metrics_render_supported_sql() { let mut graph = SemanticGraph::new(); From 30675820ec5892eb5f0c4da976d8b75202050d69 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 19:25:44 -0700 Subject: [PATCH 10/13] Allow compact SQL graph definitions in Rust loader --- sidemantic-rs/src/config/loader.rs | 28 ++++++++++ sidemantic-rs/src/config/sql_parser.rs | 75 ++++++++++++++++++++++---- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/sidemantic-rs/src/config/loader.rs b/sidemantic-rs/src/config/loader.rs index 4876a3a2..f9d751ff 100644 --- a/sidemantic-rs/src/config/loader.rs +++ b/sidemantic-rs/src/config/loader.rs @@ -1277,6 +1277,34 @@ model orders from orders ( assert!(orders.get_metric("revenue").is_some()); } + #[test] + fn test_load_from_sql_string_collects_graph_definitions_after_compact_model() { + let sql = r#" +model orders from orders ( + primary key (order_id) + sum(amount) as revenue +) + +METRIC ( + name total_revenue, + sql orders.revenue +); + +PARAMETER ( + name region, + type string, + allowed_values [us, eu] +); +"#; + + let loaded = load_from_sql_string_with_metadata(sql).unwrap(); + + let orders = loaded.graph.get_model("orders").unwrap(); + assert!(orders.get_metric("revenue").is_some()); + assert!(loaded.graph.get_metric("total_revenue").is_some()); + assert!(loaded.graph.get_parameter("region").is_some()); + } + #[test] fn test_load_from_sql_string_keeps_multiple_legacy_models_separate() { let sql = r#" diff --git a/sidemantic-rs/src/config/sql_parser.rs b/sidemantic-rs/src/config/sql_parser.rs index dd590be6..f2611465 100644 --- a/sidemantic-rs/src/config/sql_parser.rs +++ b/sidemantic-rs/src/config/sql_parser.rs @@ -1490,15 +1490,23 @@ fn build_compact_model( Ok(model) } -fn parse_compact_sql_models(sql: &str) -> Result> { +fn parse_compact_sql_model_prefix(sql: &str) -> Result<(Vec, &str)> { let header_re = Regex::new(r"(?is)\bmodel\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s*") .expect("valid compact model header regex"); let mut models = Vec::new(); let mut remaining = sql; - while let Some(captures) = header_re.captures(remaining) { + loop { + remaining = remaining.trim_start(); + if let Some(after_semicolon) = remaining.strip_prefix(';') { + remaining = after_semicolon; + continue; + } + let Some(captures) = header_re.captures(remaining) else { + break; + }; let matched = captures.get(0).unwrap(); - if !remaining[..matched.start()].trim().is_empty() { + if matched.start() != 0 { return Err(SidemanticError::Validation( "Rust compact SQL model parser does not support non-model statements before compact model blocks".to_string(), )); @@ -1548,10 +1556,7 @@ fn parse_compact_sql_models(sql: &str) -> Result> { SidemanticError::Validation(format!("compact model '{name}' has an unterminated body")) })?; models.push(build_compact_model(name, table, source_sql, body)?); - remaining = rest.trim_start(); - if let Some(after_semicolon) = remaining.strip_prefix(';') { - remaining = after_semicolon; - } + remaining = rest; } if models.is_empty() { @@ -1560,10 +1565,18 @@ fn parse_compact_sql_models(sql: &str) -> Result> { )); } - if !remaining.trim().is_empty() { - return Err(SidemanticError::Validation( - "Rust compact SQL model parser does not support trailing graph-level definitions after compact model blocks".to_string(), - )); + Ok((models, remaining)) +} + +fn parse_compact_sql_models(sql: &str) -> Result> { + let (models, remaining) = parse_compact_sql_model_prefix(sql)?; + let trailing = remaining.trim(); + if !trailing.is_empty() { + parse_sql_graph_definitions_extended(trailing).map_err(|err| { + SidemanticError::Validation(format!( + "failed to parse trailing graph-level definitions after compact model blocks: {err}" + )) + })?; } Ok(models) @@ -1634,6 +1647,12 @@ pub fn parse_sql_graph_definitions( /// Parse SQL definitions for graph-level definitions including pre-aggregations. pub fn parse_sql_graph_definitions_extended(sql: &str) -> Result { + let sql = if has_compact_model_syntax(sql) { + let (_, remaining) = parse_compact_sql_model_prefix(sql)?; + remaining + } else { + sql + }; let (_, statements) = parse_file(sql).map_err(|e| SidemanticError::Validation(format!("Parse error: {e}")))?; @@ -2492,6 +2511,40 @@ model customers from public.customers ( assert!(models[1].get_dimension("region").is_some()); } + #[test] + fn test_parse_compact_sql_models_with_trailing_graph_definitions() { + let sql = r#" +model orders from orders ( + primary key (order_id) + sum(amount) as revenue +) + +METRIC ( + name total_revenue, + sql orders.revenue +); + +PARAMETER ( + name region, + type string, + allowed_values [us, eu] +); +"#; + + let models = parse_sql_models(sql).unwrap(); + assert_eq!( + models.iter().map(|m| m.name.as_str()).collect::>(), + vec!["orders"] + ); + + let (metrics, segments, parameters) = parse_sql_graph_definitions(sql).unwrap(); + assert_eq!(metrics.len(), 1); + assert_eq!(metrics[0].name, "total_revenue"); + assert!(segments.is_empty()); + assert_eq!(parameters.len(), 1); + assert_eq!(parameters[0].name, "region"); + } + #[test] fn test_parse_legacy_sql_models_multiple() { let sql = r#" From 858e1d7423e752d56f1c5eb306a328636b70a53e Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 19:46:27 -0700 Subject: [PATCH 11/13] Use graph metric SQL to resolve Rust owner --- sidemantic-rs/src/sql/generator.rs | 95 +++++++++++++++++++----------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index b6bdf1c8..b7fd533a 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -820,46 +820,51 @@ impl<'a> SqlGenerator<'a> { fn graph_metric_owner_models(&self, reference: &str, metric: &Metric) -> Result> { let mut owners = HashSet::new(); - for model in self.graph.models() { - if model.get_metric(reference).is_some() { - owners.insert(model.name.clone()); - } + + for fragment in [ + metric.sql.as_deref(), + metric.numerator.as_deref(), + metric.denominator.as_deref(), + metric.base_metric.as_deref(), + metric.entity.as_deref(), + metric.base_event.as_deref(), + metric.conversion_event.as_deref(), + metric.cohort_event.as_deref(), + metric.activity_event.as_deref(), + metric.having.as_deref(), + ] + .into_iter() + .flatten() + { + self.collect_owner_models_from_fragment(fragment, &mut owners); } - if owners.is_empty() { - for fragment in [ - metric.sql.as_deref(), - metric.numerator.as_deref(), - metric.denominator.as_deref(), - metric.base_metric.as_deref(), - metric.entity.as_deref(), - metric.base_event.as_deref(), - metric.conversion_event.as_deref(), - metric.cohort_event.as_deref(), - metric.activity_event.as_deref(), - metric.having.as_deref(), - ] - .into_iter() - .flatten() - { - self.collect_owner_models_from_fragment(fragment, &mut owners); - } + for filter in &metric.filters { + self.collect_owner_models_from_fragment(filter, &mut owners); + } - if let Some(steps) = metric.steps.as_ref() { - for step in steps { - self.collect_owner_models_from_fragment(step, &mut owners); - } + if let Some(steps) = metric.steps.as_ref() { + for step in steps { + self.collect_owner_models_from_fragment(step, &mut owners); } - if let Some(inner_metrics) = metric.inner_metrics.as_ref() { - for inner_metric in inner_metrics { - if let Some(sql) = inner_metric.sql.as_deref() { - self.collect_owner_models_from_fragment(sql, &mut owners); - } + } + if let Some(inner_metrics) = metric.inner_metrics.as_ref() { + for inner_metric in inner_metrics { + if let Some(sql) = inner_metric.sql.as_deref() { + self.collect_owner_models_from_fragment(sql, &mut owners); } } - if let Some(entity_dimensions) = metric.entity_dimensions.as_ref() { - for dimension in entity_dimensions { - self.collect_owner_models_from_fragment(dimension, &mut owners); + } + if let Some(entity_dimensions) = metric.entity_dimensions.as_ref() { + for dimension in entity_dimensions { + self.collect_owner_models_from_fragment(dimension, &mut owners); + } + } + + if owners.is_empty() { + for model in self.graph.models() { + if model.get_metric(reference).is_some() { + owners.insert(model.name.clone()); } } } @@ -4658,6 +4663,28 @@ mod tests { assert!(!sql.contains("amount AS revenue_raw"), "{sql}"); } + #[test] + fn test_graph_metric_owner_comes_from_metric_sql_not_same_named_model_metric() { + let mut graph = create_test_graph(); + let sales = Model::new("sales", "sale_id") + .with_table("sales") + .with_metric(Metric::sum("gross_sales", "amount")); + graph.add_model(sales).unwrap(); + graph + .add_metric(Metric::sum("revenue", "sales.amount")) + .unwrap(); + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new().with_metrics(vec!["revenue".into()]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("sales.amount AS revenue_raw"), "{sql}"); + assert!(sql.contains("FROM sales"), "{sql}"); + assert!(!sql.contains("\n amount AS revenue_raw"), "{sql}"); + assert!(!sql.contains("FROM orders"), "{sql}"); + } + #[test] fn test_qualified_model_metric_wins_over_same_name_graph_metric() { let mut graph = create_test_graph(); From fce0ca275b89d274453c977da99df3fc90e1f92c Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 20:21:54 -0700 Subject: [PATCH 12/13] Preserve explicit derived inline aggregates --- sidemantic-rs/src/config/schema.rs | 27 ++++++++++++++++++------- sidemantic-rs/src/config/sql_parser.rs | 22 ++++++++++++++------ sidemantic-rs/src/sql/generator.rs | 28 ++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/sidemantic-rs/src/config/schema.rs b/sidemantic-rs/src/config/schema.rs index 71fd5b7e..d292e7a2 100644 --- a/sidemantic-rs/src/config/schema.rs +++ b/sidemantic-rs/src/config/schema.rs @@ -606,7 +606,10 @@ impl DimensionConfig { impl MetricConfig { fn into_metric(self) -> Metric { - let inline_aggregation = if self.agg.is_none() { + let explicit_metric_type = self.metric_type.as_deref().map(str::to_ascii_lowercase); + let can_normalize_inline_aggregation = + matches!(explicit_metric_type.as_deref(), None | Some("simple")); + let inline_aggregation = if self.agg.is_none() && can_normalize_inline_aggregation { self.sql .as_deref() .and_then(parse_inline_metric_aggregation) @@ -614,12 +617,7 @@ impl MetricConfig { None }; - let metric_type = match self - .metric_type - .as_deref() - .map(str::to_ascii_lowercase) - .as_deref() - { + let metric_type = match explicit_metric_type.as_deref() { Some("simple") => MetricType::Simple, Some("derived") => MetricType::Derived, Some("ratio") => MetricType::Ratio, @@ -1795,6 +1793,9 @@ models: sql: VARIANCE_POP(amount) - name: revenue_per_order sql: SUM(amount) / COUNT(*) + - name: explicit_derived_revenue + type: derived + sql: SUM(orders.amount) "#; let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); @@ -1844,6 +1845,18 @@ models: revenue_per_order.sql.as_deref(), Some("SUM(amount) / COUNT(*)") ); + + let explicit_derived_revenue = orders + .metrics + .iter() + .find(|m| m.name == "explicit_derived_revenue") + .unwrap(); + assert_eq!(explicit_derived_revenue.r#type, MetricType::Derived); + assert_eq!(explicit_derived_revenue.agg, None); + assert_eq!( + explicit_derived_revenue.sql.as_deref(), + Some("SUM(orders.amount)") + ); } #[test] diff --git a/sidemantic-rs/src/config/sql_parser.rs b/sidemantic-rs/src/config/sql_parser.rs index f2611465..41d06944 100644 --- a/sidemantic-rs/src/config/sql_parser.rs +++ b/sidemantic-rs/src/config/sql_parser.rs @@ -1938,12 +1938,8 @@ fn build_metric(props: &HashMap) -> Option { metric.agg = parse_metric_aggregation(props.get("agg")); - if metric.agg.is_none() - && matches!( - metric.r#type, - MetricType::Simple | MetricType::Cumulative | MetricType::Derived - ) - { + let explicit_metric_type = props.get("type").map(|value| value.to_ascii_lowercase()); + if metric.agg.is_none() && matches!(explicit_metric_type.as_deref(), None | Some("simple")) { if let Some(sql) = metric.sql.as_deref() { if let Some((agg, inner_expr)) = extract_aggregation_with_polyglot(sql) { metric.agg = parse_metric_aggregation(Some(&agg)); @@ -2228,6 +2224,20 @@ mod tests { assert_eq!(revenue.sql, Some("amount".to_string())); } + #[test] + fn test_parse_explicit_derived_metric_preserves_inline_aggregate_expression() { + let sql = r#" + MODEL (name orders, table orders); + METRIC (name revenue, type derived, sql SUM(orders.amount)); + "#; + + let model = parse_sql_model(sql).unwrap(); + let revenue = model.get_metric("revenue").unwrap(); + assert_eq!(revenue.r#type, MetricType::Derived); + assert_eq!(revenue.agg, None); + assert_eq!(revenue.sql, Some("SUM(orders.amount)".to_string())); + } + #[test] fn test_parse_cohort_metric_preserves_outer_aggregation() { let sql = r#" diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index b7fd533a..8e88259d 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -4804,6 +4804,34 @@ mod tests { ); } + #[test] + fn test_explicit_derived_inline_aggregate_from_yaml_generates_aggregate() { + let graph = crate::config::load_from_string( + r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: derived_revenue + type: derived + sql: SUM(orders.amount) +"#, + ) + .unwrap(); + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new().with_metrics(vec!["orders.derived_revenue".into()]); + + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("amount AS amount"), "{sql}"); + assert!( + sql.contains("SUM(orders_cte.amount) AS derived_revenue"), + "{sql}" + ); + } + #[test] fn test_ordered_set_aggregate_metric_is_not_treated_as_metric_reference() { let mut graph = SemanticGraph::new(); From 30aa6da35ba9ad435f5908d9ec867d732a3cc12e Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sun, 31 May 2026 20:42:08 -0700 Subject: [PATCH 13/13] Resolve graph metric dependency owners --- sidemantic-rs/src/sql/generator.rs | 107 +++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/sidemantic-rs/src/sql/generator.rs b/sidemantic-rs/src/sql/generator.rs index 8e88259d..e8790a41 100644 --- a/sidemantic-rs/src/sql/generator.rs +++ b/sidemantic-rs/src/sql/generator.rs @@ -819,6 +819,31 @@ impl<'a> SqlGenerator<'a> { } fn graph_metric_owner_models(&self, reference: &str, metric: &Metric) -> Result> { + let mut visiting = HashSet::new(); + self.graph_metric_owner_models_inner(reference, metric, &mut visiting) + } + + fn graph_metric_owner_models_inner( + &self, + reference: &str, + metric: &Metric, + visiting: &mut HashSet, + ) -> Result> { + if !visiting.insert(reference.to_string()) { + return Ok(Vec::new()); + } + + let result = self.graph_metric_owner_models_uncycled(reference, metric, visiting); + visiting.remove(reference); + result + } + + fn graph_metric_owner_models_uncycled( + &self, + reference: &str, + metric: &Metric, + visiting: &mut HashSet, + ) -> Result> { let mut owners = HashSet::new(); for fragment in [ @@ -838,6 +863,13 @@ impl<'a> SqlGenerator<'a> { { self.collect_owner_models_from_fragment(fragment, &mut owners); } + for fragment in self.graph_metric_dependency_fragments(metric) { + self.collect_owner_models_from_graph_metric_dependencies( + fragment, + &mut owners, + visiting, + )?; + } for filter in &metric.filters { self.collect_owner_models_from_fragment(filter, &mut owners); @@ -890,6 +922,23 @@ impl<'a> SqlGenerator<'a> { Ok(owners) } + fn graph_metric_dependency_fragments<'b>(&self, metric: &'b Metric) -> Vec<&'b str> { + match metric.r#type { + MetricType::Derived => metric.sql.iter().map(String::as_str).collect(), + MetricType::Ratio => [metric.numerator.as_deref(), metric.denominator.as_deref()] + .into_iter() + .flatten() + .collect(), + MetricType::Cumulative | MetricType::TimeComparison => { + [metric.base_metric.as_deref(), metric.sql.as_deref()] + .into_iter() + .flatten() + .collect() + } + _ => Vec::new(), + } + } + fn collect_owner_models_from_fragment(&self, fragment: &str, owners: &mut HashSet) { let model_ref_re = regex::Regex::new(r"\b([A-Za-z_][A-Za-z0-9_]*)\.([A-Za-z_][A-Za-z0-9_]*)\b") @@ -905,6 +954,34 @@ impl<'a> SqlGenerator<'a> { } } + fn collect_owner_models_from_graph_metric_dependencies( + &self, + fragment: &str, + owners: &mut HashSet, + visiting: &mut HashSet, + ) -> Result<()> { + let metric_ref_re = regex::Regex::new( + r"\b([A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*|[A-Za-z_][A-Za-z0-9_]*)\b", + ) + .expect("valid metric reference regex"); + for cap in metric_ref_re.captures_iter(fragment) { + let Some(token_match) = cap.get(1) else { + continue; + }; + let token = token_match.as_str(); + if token.contains('.') || Self::is_sql_keyword_or_function(token) { + continue; + } + let Some(metric) = self.graph.get_metric(token) else { + continue; + }; + for owner in self.graph_metric_owner_models_inner(token, metric, visiting)? { + owners.insert(owner); + } + } + Ok(()) + } + fn metric_for_ref(&self, metric_ref: &MetricRef) -> Result<&Metric> { self.metric_for_model_with_source( &metric_ref.model, @@ -4685,6 +4762,36 @@ mod tests { assert!(!sql.contains("FROM orders"), "{sql}"); } + #[test] + fn test_graph_metric_owner_follows_unqualified_graph_metric_dependencies() { + let mut graph = create_test_graph(); + graph + .add_metric(Metric::sum("signups", "orders.signups")) + .unwrap(); + graph + .add_metric(Metric::sum("visitors", "orders.visitors")) + .unwrap(); + graph + .add_metric(Metric::ratio("conversion_rate", "signups", "visitors")) + .unwrap(); + let generator = SqlGenerator::new(&graph); + + let refs = generator + .parse_metric_refs(&["conversion_rate".to_string()]) + .unwrap(); + + assert_eq!(refs.len(), 1); + assert_eq!(refs[0].model, "orders"); + assert_eq!(refs[0].name, "conversion_rate"); + assert!(refs[0].graph_metric); + + let query = SemanticQuery::new().with_metrics(vec!["conversion_rate".into()]); + let sql = generator.generate(&query).unwrap(); + + assert!(sql.contains("orders.signups AS signups_raw"), "{sql}"); + assert!(sql.contains("orders.visitors AS visitors_raw"), "{sql}"); + } + #[test] fn test_qualified_model_metric_wins_over_same_name_graph_metric() { let mut graph = create_test_graph();