|
49 | 49 | from sqlmesh.core.config import Config, load_config_from_paths |
50 | 50 | from sqlmesh.core.console import Console, get_console |
51 | 51 | from sqlmesh.core.context_diff import ContextDiff |
52 | | -from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse |
| 52 | +from sqlmesh.core.dialect import ( |
| 53 | + format_model_expressions, |
| 54 | + normalize_model_name, |
| 55 | + pandas_to_sql, |
| 56 | + parse, |
| 57 | +) |
53 | 58 | from sqlmesh.core.engine_adapter import EngineAdapter |
54 | 59 | from sqlmesh.core.environment import Environment |
55 | 60 | from sqlmesh.core.loader import Loader, SqlMeshLoader, update_model_schemas |
@@ -277,13 +282,13 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: |
277 | 282 | Returns: |
278 | 283 | A new instance of the updated or inserted model. |
279 | 284 | """ |
280 | | - if isinstance(model, str): |
281 | | - model = self._models[model] |
| 285 | + model = self.get_model(model, raise_if_missing=True) |
| 286 | + path = model._path |
282 | 287 |
|
283 | | - path = model._path # type: ignore |
284 | 288 | # model.copy() can't be used here due to a cached state that can be a part of a model instance. |
285 | 289 | model = t.cast(Model, type(model)(**{**t.cast(Model, model).dict(), **kwargs})) |
286 | 290 | model._path = path |
| 291 | + |
287 | 292 | self._models.update({model.name: model}) |
288 | 293 |
|
289 | 294 | self._add_model_to_dag(model) |
@@ -409,14 +414,16 @@ def get_model( |
409 | 414 | The expected model. |
410 | 415 | """ |
411 | 416 | if isinstance(model_or_snapshot, str): |
412 | | - model = self._models.get(model_or_snapshot) |
| 417 | + normalized_name = normalize_model_name(model_or_snapshot, dialect=self.config.dialect) |
| 418 | + model = self._models.get(normalized_name) |
413 | 419 | elif isinstance(model_or_snapshot, Snapshot): |
414 | 420 | model = model_or_snapshot.model |
415 | 421 | else: |
416 | 422 | model = model_or_snapshot |
417 | 423 |
|
418 | 424 | if raise_if_missing and not model: |
419 | 425 | raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'") |
| 426 | + |
420 | 427 | return model |
421 | 428 |
|
422 | 429 | @t.overload |
@@ -444,7 +451,8 @@ def get_snapshot( |
444 | 451 | The expected snapshot. |
445 | 452 | """ |
446 | 453 | if isinstance(model_or_snapshot, str): |
447 | | - snapshot = self.snapshots.get(model_or_snapshot) |
| 454 | + normalized_name = normalize_model_name(model_or_snapshot, dialect=self.config.dialect) |
| 455 | + snapshot = self.snapshots.get(normalized_name) |
448 | 456 | elif isinstance(model_or_snapshot, Snapshot): |
449 | 457 | snapshot = model_or_snapshot |
450 | 458 | else: |
@@ -908,7 +916,9 @@ def audit( |
908 | 916 | """ |
909 | 917 |
|
910 | 918 | snapshots = ( |
911 | | - [self.snapshots[model] for model in models] if models else self.snapshots.values() |
| 919 | + [self.get_snapshot(model, raise_if_missing=True) for model in models] |
| 920 | + if models |
| 921 | + else self.snapshots.values() |
912 | 922 | ) |
913 | 923 |
|
914 | 924 | num_audits = sum(len(snapshot.model.audits) for snapshot in snapshots) |
@@ -942,6 +952,7 @@ def audit( |
942 | 952 | ) |
943 | 953 | self.console.log_status_update(f"Got {error.count} results, expected 0.") |
944 | 954 | self.console.show_sql(f"{error.query}") |
| 955 | + |
945 | 956 | self.console.log_status_update("Done.") |
946 | 957 |
|
947 | 958 | def migrate(self) -> None: |
|
0 commit comments