|
12 | 12 | from sqlmesh.core.model import Model, PythonModel, SqlModel |
13 | 13 | from sqlmesh.utils.errors import SQLMeshError |
14 | 14 |
|
| 15 | +Row = t.Dict[str, t.Any] |
| 16 | + |
15 | 17 |
|
16 | 18 | class TestError(SQLMeshError): |
17 | 19 | """Test error""" |
@@ -55,7 +57,9 @@ def __init__( |
55 | 57 |
|
56 | 58 | def setUp(self) -> None: |
57 | 59 | """Load all input tables""" |
58 | | - inputs = {name: table["rows"] for name, table in self.body.get("inputs", {}).items()} |
| 60 | + inputs = { |
| 61 | + name: self._get_rows(table) for name, table in self.body.get("inputs", {}).items() |
| 62 | + } |
59 | 63 |
|
60 | 64 | for table, rows in inputs.items(): |
61 | 65 | df = pd.DataFrame.from_records(rows) # noqa |
@@ -128,6 +132,18 @@ def create_test( |
128 | 132 | def __str__(self) -> str: |
129 | 133 | return f"{self.test_name} ({self.path}:{self.body.lc.line})" # type: ignore |
130 | 134 |
|
| 135 | + def _get_rows(self, table: list[Row] | dict[str, list[Row]]) -> list[Row]: |
| 136 | + """Get a list of rows for input and output table data. |
| 137 | +
|
| 138 | + Input and output table data might be a list of rows or it might be a dictionary |
| 139 | + with a "rows" key. |
| 140 | + """ |
| 141 | + if isinstance(table, dict): |
| 142 | + if "rows" not in table: |
| 143 | + _raise_error("Incomplete test, missing row data for table", self.path) |
| 144 | + return table["rows"] |
| 145 | + return table |
| 146 | + |
131 | 147 |
|
132 | 148 | class SqlModelTest(ModelTest): |
133 | 149 | def __init__( |
@@ -180,16 +196,16 @@ def test_ctes(self) -> None: |
180 | 196 | cte_query = self.ctes[cte_name].this |
181 | 197 | for alias, cte in self.ctes.items(): |
182 | 198 | cte_query = cte_query.with_(alias, cte.this) |
183 | | - expected_df = pd.DataFrame.from_records(value["rows"]) |
| 199 | + expected_df = pd.DataFrame.from_records(self._get_rows(value)) |
184 | 200 | actual_df = self.execute(cte_query) |
185 | 201 | self.assert_equal(expected_df, actual_df) |
186 | 202 |
|
187 | 203 | def runTest(self) -> None: |
188 | 204 | self.test_ctes() |
189 | 205 |
|
190 | 206 | # Test model query |
191 | | - if "rows" in self.body["outputs"].get("query", {}): |
192 | | - expected_df = pd.DataFrame.from_records(self.body["outputs"]["query"]["rows"]) |
| 207 | + if "query" in self.body["outputs"]: |
| 208 | + expected_df = pd.DataFrame.from_records(self._get_rows(self.body["outputs"]["query"])) |
193 | 209 | actual_df = self.execute(self.query) |
194 | 210 | self.assert_equal(expected_df, actual_df) |
195 | 211 |
|
@@ -224,8 +240,8 @@ def __init__( |
224 | 240 | ) |
225 | 241 |
|
226 | 242 | def runTest(self) -> None: |
227 | | - if "rows" in self.body["outputs"].get("query", {}): |
228 | | - expected_df = pd.DataFrame.from_records(self.body["outputs"]["query"]["rows"]) |
| 243 | + if "query" in self.body["outputs"]: |
| 244 | + expected_df = pd.DataFrame.from_records(self._get_rows(self.body["outputs"]["query"])) |
229 | 245 | actual_df = next( |
230 | 246 | self.model.render( |
231 | 247 | context=self.context, |
|
0 commit comments