Skip to content

Commit 3892983

Browse files
committed
Project custom join predicate columns
1 parent 4b9c638 commit 3892983

2 files changed

Lines changed: 99 additions & 0 deletions

File tree

sidemantic/sql/generator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,42 @@ def _custom_join_condition(self, join_path) -> str:
309309
to_alias = self._quote_identifier(self._cte_name(join_path.to_model))
310310
return join_path.custom_condition.replace("{from}", from_alias).replace("{to}", to_alias)
311311

312+
def _custom_join_columns(self, join_path) -> dict[str, set[str]]:
313+
"""Extract raw columns that a custom join predicate reads from each side."""
314+
if not join_path.custom_condition:
315+
return {}
316+
317+
from_marker = "__from__"
318+
to_marker = "__to__"
319+
condition = join_path.custom_condition.replace("{from}", from_marker).replace("{to}", to_marker)
320+
try:
321+
parsed = sqlglot.parse_one(condition, dialect=self.dialect)
322+
except Exception as exc:
323+
raise ValueError(
324+
"Could not parse custom relationship SQL for "
325+
f"{join_path.from_model} -> {join_path.to_model}: {join_path.custom_condition}"
326+
) from exc
327+
328+
columns: dict[str, set[str]] = {join_path.from_model: set(), join_path.to_model: set()}
329+
for column in parsed.find_all(exp.Column):
330+
if column.table == from_marker:
331+
columns[join_path.from_model].add(column.name)
332+
elif column.table == to_marker:
333+
columns[join_path.to_model].add(column.name)
334+
335+
return {model_name: cols for model_name, cols in columns.items() if cols}
336+
337+
def _custom_join_columns_by_model(self, base_model_name: str, other_models: list[str]) -> dict[str, set[str]]:
338+
columns_by_model: dict[str, set[str]] = {}
339+
for other_model in other_models:
340+
join_path = self.graph.find_relationship_path(base_model_name, other_model)
341+
if not join_path:
342+
continue
343+
for join_step in join_path:
344+
for model_name, columns in self._custom_join_columns(join_step).items():
345+
columns_by_model.setdefault(model_name, set()).update(columns)
346+
return columns_by_model
347+
312348
def _apply_default_time_dimensions(self, metrics: list[str], dimensions: list[str]) -> list[str]:
313349
"""Auto-include default_time_dimension from models if not already present.
314350
@@ -669,6 +705,9 @@ def metric_needs_window(m):
669705

670706
# Extract columns needed for metric-level filters (before building CTEs)
671707
metric_filter_cols_by_model = self._extract_metric_filter_columns(metrics)
708+
custom_join_cols_by_model = self._custom_join_columns_by_model(base_model_name, model_names[1:])
709+
for model_name, column_names in custom_join_cols_by_model.items():
710+
metric_filter_cols_by_model.setdefault(model_name, set()).update(column_names)
672711

673712
# Ensure dimensions referenced in outer-query filters (e.g. window dims)
674713
# are included in the relevant CTE SELECT lists.

tests/queries/test_basic.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,66 @@ def test_no_prefix_when_no_collision(layer):
399399
assert "AS customers_customer_name" not in sql
400400

401401

402+
def test_custom_join_sql_projects_extra_predicate_columns():
403+
conn = duckdb.connect(":memory:")
404+
conn.execute("""
405+
CREATE TABLE orders (
406+
order_id INTEGER,
407+
customer_id INTEGER,
408+
amount INTEGER
409+
)
410+
""")
411+
conn.execute("""
412+
CREATE TABLE customers (
413+
customer_id INTEGER,
414+
country VARCHAR,
415+
valid_to DATE
416+
)
417+
""")
418+
conn.execute("INSERT INTO orders VALUES (1, 100, 50)")
419+
conn.execute("""
420+
INSERT INTO customers VALUES
421+
(100, 'US', NULL),
422+
(100, 'Expired', DATE '2024-01-01')
423+
""")
424+
425+
layer = SemanticLayer()
426+
layer.conn = conn
427+
layer.add_model(
428+
Model(
429+
name="orders",
430+
table="orders",
431+
primary_key="order_id",
432+
relationships=[
433+
Relationship(
434+
name="customers",
435+
type="many_to_one",
436+
foreign_key="customer_id",
437+
sql="{from}.customer_id = {to}.customer_id AND {to}.valid_to IS NULL",
438+
)
439+
],
440+
metrics=[Metric(name="revenue", agg="sum", sql="amount")],
441+
)
442+
)
443+
layer.add_model(
444+
Model(
445+
name="customers",
446+
table="customers",
447+
primary_key="customer_id",
448+
dimensions=[Dimension(name="country", type="categorical")],
449+
)
450+
)
451+
452+
sql = layer.compile(metrics=["orders.revenue"], dimensions=["customers.country"], order_by=["customers.country"])
453+
assert "valid_to AS valid_to" in sql
454+
assert "customers_cte.valid_to IS NULL" in sql
455+
456+
rows = df_rows(
457+
layer.query(metrics=["orders.revenue"], dimensions=["customers.country"], order_by=["customers.country"])
458+
)
459+
assert rows == [("Expired", None), ("US", 50)]
460+
461+
402462
def test_count_distinct_without_sql_uses_primary_key(layer):
403463
"""Test that count_distinct without sql field uses primary key.
404464

0 commit comments

Comments
 (0)