Skip to content

Commit aee881a

Browse files
authored
fix: allow subclassing of config again (#4210)
1 parent 692ffff commit aee881a

2 files changed

Lines changed: 38 additions & 34 deletions

File tree

sqlmesh/core/config/root.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import zlib
77

88
from pydantic import Field
9+
from pydantic.functional_validators import BeforeValidator
910
from sqlglot import exp
1011
from sqlglot.helper import first
1112
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
@@ -43,11 +44,42 @@
4344
from sqlmesh.core.user import User
4445
from sqlmesh.utils.date import to_timestamp, now
4546
from sqlmesh.utils.errors import ConfigError
46-
from sqlmesh.utils.pydantic import model_validator, field_validator
47+
from sqlmesh.utils.pydantic import model_validator
48+
49+
50+
def validate_no_past_ttl(v: str) -> str:
51+
current_time = now()
52+
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
53+
raise ValueError(
54+
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
55+
)
56+
return v
57+
58+
59+
def gateways_ensure_dict(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
60+
try:
61+
if not isinstance(value, GatewayConfig):
62+
GatewayConfig.parse_obj(value)
63+
return {"": value}
64+
except Exception:
65+
return value
66+
67+
68+
def validate_regex_key_dict(value: t.Dict[str | re.Pattern, t.Any]) -> t.Dict[re.Pattern, t.Any]:
69+
return compile_regex_mapping(value)
70+
4771

4872
if t.TYPE_CHECKING:
4973
from sqlmesh.core._typing import Self
5074

75+
NoPastTTLString = str
76+
GatewayDict = t.Dict[str, GatewayConfig]
77+
RegexKeyDict = t.Dict[re.Pattern, str]
78+
else:
79+
NoPastTTLString = t.Annotated[str, BeforeValidator(validate_no_past_ttl)]
80+
GatewayDict = t.Annotated[t.Dict[str, GatewayConfig], BeforeValidator(gateways_ensure_dict)]
81+
RegexKeyDict = t.Annotated[t.Dict[re.Pattern, str], BeforeValidator(validate_regex_key_dict)]
82+
5183

5284
class Config(BaseConfig):
5385
"""An object used by a Context to configure your SQLMesh project.
@@ -89,7 +121,7 @@ class Config(BaseConfig):
89121
after_all: SQL statements or macros to be executed at the end of the `sqlmesh plan` and `sqlmesh run` commands.
90122
"""
91123

92-
gateways: t.Dict[str, GatewayConfig] = {"": GatewayConfig()}
124+
gateways: GatewayDict = {"": GatewayConfig()}
93125
default_connection: SerializableConnectionConfig = DuckDBConnectionConfig()
94126
default_test_connection_: t.Optional[SerializableConnectionConfig] = Field(
95127
default=None, alias="default_test_connection"
@@ -98,8 +130,8 @@ class Config(BaseConfig):
98130
default_gateway: str = ""
99131
notification_targets: t.List[NotificationTarget] = []
100132
project: str = ""
101-
snapshot_ttl: str = c.DEFAULT_SNAPSHOT_TTL
102-
environment_ttl: t.Optional[str] = c.DEFAULT_ENVIRONMENT_TTL
133+
snapshot_ttl: NoPastTTLString = c.DEFAULT_SNAPSHOT_TTL
134+
environment_ttl: t.Optional[NoPastTTLString] = c.DEFAULT_ENVIRONMENT_TTL
103135
ignore_patterns: t.List[str] = c.IGNORE_PATTERNS
104136
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT
105137
users: t.List[User] = []
@@ -109,12 +141,12 @@ class Config(BaseConfig):
109141
loader_kwargs: t.Dict[str, t.Any] = {}
110142
env_vars: t.Dict[str, str] = {}
111143
username: str = ""
112-
physical_schema_mapping: t.Dict[re.Pattern, str] = {}
144+
physical_schema_mapping: RegexKeyDict = {}
113145
environment_suffix_target: EnvironmentSuffixTarget = Field(
114146
default=EnvironmentSuffixTarget.default
115147
)
116148
gateway_managed_virtual_layer: bool = False
117-
environment_catalog_mapping: t.Dict[re.Pattern, str] = {}
149+
environment_catalog_mapping: RegexKeyDict = {}
118150
default_target_environment: str = c.PROD
119151
log_limit: int = c.DEFAULT_LOG_LIMIT
120152
cicd_bot: t.Optional[CICDBotConfig] = None
@@ -155,33 +187,6 @@ class Config(BaseConfig):
155187
_scheduler_config_validator = scheduler_config_validator # type: ignore
156188
_variables_validator = variables_validator
157189

158-
@field_validator("gateways", mode="before")
159-
@classmethod
160-
def _gateways_ensure_dict(cls, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
161-
try:
162-
if not isinstance(value, GatewayConfig):
163-
GatewayConfig.parse_obj(value)
164-
return {"": value}
165-
except Exception:
166-
return value
167-
168-
@field_validator("environment_catalog_mapping", "physical_schema_mapping", mode="before")
169-
@classmethod
170-
def _validate_regex_keys(
171-
cls, value: t.Dict[str | re.Pattern, t.Any]
172-
) -> t.Dict[re.Pattern, t.Any]:
173-
return compile_regex_mapping(value)
174-
175-
@field_validator("snapshot_ttl", "environment_ttl", mode="before")
176-
@classmethod
177-
def validate_no_past_ttl(cls, v: str) -> str:
178-
current_time = now()
179-
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
180-
raise ValueError(
181-
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
182-
)
183-
return v
184-
185190
@model_validator(mode="before")
186191
def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
187192
if not isinstance(data, dict):

tests/core/test_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,6 @@ class TestConfig(DuckDBConnectionConfig):
976976
TestConfig()
977977

978978

979-
# @pytest.mark.isolated
980979
def test_config_subclassing() -> None:
981980
class ConfigSubclass(Config): ...
982981

0 commit comments

Comments
 (0)