Skip to content

Commit 26e9294

Browse files
authored
Add post-processing SQL over semantic query results (#135)
* Add post-processing SQL over semantic query results Support arbitrary SQL (CASE, window functions, arithmetic, etc.) on top of semantic query results via subquery wrapping. The rewriter now walks the query tree recursively so nested subqueries and JOIN subqueries that reference semantic models are compiled correctly. Also adds a post_process parameter to compile() and query() for the Python API path, with automatic CTE hoisting. * Preserve root semantic rewrite when JOIN has subquery When the root FROM references a semantic model and a JOIN contains a subquery, the routing into _rewrite_with_ctes_or_subqueries must still apply _rewrite_simple_query to the root SELECT so the explicit JOIN guard and semantic rewriting are not bypassed. * Keep CTE scope in root semantic rewrite and merge post_process CTEs Fix two issues from PR review: 1. When the root SELECT references a semantic model and has user-defined CTEs, _rewrite_simple_query was discarding them. Now user CTEs are saved before rewriting and merged back into the generated SQL so filter references like IN (SELECT ... FROM user_cte) remain valid. 2. When post_process SQL has its own WITH clause and the inner semantic query also produces CTEs, the two WITH clauses are now merged into one instead of producing invalid double-WITH SQL. * Preserve WITH RECURSIVE when merging root CTEs Propagate the recursive flag from user-defined CTEs to the merged WITH clause so self-referencing CTEs execute correctly. * Handle CTE name collisions between user and generated CTEs Two changes: 1. post_process path: remove CTE hoisting entirely. Inner SQL (with CTEs) is placed directly in the subquery position. CTEs inside subqueries are valid in all target databases and naturally scoped, so name collisions with post_process CTEs cannot occur. 2. Root semantic + user CTEs: detect name collisions between user CTEs and generated CTEs, raising a clear error instead of producing invalid SQL. Walk-based renaming was too aggressive (renamed user CTE references inside filter subqueries).
1 parent 98406be commit 26e9294

3 files changed

Lines changed: 493 additions & 32 deletions

File tree

sidemantic/core/semantic_layer.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def query(
435435
ungrouped: bool = False,
436436
parameters: dict[str, any] | None = None,
437437
use_preaggregations: bool | None = None,
438+
post_process: str | None = None,
438439
):
439440
"""Execute a query against the semantic layer.
440441
@@ -448,6 +449,9 @@ def query(
448449
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
449450
parameters: Template parameters for Jinja2 rendering
450451
use_preaggregations: Override pre-aggregation routing setting for this query
452+
post_process: Optional SQL to wrap around the semantic query result.
453+
Use {inner} as a placeholder for the compiled semantic query, e.g.:
454+
"SELECT *, revenue / count AS avg_value FROM ({inner})"
451455
452456
Returns:
453457
DuckDB relation object (can convert to DataFrame with .df() or .to_df())
@@ -462,6 +466,7 @@ def query(
462466
ungrouped=ungrouped,
463467
parameters=parameters,
464468
use_preaggregations=use_preaggregations,
469+
post_process=post_process,
465470
)
466471

467472
return self.adapter.execute(sql)
@@ -479,6 +484,7 @@ def compile(
479484
ungrouped: bool = False,
480485
parameters: dict[str, any] | None = None,
481486
use_preaggregations: bool | None = None,
487+
post_process: str | None = None,
482488
) -> str:
483489
"""Compile a query to SQL without executing.
484490
@@ -493,6 +499,9 @@ def compile(
493499
dialect: SQL dialect override (defaults to layer's dialect)
494500
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
495501
use_preaggregations: Override pre-aggregation routing setting for this query
502+
post_process: Optional SQL to wrap around the semantic query result.
503+
Use {inner} as a placeholder for the compiled semantic query, e.g.:
504+
"SELECT *, revenue / count AS avg_value FROM ({inner})"
496505
497506
Returns:
498507
SQL query string
@@ -520,7 +529,7 @@ def compile(
520529
preagg_schema=self.preagg_schema,
521530
)
522531

523-
return generator.generate(
532+
inner_sql = generator.generate(
524533
metrics=metrics,
525534
dimensions=dimensions,
526535
filters=filters,
@@ -533,6 +542,24 @@ def compile(
533542
use_preaggregations=use_preaggs,
534543
)
535544

545+
if post_process is not None:
546+
if "{inner}" not in post_process:
547+
raise ValueError("post_process must contain a {inner} placeholder")
548+
549+
# Strip sidemantic instrumentation comment
550+
stripped = inner_sql.rstrip()
551+
last_line = stripped.split("\n")[-1].strip()
552+
if last_line.startswith("-- sidemantic:"):
553+
stripped = "\n".join(stripped.split("\n")[:-1])
554+
555+
# Inner SQL (including any CTEs) is placed directly in the
556+
# subquery position. CTEs inside subqueries are valid SQL in
557+
# all target databases and naturally scoped, avoiding name
558+
# collisions with CTEs in the post_process SQL.
559+
return post_process.replace("{inner}", stripped)
560+
561+
return inner_sql
562+
536563
def explain(
537564
self,
538565
metrics: list[str] | None = None,

sidemantic/sql/query_rewriter.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str:
118118
# Check if this is a CTE-based query or has subqueries
119119
has_ctes = parsed.args.get("with") is not None
120120
has_subquery_in_from = self._has_subquery_in_from(parsed)
121+
has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or []))
121122

122-
if has_ctes or has_subquery_in_from:
123+
if has_ctes or has_subquery_in_from or has_subquery_in_joins:
123124
# Handle CTEs and subqueries
124125
return self._rewrite_with_ctes_or_subqueries(parsed)
125126

@@ -1851,41 +1852,89 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool:
18511852
def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str:
18521853
"""Rewrite query that contains CTEs or subqueries.
18531854
1854-
Strategy:
1855-
1. Rewrite each CTE that references semantic models
1856-
2. Rewrite subqueries in FROM clause
1857-
3. Return the modified SQL
1855+
Recursively walks the query tree bottom-up, rewriting any
1856+
SELECT whose FROM target resolves to a semantic model.
1857+
Outer queries are left as plain SQL, so post-processing
1858+
(CASE, window functions, arithmetic, etc.) works naturally.
18581859
"""
1859-
# Handle CTEs
1860-
if parsed.args.get("with"):
1861-
with_clause = parsed.args["with"]
1862-
for cte in with_clause.expressions:
1863-
# Each CTE has a name (alias) and a query (this)
1860+
self._rewrite_select_tree(parsed)
1861+
1862+
# If the root SELECT itself references a semantic model, it must
1863+
# still go through _rewrite_simple_query (which enforces the
1864+
# explicit JOIN guard and performs semantic rewriting).
1865+
if self._references_semantic_model(parsed):
1866+
# Save user-defined CTEs before _rewrite_simple_query replaces
1867+
# the entire query with fresh generator output.
1868+
original_with = parsed.args.get("with")
1869+
1870+
rewritten_sql = self._rewrite_simple_query(parsed)
1871+
1872+
if original_with:
1873+
# Merge user CTEs into the generated SQL so references
1874+
# from filters/expressions (e.g. IN (SELECT ... FROM cte))
1875+
# remain valid.
1876+
rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect)
1877+
gen_with = rewritten.args.get("with")
1878+
if gen_with:
1879+
# Check for CTE name collisions between user and generated CTEs
1880+
user_names = {cte.alias for cte in original_with.expressions}
1881+
for gen_cte in gen_with.expressions:
1882+
if gen_cte.alias in user_names:
1883+
raise ValueError(
1884+
f"CTE name '{gen_cte.alias}' conflicts with an internally "
1885+
f"generated name. Please choose a different CTE name."
1886+
)
1887+
1888+
user_ctes = [cte.copy() for cte in original_with.expressions]
1889+
gen_with.set("expressions", user_ctes + list(gen_with.expressions))
1890+
# Preserve WITH RECURSIVE from the original query
1891+
if original_with.args.get("recursive"):
1892+
gen_with.set("recursive", True)
1893+
else:
1894+
rewritten.set("with", original_with.copy())
1895+
return rewritten.sql(dialect=self.dialect)
1896+
1897+
return rewritten_sql
1898+
1899+
return parsed.sql(dialect=self.dialect)
1900+
1901+
def _rewrite_select_tree(self, select: exp.Select):
1902+
"""Recursively rewrite semantic subqueries and CTEs (bottom-up).
1903+
1904+
At each level: recurse into children first, then rewrite this
1905+
node if it directly references a semantic model.
1906+
"""
1907+
# Recurse into CTEs
1908+
if select.args.get("with"):
1909+
for cte in select.args["with"].expressions:
18641910
cte_query = cte.this
18651911
if isinstance(cte_query, exp.Select):
1866-
# Check if this CTE references a semantic model
1912+
self._rewrite_select_tree(cte_query)
18671913
if self._references_semantic_model(cte_query):
1868-
# Rewrite the CTE query
1869-
rewritten_cte_sql = self._rewrite_simple_query(cte_query)
1870-
# Parse the rewritten SQL and replace the CTE query
1871-
rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect)
1872-
cte.set("this", rewritten_cte)
1873-
1874-
# Handle subquery in FROM
1875-
from_clause = parsed.args.get("from")
1914+
rewritten_sql = self._rewrite_simple_query(cte_query)
1915+
cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
1916+
1917+
# Recurse into FROM subquery
1918+
from_clause = select.args.get("from")
18761919
if from_clause and isinstance(from_clause.this, exp.Subquery):
18771920
subquery = from_clause.this
18781921
subquery_select = subquery.this
1879-
if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select):
1880-
# Rewrite the subquery
1881-
rewritten_subquery_sql = self._rewrite_simple_query(subquery_select)
1882-
rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect)
1883-
subquery.set("this", rewritten_subquery)
1884-
1885-
# Return the modified SQL
1886-
# Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator
1887-
# The outer query wrapper doesn't need separate instrumentation
1888-
return parsed.sql(dialect=self.dialect)
1922+
if isinstance(subquery_select, exp.Select):
1923+
self._rewrite_select_tree(subquery_select)
1924+
if self._references_semantic_model(subquery_select):
1925+
rewritten_sql = self._rewrite_simple_query(subquery_select)
1926+
subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
1927+
1928+
# Recurse into JOIN subqueries
1929+
for join in select.args.get("joins") or []:
1930+
join_expr = join.this
1931+
if isinstance(join_expr, exp.Subquery):
1932+
join_select = join_expr.this
1933+
if isinstance(join_select, exp.Select):
1934+
self._rewrite_select_tree(join_select)
1935+
if self._references_semantic_model(join_select):
1936+
rewritten_sql = self._rewrite_simple_query(join_select)
1937+
join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
18891938

18901939
def _references_semantic_model(self, select: exp.Select) -> bool:
18911940
"""Check if a SELECT statement references any semantic models."""

0 commit comments

Comments
 (0)