-
Notifications
You must be signed in to change notification settings - Fork 373
Expand file tree
/
Copy pathdefinition.py
More file actions
1062 lines (867 loc) · 41.2 KB
/
definition.py
File metadata and controls
1062 lines (867 loc) · 41.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import sys
import datetime
import threading
import typing as t
import unittest
from collections import Counter
from contextlib import nullcontext, contextmanager, AbstractContextManager
from itertools import chain
from pathlib import Path
from unittest.mock import patch
from io import StringIO
from sqlglot import Dialect, exp
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlmesh.core import constants as c
from sqlmesh.core.dialect import normalize_model_name, schema_
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.model import Model, PythonModel, SqlModel
from sqlmesh.utils import UniqueKeyDict, random_id, type_is_known, yaml
from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime
from sqlmesh.utils.errors import ConfigError, TestError
from sqlmesh.utils.yaml import load as yaml_load
from sqlmesh.utils import Verbosity
from sqlmesh.utils.rich import df_to_table
if t.TYPE_CHECKING:
import pandas as pd
from sqlglot.dialects.dialect import DialectType
Row = t.Dict[str, t.Any]
TIME_KWARG_KEYS = {
"start",
"end",
"execution_time",
"latest",
# all built-in datetime macro var names
*date_dict(execution_time="1970-01-01", start="1970-01-01", end="1970-01-01").keys(),
}
class ModelTest(unittest.TestCase):
__test__ = False
CONCURRENT_RENDER_LOCK = threading.Lock()
def __init__(
self,
body: t.Dict[str, t.Any],
test_name: str,
model: Model,
models: UniqueKeyDict[str, Model],
engine_adapter: EngineAdapter,
dialect: str | None = None,
path: Path | None = None,
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> None:
"""ModelTest encapsulates a unit test for a model.
Args:
body: A dictionary that contains test metadata like inputs and outputs.
test_name: The name of the test.
model: The model that is being tested.
models: All models to use for expansion and mapping of physical locations.
engine_adapter: The engine adapter to use.
dialect: The models' dialect, used for normalization purposes.
path: An optional path to the test definition yaml file.
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
"""
self.body = body
self.test_name = test_name
self.model = model
self.models = models
self.engine_adapter = engine_adapter
self.path = path
self.preserve_fixtures = preserve_fixtures
self.default_catalog = default_catalog
self.dialect = dialect
self.concurrency = concurrency
self.verbosity = verbosity
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
self._normalized_column_name_cache: t.Dict[str, str] = {}
self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {}
self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect)
self._validate_and_normalize_test()
if self.engine_adapter.default_catalog:
self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers(
exp.parse_identifier(
self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect
),
dialect=self._test_adapter_dialect,
)
else:
self._fixture_catalog = None
# The test schema name is randomized to avoid concurrency issues,
# unless a schema is provided in the unit tests's body
self._fixture_schema = exp.parse_identifier(
self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}"
)
self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog)
self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS
self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
if self._execution_time:
# Normalizes the execution time by converting it into UTC timezone
self._execution_time = str(to_datetime(self._execution_time))
# When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
if self._execution_time:
exec_time = exp.Literal.string(self._execution_time)
self._transforms = {
**self._transforms,
exp.CurrentDate: lambda self, _: self.sql(
exp.cast(exec_time, "date", dialect=dialect)
),
exp.CurrentDatetime: lambda self, _: self.sql(
exp.cast(exec_time, "datetime", dialect=dialect)
),
exp.CurrentTime: lambda self, _: self.sql(
exp.cast(exec_time, "time", dialect=dialect)
),
exp.CurrentTimestamp: lambda self, _: self.sql(
exp.cast(exec_time, "timestamp", dialect=dialect)
),
}
super().__init__()
def defaultTestResult(self) -> unittest.TestResult:
from sqlmesh.core.test.result import ModelTextTestResult
return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
def shortDescription(self) -> t.Optional[str]:
return self.body.get("description")
def setUp(self) -> None:
"""Load all input tables"""
import pandas as pd
import numpy as np
self.engine_adapter.create_schema(self._qualified_fixture_schema)
for name, values in self.body.get("inputs", {}).items():
all_types_are_known = False
columns_to_known_types: t.Dict[str, exp.DataType] = {}
model = self.models.get(name)
if model:
inferred_columns_to_types = model.columns_to_types or {}
columns_to_known_types = {
c: t for c, t in inferred_columns_to_types.items() if type_is_known(t)
}
all_types_are_known = bool(inferred_columns_to_types) and (
len(columns_to_known_types) == len(inferred_columns_to_types)
)
# Types specified in the test will override the corresponding inferred ones
columns_to_known_types.update(values.get("columns", {}))
rows = values.get("rows")
if not all_types_are_known and rows:
for col, value in rows[0].items():
if col not in columns_to_known_types:
v_type = annotate_types(exp.convert(value)).type or type(value).__name__
v_type = exp.maybe_parse(
v_type, into=exp.DataType, dialect=self._test_adapter_dialect
)
if not type_is_known(v_type):
_raise_error(
f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be "
"mitigated by casting the column in the model definition, setting its type in "
"external_models.yaml if it's an external model, setting the model's 'columns' property, "
"or setting its 'columns' mapping in the test itself",
self.path,
)
columns_to_known_types[col] = v_type
if rows is None:
query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns(
values["query"], columns_to_known_types
)
if columns_to_known_types:
columns_to_known_types = {
col: columns_to_known_types[col] for col in query_or_df.named_selects
}
else:
query_or_df = self._create_df(values, columns=columns_to_known_types)
# Convert NaN/NaT values to None if DataFrame
if isinstance(query_or_df, pd.DataFrame):
query_or_df = query_or_df.replace({np.nan: None})
self.engine_adapter.create_view(
self._test_fixture_table(name), query_or_df, columns_to_known_types
)
def tearDown(self) -> None:
"""Drop all fixture tables."""
if not self.preserve_fixtures:
self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True)
def assert_equal(
self,
expected: pd.DataFrame,
actual: pd.DataFrame,
sort: bool,
partial: t.Optional[bool] = False,
) -> None:
"""Compare two DataFrames"""
import numpy as np
import pandas as pd
from pandas.api.types import is_object_dtype
if partial:
intersection = actual[actual.columns.intersection(expected.columns)]
if len(intersection.columns) > 0:
actual = intersection
# Two astypes are necessary, pandas converts strings to times as NS,
# but if the actual is US, it doesn't take effect until the 2nd try!
actual_types = actual.dtypes.to_dict()
expected = expected.astype(actual_types, errors="ignore").astype(
actual_types, errors="ignore"
)
# The `actual` df's dtypes will almost always be pd.Timestamp for datetime values,
# but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type
# containing python `datetime.xxx` values.
#
# Pandas `object` columns result in a noop for the `astype` call above. Because any
# quoted YAML value is a string, we must manually convert the `expected` df string
# values to the correct `datetime.xxx` type.
#
# We determine the type from a single sentinel value, but since the `actual` df is
# coming from a database query, it is safe to assume that the column contains only
# a single type.
object_sentinel_values = {
col: actual[col][0]
for col in actual_types
if is_object_dtype(actual_types[col]) and len(actual[col]) != 0
}
for col, value in object_sentinel_values.items():
try:
# can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525
if type(value) is datetime.date:
expected[col] = pd.to_datetime(expected[col]).dt.date
elif type(value) is datetime.time:
expected[col] = pd.to_datetime(expected[col]).dt.time
elif type(value) is datetime.datetime:
expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime()
except Exception as e:
from sqlmesh.core.console import get_console
get_console().log_warning(
f"Failed to convert expected value for {col} into `datetime` "
f"for unit test '{str(self)}'. {str(e)}."
)
actual = actual.replace({np.nan: None})
expected = expected.replace({np.nan: None})
# We define this here to avoid a top-level import of numpy and pandas
DATETIME_TYPES = (
datetime.datetime,
datetime.date,
datetime.time,
np.datetime64,
pd.Timestamp,
)
def _to_hashable(x: t.Any) -> t.Any:
if isinstance(x, (list, np.ndarray)):
return tuple(_to_hashable(v) for v in x)
if isinstance(x, dict):
return tuple((k, _to_hashable(v)) for k, v in x.items())
return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x
actual = actual.apply(lambda col: col.map(_to_hashable))
expected = expected.apply(lambda col: col.map(_to_hashable))
if sort:
actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
try:
pd.testing.assert_frame_equal(
expected,
actual,
check_dtype=False,
check_datetimelike_compat=True,
check_like=True, # Ignore column order
)
except AssertionError as e:
# There are 2 concepts at play here:
# 1. The Exception args will contain the error message plus the diff dataframe table stringified
# (backwards compatibility with existing tests, possible to serialize/send over network etc)
# 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll
# be surfaced to the user through Console for better UX (versus stringified dataframes)
#
# This is a bit of a hack, but it's a way to get the best of both worlds.
args: t.List[t.Any] = []
failed_subtest = ""
if subtest := getattr(self, "_subtest", None):
if cte := subtest.params.get("cte"):
failed_subtest = f" (CTE {cte})"
if expected.shape != actual.shape:
_raise_if_unexpected_columns(expected.columns, actual.columns)
args.append("Data mismatch (rows are different)")
missing_rows = _row_difference(expected, actual)
if not missing_rows.empty:
args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
unexpected_rows = _row_difference(actual, expected)
if not unexpected_rows.empty:
args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
else:
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}")
diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
if self.verbosity == Verbosity.DEFAULT:
args.extend(
df_to_table(f"Data mismatch{failed_subtest}", df)
for df in _split_df_by_column_pairs(diff)
)
else:
from pandas import MultiIndex
levels = t.cast(MultiIndex, diff.columns).levels[0]
for col in levels:
col_diff = diff[col]
if not col_diff.empty:
table = df_to_table(
f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
col_diff,
)
args.append(table)
e.args = (*args,)
raise e
def runTest(self) -> None:
raise NotImplementedError
def path_relative_to(self, other: Path) -> Path | None:
"""Compute a version of this test's path relative to the `other` path"""
return self.path.relative_to(other) if self.path else None
@staticmethod
def create_test(
body: t.Dict[str, t.Any],
test_name: str,
models: UniqueKeyDict[str, Model],
engine_adapter: EngineAdapter,
dialect: str | None,
path: Path | None,
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> t.Optional[ModelTest]:
"""Create a SqlModelTest or a PythonModelTest.
Args:
body: A dictionary that contains test metadata like inputs and outputs.
test_name: The name of the test.
models: All models to use for expansion and mapping of physical locations.
engine_adapter: The engine adapter to use.
dialect: The models' dialect, used for normalization purposes.
path: An optional path to the test definition yaml file.
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
"""
name = body.get("model")
if name is None:
_raise_error("Missing required 'model' field", path)
name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect)
model = models.get(name)
if not model:
from sqlmesh.core.console import get_console
get_console().log_warning(
f"Model '{name}' was not found{' at ' + str(path) if path else ''}"
)
return None
if isinstance(model, SqlModel):
test_type: t.Type[ModelTest] = SqlModelTest
elif isinstance(model, PythonModel):
test_type = PythonModelTest
else:
_raise_error(f"Model '{name}' is an unsupported model type for testing", path)
try:
return test_type(
body,
test_name,
t.cast(Model, model),
models,
engine_adapter,
dialect,
path,
preserve_fixtures,
default_catalog,
concurrency,
verbosity,
)
except Exception as e:
raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
def __str__(self) -> str:
return f"{self.test_name} ({self.path})"
def _validate_and_normalize_test(self) -> None:
inputs = self.body.get("inputs")
outputs = self.body.get("outputs", {})
if not outputs:
_raise_error("Incomplete test, missing outputs", self.path)
ctes = outputs.get("ctes")
query = outputs.get("query")
partial = outputs.pop("partial", None)
if ctes is None and query is None:
_raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path)
def _normalize_rows(
values: t.List[Row] | t.Dict,
name: str,
partial: bool = False,
dialect: DialectType = None,
) -> t.Dict:
import pandas as pd
if not isinstance(values, dict):
values = {"rows": values}
rows = values.get("rows")
query = values.get("query")
fmt = values.get("format")
path = values.get("path")
if fmt == "csv":
csv_settings = values.get("csv_settings") or {}
rows = pd.read_csv(path or StringIO(rows), **csv_settings).to_dict(orient="records")
elif fmt in (None, "yaml"):
if path:
input_rows = yaml_load(Path(path))
rows = input_rows.get("rows") if isinstance(input_rows, dict) else input_rows
else:
_raise_error(f"Unsupported data format '{fmt}' for '{name}'", self.path)
if query is not None:
if rows is not None:
_raise_error(
f"Invalid test, cannot set both 'query' and 'rows' for '{name}'", self.path
)
# We parse the user-supplied query using the testing adapter dialect, but we
# normalize its identifiers according to the model's dialect, so that, e.g.,
# the projection names match those in its `columns_to_types` field
values["query"] = normalize_identifiers(
exp.maybe_parse(query, dialect=self._test_adapter_dialect), dialect=dialect
)
return values
if rows is None:
_raise_error(f"Incomplete test, missing row data for '{name}'", self.path)
assert isinstance(rows, list)
values["rows"] = [
{self._normalize_column_name(column): value for column, value in row.items()}
for row in rows
]
if partial:
values["partial"] = True
return values
def _normalize_sources(
sources: t.Dict, partial: bool = False, with_default_catalog: bool = True
) -> t.Dict:
normalized_sources = {}
for name, values in sources.items():
normalized_name = self._normalize_model_name(
name, with_default_catalog=with_default_catalog
)
model = self.models.get(normalized_name)
dialect = model.dialect if model else self.dialect
normalized_sources[normalized_name] = _normalize_rows(
values, name, partial=partial, dialect=dialect
)
return normalized_sources
normalized_model_name = self._normalize_model_name(self.body["model"])
self.body["model"] = normalized_model_name
if inputs:
inputs = _normalize_sources(inputs)
for name, values in inputs.items():
columns = values.get("columns")
if columns is None:
continue
if not isinstance(columns, dict):
_raise_error(
f"Invalid 'columns' value for model '{name}', expected a mapping name -> type",
self.path,
)
values["columns"] = {
self._normalize_column_name(c): exp.DataType.build(
t, dialect=self._test_adapter_dialect
)
for c, t in columns.items()
}
for depends_on in self.model.depends_on:
if depends_on not in inputs:
_raise_error(f"Incomplete test, missing input model '{depends_on}'", self.path)
if self.model.depends_on_self and normalized_model_name not in inputs:
inputs[normalized_model_name] = {"rows": []}
self.body["inputs"] = inputs
if ctes:
outputs["ctes"] = _normalize_sources(ctes, partial=partial, with_default_catalog=False)
if query or query == []:
outputs["query"] = _normalize_rows(
query, self.model.name, partial=partial, dialect=self.model.dialect
)
def _test_fixture_table(self, name: str) -> exp.Table:
table = self._fixture_table_cache.get(name)
if not table:
table = exp.to_table(name, dialect=self._test_adapter_dialect)
# We change the table path below, so this ensures there are no name clashes
table.this.set("this", "__".join(part.name for part in table.parts))
table.set("db", self._fixture_schema.copy())
if self._fixture_catalog:
table.set("catalog", self._fixture_catalog.copy())
self._fixture_table_cache[name] = table
return table
def _normalize_model_name(self, name: str, with_default_catalog: bool = True) -> str:
normalized_name = self._normalized_model_name_cache.get((name, with_default_catalog))
if normalized_name is None:
default_catalog = self.default_catalog if with_default_catalog else None
normalized_name = normalize_model_name(
name, default_catalog=default_catalog, dialect=self.dialect
)
self._normalized_model_name_cache[(name, with_default_catalog)] = normalized_name
return normalized_name
def _normalize_column_name(self, name: str) -> str:
normalized_name = self._normalized_column_name_cache.get(name)
if normalized_name is None:
normalized_name = normalize_identifiers(name, dialect=self.dialect).name
self._normalized_column_name_cache[name] = normalized_name
return normalized_name
@contextmanager
def _concurrent_render_context(self) -> t.Iterator[None]:
"""
Context manager that ensures that the tests are executed safely in a concurrent environment.
This is needed in case `execution_time` is set, as we'd then have to:
- Freeze time through `time_machine` (not thread safe)
- Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
"""
import time_machine
lock_ctx: AbstractContextManager = (
self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext()
)
time_ctx: AbstractContextManager = nullcontext()
dialect_patch_ctx: AbstractContextManager = nullcontext()
if self._execution_time:
time_ctx = time_machine.travel(self._execution_time, tick=False)
dialect_patch_ctx = patch.dict(
self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
)
with lock_ctx, time_ctx, dialect_patch_ctx:
yield
def _execute(self, query: exp.Query | str) -> pd.DataFrame:
"""Executes the given query using the testing engine adapter and returns a DataFrame."""
return self.engine_adapter.fetchdf(query)
def _create_df(
self,
values: t.Dict[str, t.Any],
columns: t.Optional[t.Collection] = None,
partial: t.Optional[bool] = False,
) -> pd.DataFrame:
import pandas as pd
query = values.get("query")
if query:
if not partial:
query = self._add_missing_columns(query, columns)
return self._execute(query)
rows = values["rows"]
columns_str: t.Optional[t.List[str]] = None
if columns:
columns_str = [str(c) for c in columns]
referenced_columns = list(dict.fromkeys(col for row in rows for col in row))
_raise_if_unexpected_columns(columns, referenced_columns)
if partial:
columns_str = [c for c in columns_str if c in referenced_columns]
return pd.DataFrame.from_records(rows, columns=columns_str)
def _add_missing_columns(
self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None
) -> exp.Query:
if not all_columns or query.is_star:
return query
query_columns = set(query.named_selects)
missing_columns = [col for col in all_columns if col not in query_columns]
if missing_columns:
query.select(*[exp.null().as_(col) for col in missing_columns], copy=False)
return query
class SqlModelTest(ModelTest):
def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False) -> None:
"""Run CTE queries and compare output to expected output"""
for cte_name, values in self.body["outputs"].get("ctes", {}).items():
with self.subTest(cte=cte_name):
if cte_name not in ctes:
_raise_error(
f"No CTE named {cte_name} found in model {self.model.name}", self.path
)
cte_query = ctes[cte_name].this
sort = cte_query.args.get("order") is None
partial = values.get("partial")
cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name)
for alias, cte in ctes.items():
cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
with self._concurrent_render_context():
# Similar to the model's query, we render the CTE query under the locked context
# so that the execution (fetchdf) can continue concurrently between the threads
sql = cte_query.sql(
self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql
)
actual = self._execute(sql)
expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
self.assert_equal(expected, actual, sort=sort, partial=partial)
def runTest(self) -> None:
with self._concurrent_render_context():
# Render the model's query and generate the SQL under the locked context so that
# execution (fetchdf) can continue concurrently between the threads
query = self._render_model_query()
sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
with_clause = query.args.get("with")
if with_clause:
self.test_ctes(
{
self._normalize_model_name(cte.alias, with_default_catalog=False): cte
for cte in query.ctes
},
recursive=with_clause.recursive,
)
values = self.body["outputs"].get("query")
if values is not None:
partial = values.get("partial")
sort = query.args.get("order") is None
actual = self._execute(sql)
expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
self.assert_equal(expected, actual, sort=sort, partial=partial)
def _render_model_query(self) -> exp.Query:
variables = self.body.get("vars", {}).copy()
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
query = self.model.render_query_or_raise(
**time_kwargs,
variables=variables,
engine_adapter=self.engine_adapter,
table_mapping={
name: self._test_fixture_table(name).sql() for name in self.body.get("inputs", {})
},
runtime_stage=RuntimeStage.TESTING,
)
return query
class PythonModelTest(ModelTest):
def __init__(
self,
body: t.Dict[str, t.Any],
test_name: str,
model: Model,
models: UniqueKeyDict[str, Model],
engine_adapter: EngineAdapter,
dialect: str | None = None,
path: Path | None = None,
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> None:
"""PythonModelTest encapsulates a unit test for a Python model.
Args:
body: A dictionary that contains test metadata like inputs and outputs.
test_name: The name of the test.
model: The Python model that is being tested.
models: All models to use for expansion and mapping of physical locations.
engine_adapter: The engine adapter to use.
dialect: The models' dialect, used for normalization purposes.
path: An optional path to the test definition yaml file.
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
"""
from sqlmesh.core.test.context import TestExecutionContext
super().__init__(
body,
test_name,
model,
models,
engine_adapter,
dialect,
path,
preserve_fixtures,
default_catalog,
concurrency,
verbosity,
)
self.context = TestExecutionContext(
engine_adapter=engine_adapter,
models=models,
test=self,
default_dialect=dialect,
default_catalog=default_catalog,
)
def runTest(self) -> None:
values = self.body["outputs"].get("query")
if values is not None:
partial = values.get("partial")
actual_df = self._execute_model()
actual_df.reset_index(drop=True, inplace=True)
expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
self.assert_equal(expected, actual_df, sort=True, partial=partial)
def _execute_model(self) -> pd.DataFrame:
"""Executes the python model and returns a DataFrame."""
import pandas as pd
with self._concurrent_render_context():
variables = self.body.get("vars", {}).copy()
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
df = next(self.model.render(context=self.context, variables=variables, **time_kwargs))
assert not isinstance(df, exp.Expression)
return df if isinstance(df, pd.DataFrame) else df.toPandas()
def generate_test(
model: Model,
input_queries: t.Dict[str, str],
models: UniqueKeyDict[str, Model],
engine_adapter: EngineAdapter,
test_engine_adapter: EngineAdapter,
project_path: Path,
overwrite: bool = False,
variables: t.Optional[t.Dict[str, str]] = None,
path: t.Optional[str] = None,
name: t.Optional[str] = None,
include_ctes: bool = False,
) -> None:
"""Generate a unit test fixture for a given model.
Args:
model: The model to test.
input_queries: Mapping of model names to queries. Each model included in this mapping
will be populated in the test based on the results of the corresponding query.
models: The context's models.
engine_adapter: The target engine adapter.
test_engine_adapter: The test engine adapter.
project_path: The path pointing to the project's root directory.
overwrite: Whether to overwrite the existing test in case of a file path collision.
When set to False, an error will be raised if there is such a collision.
variables: Key-value pairs that will define variables needed by the model.
path: The file path corresponding to the fixture, relative to the test directory.
By default, the fixture will be created under the test directory and the file name
will be inferred from the test's name.
name: The name of the test. This is inferred from the model name by default.
include_ctes: When true, CTE fixtures will also be generated.
"""
import numpy as np
test_name = name or f"test_{model.view_name}"
path = path or f"{test_name}.yaml"
extension = path.split(".")[-1].lower()
if extension not in ("yaml", "yml"):
path = f"{path}.yaml"
fixture_path = project_path / c.TESTS / path
if not overwrite and fixture_path.exists():
raise ConfigError(
f"Fixture '{fixture_path}' already exists, make sure to set --overwrite if it can be safely overwritten."
)
# ruamel.yaml does not support pandas Timestamps, so we must convert them to python
# datetime or datetime.date objects based on column type
inputs = {
dep: pandas_timestamp_to_pydatetime(
engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)),
models[dep].columns_to_types,
)
.replace({np.nan: None})
.to_dict(orient="records")
for dep, query in input_queries.items()
}
outputs: t.Dict[str, t.Any] = {"query": {}}
variables = variables or {}
test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs}
if variables:
test_body["vars"] = variables
test = ModelTest.create_test(
body=test_body.copy(),
test_name=test_name,
models=models,
engine_adapter=test_engine_adapter,
dialect=model.dialect,
path=fixture_path,
default_catalog=model.default_catalog,
)
if not test:
return
test.setUp()
if isinstance(model, SqlModel):
assert isinstance(test, SqlModelTest)
model_query = test._render_model_query()
with_clause = model_query.args.get("with")
if with_clause and include_ctes:
ctes = {}
recursive = with_clause.recursive
previous_ctes: t.List[exp.CTE] = []
for cte in model_query.ctes:
cte_query = cte.this
cte_identifier = cte.args["alias"].this
cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier)
for prev in chain(previous_ctes, [cte]):
cte_query = cte_query.with_(
prev.args["alias"].this, prev.this, recursive=recursive
)
cte_output = test._execute(cte_query)
ctes[cte.alias] = (
pandas_timestamp_to_pydatetime(
df=cte_output.apply(lambda col: col.map(_normalize_df_value)),
)
.replace({np.nan: None})
.to_dict(orient="records")
)
previous_ctes.append(cte)
if ctes:
outputs["ctes"] = ctes
output = test._execute(model_query)
else:
output = t.cast(PythonModelTest, test)._execute_model()
outputs["query"] = (
pandas_timestamp_to_pydatetime(
output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types
)
.replace({np.nan: None})
.to_dict(orient="records")
)
test.tearDown()
fixture_path.parent.mkdir(exist_ok=True, parents=True)
with open(fixture_path, "w", encoding="utf-8") as file:
yaml.dump({test_name: test_body}, file)
def _projection_identifiers(query: exp.Query) -> t.List[str | exp.Identifier]:
identifiers: t.List[str | exp.Identifier] = []
for select in query.selects:
if isinstance(select, exp.Alias):
identifiers.append(select.args["alias"])
elif isinstance(select, exp.Column):
identifiers.append(select.this)
else:
identifiers.append(select.output_name)
return identifiers
def _raise_if_unexpected_columns(
expected_cols: t.Collection[str], actual_cols: t.Collection[str]
) -> None:
unique_expected_cols = set(expected_cols)
unknown_cols = [col for col in actual_cols if col not in unique_expected_cols]
if unknown_cols:
expected = f"Expected column(s): {', '.join(list(expected_cols))}\n"
unknown = f"Unknown column(s): {', '.join(unknown_cols)}"
_raise_error(f"Detected unknown column(s)\n\n{expected}{unknown}")
def _row_difference(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
"""Returns all rows in `left` that don't appear in `right`."""
import numpy as np
import pandas as pd
rows_missing_from_right = []
# `None` replaces `np.nan` because `np.nan != np.nan` and this would affect the mapping lookup
right_row_count: t.MutableMapping[t.Tuple, int] = Counter(
right.replace({np.nan: None}).itertuples(index=False, name=None)
)
for left_row in left.replace({np.nan: None}).itertuples(index=False):
left_row_tuple = tuple(left_row)
if right_row_count[left_row_tuple] <= 0:
rows_missing_from_right.append(left_row)
else:
right_row_count[left_row_tuple] -= 1