@@ -309,6 +309,42 @@ def _custom_join_condition(self, join_path) -> str:
309309 to_alias = self ._quote_identifier (self ._cte_name (join_path .to_model ))
310310 return join_path .custom_condition .replace ("{from}" , from_alias ).replace ("{to}" , to_alias )
311311
312+ def _custom_join_columns (self , join_path ) -> dict [str , set [str ]]:
313+ """Extract raw columns that a custom join predicate reads from each side."""
314+ if not join_path .custom_condition :
315+ return {}
316+
317+ from_marker = "__from__"
318+ to_marker = "__to__"
319+ condition = join_path .custom_condition .replace ("{from}" , from_marker ).replace ("{to}" , to_marker )
320+ try :
321+ parsed = sqlglot .parse_one (condition , dialect = self .dialect )
322+ except Exception as exc :
323+ raise ValueError (
324+ "Could not parse custom relationship SQL for "
325+ f"{ join_path .from_model } -> { join_path .to_model } : { join_path .custom_condition } "
326+ ) from exc
327+
328+ columns : dict [str , set [str ]] = {join_path .from_model : set (), join_path .to_model : set ()}
329+ for column in parsed .find_all (exp .Column ):
330+ if column .table == from_marker :
331+ columns [join_path .from_model ].add (column .name )
332+ elif column .table == to_marker :
333+ columns [join_path .to_model ].add (column .name )
334+
335+ return {model_name : cols for model_name , cols in columns .items () if cols }
336+
337+ def _custom_join_columns_by_model (self , base_model_name : str , other_models : list [str ]) -> dict [str , set [str ]]:
338+ columns_by_model : dict [str , set [str ]] = {}
339+ for other_model in other_models :
340+ join_path = self .graph .find_relationship_path (base_model_name , other_model )
341+ if not join_path :
342+ continue
343+ for join_step in join_path :
344+ for model_name , columns in self ._custom_join_columns (join_step ).items ():
345+ columns_by_model .setdefault (model_name , set ()).update (columns )
346+ return columns_by_model
347+
312348 def _apply_default_time_dimensions (self , metrics : list [str ], dimensions : list [str ]) -> list [str ]:
313349 """Auto-include default_time_dimension from models if not already present.
314350
@@ -669,6 +705,9 @@ def metric_needs_window(m):
669705
670706 # Extract columns needed for metric-level filters (before building CTEs)
671707 metric_filter_cols_by_model = self ._extract_metric_filter_columns (metrics )
708+ custom_join_cols_by_model = self ._custom_join_columns_by_model (base_model_name , model_names [1 :])
709+ for model_name , column_names in custom_join_cols_by_model .items ():
710+ metric_filter_cols_by_model .setdefault (model_name , set ()).update (column_names )
672711
673712 # Ensure dimensions referenced in outer-query filters (e.g. window dims)
674713 # are included in the relevant CTE SELECT lists.
0 commit comments