Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"requests",
"rich[jupyter]",
"ruamel.yaml",
"sqlglot[rs]~=27.24.2",
"sqlglot[rs]~=27.26.0",
"tenacity",
"time-machine",
"json-stream"
Expand Down
22 changes: 17 additions & 5 deletions sqlmesh/dbt/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
import logging

from sqlglot import exp, parse_one
from sqlglot.helper import ensure_list
Expand All @@ -9,6 +10,8 @@
from sqlmesh.utils.conversions import ensure_bool
from sqlmesh.utils.pydantic import field_validator

logger = logging.getLogger(__name__)


def yaml_to_columns(
yaml: t.Dict[str, ColumnConfig] | t.List[t.Dict[str, ColumnConfig]],
Expand All @@ -31,11 +34,20 @@ def column_types_to_sqlmesh(
Returns:
A dict of column name to exp.DataType
"""
return {
name: parse_one(column.data_type, into=exp.DataType, dialect=dialect or "")
for name, column in columns.items()
if column.enabled and column.data_type
}
col_types_to_sqlmesh: t.Dict[str, exp.DataType] = {}
for name, column in columns.items():
if column.enabled and column.data_type:
column_def = parse_one(
f"{name} {column.data_type}", into=exp.ColumnDef, dialect=dialect or ""
)
if column_def.args.get("constraints"):
logger.warning(
f"Ignoring unsupported constraints for column '{name}' with definition '{column.data_type}'."
)
kind = column_def.kind
if kind:
col_types_to_sqlmesh[name] = kind
return col_types_to_sqlmesh


def column_descriptions_to_sqlmesh(columns: t.Dict[str, ColumnConfig]) -> t.Dict[str, str]:
Expand Down
23 changes: 23 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,29 @@ def test_seed_column_types():
sqlmesh_seed = seed.to_sqlmesh(context)
assert sqlmesh_seed.columns_to_types == expected_column_types

seed = SeedConfig(
name="foo",
package="package",
path=Path("examples/sushi_dbt/seeds/waiter_names.csv"),
column_types={
"id": "TEXT",
"name": "TEXT NOT NULL",
},
quote_columns=True,
)

expected_column_types = {
"id": exp.DataType.build("text"),
"name": exp.DataType.build("text"),
}

logger = logging.getLogger("sqlmesh.dbt.column")
with patch.object(logger, "warning") as mock_logger:
sqlmesh_seed = seed.to_sqlmesh(context)
mock_logger.assert_called_once()
assert "Ignoring unsupported constraints" in mock_logger.call_args[0][0]
assert sqlmesh_seed.columns_to_types == expected_column_types


def test_seed_column_inference(tmp_path):
seed_csv = tmp_path / "seed.csv"
Expand Down