Skip to content

Commit 692ffff

Browse files
authored
fix: allow subclassing of config again (#4209)
1 parent 9952bbb commit 692ffff

2 files changed

Lines changed: 40 additions & 40 deletions

File tree

sqlmesh/core/config/root.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from sqlmesh.core.user import User
4444
from sqlmesh.utils.date import to_timestamp, now
4545
from sqlmesh.utils.errors import ConfigError
46-
from sqlmesh.utils.pydantic import model_validator
46+
from sqlmesh.utils.pydantic import model_validator, field_validator
4747

4848
if t.TYPE_CHECKING:
4949
from sqlmesh.core._typing import Self
@@ -89,7 +89,7 @@ class Config(BaseConfig):
8989
after_all: SQL statements or macros to be executed at the end of the `sqlmesh plan` and `sqlmesh run` commands.
9090
"""
9191

92-
gateways: GatewayDict = {"": GatewayConfig()}
92+
gateways: t.Dict[str, GatewayConfig] = {"": GatewayConfig()}
9393
default_connection: SerializableConnectionConfig = DuckDBConnectionConfig()
9494
default_test_connection_: t.Optional[SerializableConnectionConfig] = Field(
9595
default=None, alias="default_test_connection"
@@ -98,8 +98,8 @@ class Config(BaseConfig):
9898
default_gateway: str = ""
9999
notification_targets: t.List[NotificationTarget] = []
100100
project: str = ""
101-
snapshot_ttl: NoPastTTLString = c.DEFAULT_SNAPSHOT_TTL
102-
environment_ttl: t.Optional[NoPastTTLString] = c.DEFAULT_ENVIRONMENT_TTL
101+
snapshot_ttl: str = c.DEFAULT_SNAPSHOT_TTL
102+
environment_ttl: t.Optional[str] = c.DEFAULT_ENVIRONMENT_TTL
103103
ignore_patterns: t.List[str] = c.IGNORE_PATTERNS
104104
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT
105105
users: t.List[User] = []
@@ -109,12 +109,12 @@ class Config(BaseConfig):
109109
loader_kwargs: t.Dict[str, t.Any] = {}
110110
env_vars: t.Dict[str, str] = {}
111111
username: str = ""
112-
physical_schema_mapping: RegexKeyDict = {}
112+
physical_schema_mapping: t.Dict[re.Pattern, str] = {}
113113
environment_suffix_target: EnvironmentSuffixTarget = Field(
114114
default=EnvironmentSuffixTarget.default
115115
)
116116
gateway_managed_virtual_layer: bool = False
117-
environment_catalog_mapping: RegexKeyDict = {}
117+
environment_catalog_mapping: t.Dict[re.Pattern, str] = {}
118118
default_target_environment: str = c.PROD
119119
log_limit: int = c.DEFAULT_LOG_LIMIT
120120
cicd_bot: t.Optional[CICDBotConfig] = None
@@ -155,6 +155,33 @@ class Config(BaseConfig):
155155
_scheduler_config_validator = scheduler_config_validator # type: ignore
156156
_variables_validator = variables_validator
157157

158+
@field_validator("gateways", mode="before")
159+
@classmethod
160+
def _gateways_ensure_dict(cls, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
161+
try:
162+
if not isinstance(value, GatewayConfig):
163+
GatewayConfig.parse_obj(value)
164+
return {"": value}
165+
except Exception:
166+
return value
167+
168+
@field_validator("environment_catalog_mapping", "physical_schema_mapping", mode="before")
169+
@classmethod
170+
def _validate_regex_keys(
171+
cls, value: t.Dict[str | re.Pattern, t.Any]
172+
) -> t.Dict[re.Pattern, t.Any]:
173+
return compile_regex_mapping(value)
174+
175+
@field_validator("snapshot_ttl", "environment_ttl", mode="before")
176+
@classmethod
177+
def validate_no_past_ttl(cls, v: str) -> str:
178+
current_time = now()
179+
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
180+
raise ValueError(
181+
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
182+
)
183+
return v
184+
158185
@model_validator(mode="before")
159186
def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
160187
if not isinstance(data, dict):
@@ -286,37 +313,3 @@ def dialect(self) -> t.Optional[str]:
286313
@property
287314
def fingerprint(self) -> str:
288315
return str(zlib.crc32(pickle.dumps(self.dict(exclude={"loader", "notification_targets"}))))
289-
290-
291-
def validate_no_past_ttl(v: str) -> str:
292-
current_time = now()
293-
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
294-
raise ValueError(
295-
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
296-
)
297-
return v
298-
299-
300-
def gateways_ensure_dict(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
301-
try:
302-
if not isinstance(value, GatewayConfig):
303-
GatewayConfig.parse_obj(value)
304-
return {"": value}
305-
except Exception:
306-
return value
307-
308-
309-
def validate_regex_key_dict(value: t.Dict[str | re.Pattern, t.Any]) -> t.Dict[re.Pattern, t.Any]:
310-
return compile_regex_mapping(value)
311-
312-
313-
if t.TYPE_CHECKING:
314-
NoPastTTLString = str
315-
GatewayDict = t.Dict[str, GatewayConfig]
316-
RegexKeyDict = t.Dict[re.Pattern, str]
317-
else:
318-
from pydantic.functional_validators import BeforeValidator
319-
320-
NoPastTTLString = t.Annotated[str, BeforeValidator(validate_no_past_ttl)]
321-
GatewayDict = t.Annotated[t.Dict[str, GatewayConfig], BeforeValidator(gateways_ensure_dict)]
322-
RegexKeyDict = t.Annotated[t.Dict[re.Pattern, str], BeforeValidator(validate_regex_key_dict)]

tests/core/test_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,3 +974,10 @@ class TestConfig(DuckDBConnectionConfig):
974974
pass
975975

976976
TestConfig()
977+
978+
979+
# @pytest.mark.isolated
980+
def test_config_subclassing() -> None:
981+
class ConfigSubclass(Config): ...
982+
983+
ConfigSubclass()

0 commit comments

Comments
 (0)