Skip to content

Commit 5d33825

Browse files
authored
Fix: Propagation of boolean model fields in Python models (#4130)
1 parent 759b404 commit 5d33825

14 files changed

Lines changed: 118 additions & 61 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def to_sql(self, alias: str) -> str:
152152
# TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema
153153
if self.type == "motherduck":
154154
# MotherDuck does not support aliasing
155-
if (md_db := self.path.replace("md:", "")) != alias.replace('"', ""):
155+
md_db = self.path.replace("md:", "")
156+
if md_db != alias.replace('"', ""):
156157
raise ConfigError(
157158
f"MotherDuck does not support assigning an alias different from the database name {md_db}."
158159
)
@@ -195,7 +196,8 @@ def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
195196
if not isinstance(data, dict):
196197
return data
197198

198-
if db_path := data.get("database") and data.get("catalogs"):
199+
db_path = data.get("database")
200+
if db_path and data.get("catalogs"):
199201
raise ConfigError(
200202
"Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
201203
)
@@ -302,7 +304,8 @@ def create_engine_adapter(self, register_comments_override: bool = False) -> Eng
302304
data_files.discard(":memory:")
303305
for data_file in data_files:
304306
key = data_file if isinstance(data_file, str) else data_file.path
305-
if adapter := BaseDuckDBConnectionConfig._data_file_to_adapter.get(key):
307+
adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key)
308+
if adapter is not None:
306309
logger.info(
307310
f"Using existing DuckDB adapter due to overlapping data file: {self._mask_motherduck_token(key)}"
308311
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,17 +1046,25 @@ def create_view(
10461046
if materialized_properties:
10471047
partitioned_by = materialized_properties.pop("partitioned_by", None)
10481048
clustered_by = materialized_properties.pop("clustered_by", None)
1049-
if partitioned_by and (
1050-
partitioned_by_prop := self._build_partitioned_by_exp(
1051-
partitioned_by, **materialized_properties
1049+
if (
1050+
partitioned_by
1051+
and (
1052+
partitioned_by_prop := self._build_partitioned_by_exp(
1053+
partitioned_by, **materialized_properties
1054+
)
10521055
)
1056+
is not None
10531057
):
10541058
materialized_properties["catalog_name"] = exp.to_table(view_name).catalog
10551059
properties.append("expressions", partitioned_by_prop)
1056-
if clustered_by and (
1057-
clustered_by_prop := self._build_clustered_by_exp(
1058-
clustered_by, **materialized_properties
1060+
if (
1061+
clustered_by
1062+
and (
1063+
clustered_by_prop := self._build_clustered_by_exp(
1064+
clustered_by, **materialized_properties
1065+
)
10591066
)
1067+
is not None
10601068
):
10611069
properties.append("expressions", clustered_by_prop)
10621070

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ def _build_struct_with_descriptions(
830830
return exp.DataType(this=col_type.this, expressions=column_expressions, nested=True)
831831

832832
# Recursively build column definitions for BigQuery's RECORDs (struct) and REPEATED RECORDs (array of struct)
833-
if isinstance(col_type, exp.DataType) and (expressions := col_type.expressions):
833+
if isinstance(col_type, exp.DataType) and col_type.expressions:
834+
expressions = col_type.expressions
834835
if col_type.is_type(exp.DataType.Type.STRUCT):
835836
col_type = _build_struct_with_descriptions(col_type, nested_names + [col_name])
836837
elif col_type.is_type(exp.DataType.Type.ARRAY) and expressions[0].is_type(

sqlmesh/core/engine_adapter/clickhouse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,9 @@ def _build_table_properties_exp(
785785
)
786786
)
787787

788-
if partitioned_by and (
789-
partitioned_by_prop := self._build_partitioned_by_exp(partitioned_by)
788+
if (
789+
partitioned_by
790+
and (partitioned_by_prop := self._build_partitioned_by_exp(partitioned_by)) is not None
790791
):
791792
properties.append(partitioned_by_prop)
792793

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,10 @@ def _build_table_properties_exp(
221221
)
222222
)
223223

224-
if clustered_by and (clustered_by_prop := self._build_clustered_by_exp(clustered_by)):
224+
if (
225+
clustered_by
226+
and (clustered_by_prop := self._build_clustered_by_exp(clustered_by)) is not None
227+
):
225228
properties.append(clustered_by_prop)
226229

227230
if table_properties:

sqlmesh/core/engine_adapter/trino.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]:
353353
match_key = schema.db
354354

355355
# only consider the catalog if it is present
356-
if catalog := schema.catalog:
357-
match_key = f"{catalog}.{match_key}"
356+
if schema.catalog:
357+
match_key = f"{schema.catalog}.{match_key}"
358358

359359
for k, v in mapping.items():
360360
if re.match(k, match_key):

sqlmesh/core/environment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def execute_environment_statements(
233233
end: t.Optional[TimeLike] = None,
234234
execution_time: t.Optional[TimeLike] = None,
235235
) -> None:
236-
if rendered_expressions := [
236+
rendered_expressions = [
237237
expr
238238
for statements in environment_statements
239239
for expr in render_statements(
@@ -250,7 +250,8 @@ def execute_environment_statements(
250250
runtime_stage=runtime_stage,
251251
engine_adapter=adapter,
252252
)
253-
]:
253+
]
254+
if rendered_expressions:
254255
with adapter.transaction():
255256
for expr in rendered_expressions:
256257
adapter.execute(expr)

sqlmesh/core/model/definition.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,7 +2406,8 @@ def _create_model(
24062406
# since rendering shifted from load time to run time.
24072407
# Note: we check for Tuple since that's what we expect from _resolve_properties
24082408
for property_name in PROPERTIES:
2409-
if isinstance(property_values := kwargs.get(property_name), exp.Tuple):
2409+
property_values = kwargs.get(property_name)
2410+
if isinstance(property_values, exp.Tuple):
24102411
statements.extend(property_values.expressions)
24112412

24122413
jinja_macro_references, used_variables = extract_macro_references_and_variables(
@@ -2682,30 +2683,41 @@ def render_field_value(value: t.Any) -> t.Any:
26822683

26832684
for field_name, field_info in ModelMeta.all_field_infos().items():
26842685
field = field_info.alias or field_name
2685-
if field not in RUNTIME_RENDERED_MODEL_FIELDS and (field_value := fields.get(field)):
2686-
if isinstance(field_value, dict):
2687-
rendered_dict = {}
2688-
for key, value in field_value.items():
2689-
if key in RUNTIME_RENDERED_MODEL_FIELDS:
2690-
rendered_dict[key] = value
2691-
elif rendered := render_field_value(value):
2692-
rendered_dict[key] = rendered
2693-
if rendered_dict:
2694-
fields[field] = rendered_dict
2695-
else:
2696-
fields.pop(field)
2697-
elif isinstance(field_value, list):
2698-
if rendered_list := [
2699-
rendered for value in field_value if (rendered := render_field_value(value))
2700-
]:
2701-
fields[field] = rendered_list
2702-
else:
2703-
fields.pop(field)
2686+
2687+
if field in RUNTIME_RENDERED_MODEL_FIELDS:
2688+
continue
2689+
2690+
field_value = fields.get(field)
2691+
if field_value is None:
2692+
continue
2693+
2694+
if isinstance(field_value, dict):
2695+
rendered_dict = {}
2696+
for key, value in field_value.items():
2697+
if key in RUNTIME_RENDERED_MODEL_FIELDS:
2698+
rendered_dict[key] = value
2699+
elif (rendered := render_field_value(value)) is not None:
2700+
rendered_dict[key] = rendered
2701+
if rendered_dict:
2702+
fields[field] = rendered_dict
27042703
else:
2705-
if rendered_field := render_field_value(field_value):
2706-
fields[field] = rendered_field
2707-
else:
2708-
fields.pop(field)
2704+
fields.pop(field)
2705+
elif isinstance(field_value, list):
2706+
rendered_list = [
2707+
rendered
2708+
for value in field_value
2709+
if (rendered := render_field_value(value)) is not None
2710+
]
2711+
if rendered_list:
2712+
fields[field] = rendered_list
2713+
else:
2714+
fields.pop(field)
2715+
else:
2716+
rendered_field = render_field_value(field_value)
2717+
if rendered_field is not None:
2718+
fields[field] = rendered_field
2719+
else:
2720+
fields.pop(field)
27092721

27102722
return fields
27112723

@@ -2733,12 +2745,13 @@ def render_model_defaults(
27332745

27342746
# Validate defaults that have macros are rendered to boolean
27352747
for boolean in {"optimize_query", "allow_partials", "enabled"}:
2736-
if var := rendered_defaults.get(boolean):
2737-
if not isinstance(var, (exp.Boolean, bool)):
2738-
raise ConfigError(f"Expected boolean for '{var}', got '{type(var)}' instead")
2748+
var = rendered_defaults.get(boolean)
2749+
if var is not None and not isinstance(var, (exp.Boolean, bool)):
2750+
raise ConfigError(f"Expected boolean for '{var}', got '{type(var)}' instead")
27392751

27402752
# Validate the 'interval_unit' if present is an Interval Unit
2741-
if (var := rendered_defaults.get("interval_unit")) and isinstance(var, str):
2753+
var = rendered_defaults.get("interval_unit")
2754+
if isinstance(var, str):
27422755
try:
27432756
rendered_defaults["interval_unit"] = IntervalUnit(var)
27442757
except ValueError as e:
@@ -2751,10 +2764,10 @@ def parse_defaults_properties(
27512764
defaults: t.Dict[str, t.Any], dialect: DialectType
27522765
) -> t.Dict[str, t.Any]:
27532766
for prop in PROPERTIES:
2754-
if default_properties := defaults.get(prop):
2755-
for key, value in default_properties.items():
2756-
if isinstance(key, str) and d.SQLMESH_MACRO_PREFIX in str(value):
2757-
defaults[prop][key] = exp.maybe_parse(value, dialect=dialect)
2767+
default_properties = defaults.get(prop)
2768+
for key, value in (default_properties or {}).items():
2769+
if isinstance(key, str) and d.SQLMESH_MACRO_PREFIX in str(value):
2770+
defaults[prop][key] = exp.maybe_parse(value, dialect=dialect)
27582771

27592772
return defaults
27602773

sqlmesh/core/model/kind.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -996,9 +996,10 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M
996996
# we dont want to throw an error here because we still want Models with a CustomKind to be able
997997
# to be serialized / deserialized in contexts where the custom materialization class may not be available,
998998
# such as in HTTP request handlers
999-
if custom_materialization := get_custom_materialization_type(
999+
custom_materialization = get_custom_materialization_type(
10001000
validate_string(props.get("materialization")), raise_errors=False
1001-
):
1001+
)
1002+
if custom_materialization is not None:
10021003
actual_kind_type, _ = custom_materialization
10031004
return actual_kind_type(**props)
10041005

sqlmesh/core/plan/evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,16 @@ def _restatement_intervals_across_all_environments(
500500
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
501501
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
502502
# so we only do it if necessary
503-
if full_history_restatement_snapshot_ids := [
503+
full_history_restatement_snapshot_ids = [
504504
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
505505
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
506506
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
507507
# So for now, these are not considered
508508
s_id
509509
for s_id, s in snapshots_to_restate.items()
510510
if s[0].full_history_restatement_only
511-
]:
511+
]
512+
if full_history_restatement_snapshot_ids:
512513
# only load full snapshot records that we havent already loaded
513514
additional_snapshots = self.state_sync.get_snapshots(
514515
[

0 commit comments

Comments
 (0)