|
16 | 16 | from sqlglot.diff import Insert, Keep |
17 | 17 | from sqlglot.helper import ensure_list |
18 | 18 | from sqlglot.optimizer.scope import traverse_scope |
| 19 | +from sqlglot.schema import MappingSchema, nested_set |
19 | 20 | from sqlglot.time import format_time |
20 | 21 |
|
21 | 22 | from sqlmesh.core import constants as c |
@@ -405,6 +406,19 @@ def convert_to_time_column(self, time: TimeLike) -> exp.Expression: |
405 | 406 | return exp.cast(exp.Literal.string(time), time_column_type) |
406 | 407 | return exp.convert(time) |
407 | 408 |
|
| 409 | + def update_schema(self, schema: MappingSchema) -> None: |
| 410 | + """Updates the schema for this model's dependencies based on the given mapping schema.""" |
| 411 | + for dep in self.depends_on: |
| 412 | + table = exp.to_table(dep, dialect=self.dialect) |
| 413 | + mapping_schema = schema.find(table) |
| 414 | + |
| 415 | + if mapping_schema: |
| 416 | + nested_set( |
| 417 | + self.mapping_schema, |
| 418 | + tuple(str(part) for part in table.parts), |
| 419 | + {k: str(v) for k, v in mapping_schema.items()}, |
| 420 | + ) |
| 421 | + |
408 | 422 | @property |
409 | 423 | def depends_on(self) -> t.Set[str]: |
410 | 424 | """All of the upstream dependencies referenced in the model's query, excluding self references. |
@@ -701,6 +715,10 @@ def column_descriptions(self) -> t.Dict[str, str]: |
701 | 715 | } |
702 | 716 | return self._column_descriptions |
703 | 717 |
|
| 718 | + def update_schema(self, schema: MappingSchema) -> None: |
| 719 | + super().update_schema(schema) |
| 720 | + self._columns_to_types = None |
| 721 | + |
704 | 722 | def validate_definition(self) -> None: |
705 | 723 | query = self._query_renderer.render() |
706 | 724 |
|
|
0 commit comments