|
4 | 4 | import sys |
5 | 5 | import typing as t |
6 | 6 |
|
7 | | -from pydantic import Field |
| 7 | +from pydantic import Field, root_validator |
8 | 8 |
|
9 | 9 | from sqlmesh.core import engine_adapter |
10 | 10 | from sqlmesh.core.config.base import BaseConfig |
|
13 | 13 | http_headers_validator, |
14 | 14 | ) |
15 | 15 | from sqlmesh.core.engine_adapter import EngineAdapter |
| 16 | +from sqlmesh.utils.errors import ConfigError |
16 | 17 |
|
17 | 18 | if sys.version_info >= (3, 9): |
18 | 19 | from typing import Annotated, Literal |
@@ -89,31 +90,45 @@ class SnowflakeConnectionConfig(_ConnectionConfig): |
89 | 90 | """Configuration for the Snowflake connection. |
90 | 91 |
|
91 | 92 | Args: |
| 93 | + account: The Snowflake account name. |
92 | 94 | user: The Snowflake username. |
93 | 95 | password: The Snowflake password. |
94 | | - account: The Snowflake account name. |
95 | 96 | warehouse: The optional warehouse name. |
96 | 97 | database: The optional database name. |
97 | 98 | role: The optional role name. |
98 | 99 | concurrent_tasks: The maximum number of tasks that can use this connection concurrently. |
| 100 | + authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). |
| 101 | + Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 |
99 | 102 | """ |
100 | 103 |
|
101 | | - user: str |
102 | | - password: str |
103 | 104 | account: str |
| 105 | + user: t.Optional[str] |
| 106 | + password: t.Optional[str] |
104 | 107 | warehouse: t.Optional[str] |
105 | 108 | database: t.Optional[str] |
106 | 109 | role: t.Optional[str] |
| 110 | + authenticator: t.Optional[str] |
107 | 111 |
|
108 | 112 | concurrent_tasks: int = 4 |
109 | 113 |
|
110 | 114 | type_: Literal["snowflake"] = Field(alias="type", default="snowflake") |
111 | 115 |
|
112 | 116 | _concurrent_tasks_validator = concurrent_tasks_validator |
113 | 117 |
|
| 118 | + @root_validator() |
| 119 | + def _validate_authenticator( |
| 120 | + cls, fields: t.Dict[str, t.Optional[str]] |
| 121 | + ) -> t.Dict[str, t.Optional[str]]: |
| 122 | + auth = fields.get("authenticator") |
| 123 | + user = fields.get("user") |
| 124 | + password = fields.get("password") |
| 125 | + if not auth and (not user or not password): |
| 126 | + raise ConfigError("User and password must be provided if using default authentication") |
| 127 | + return fields |
| 128 | + |
114 | 129 | @property |
115 | 130 | def _connection_kwargs_keys(self) -> t.Set[str]: |
116 | | - return {"user", "password", "account", "warehouse", "database", "role"} |
| 131 | + return {"user", "password", "account", "warehouse", "database", "role", "authenticator"} |
117 | 132 |
|
118 | 133 | @property |
119 | 134 | def _engine_adapter(self) -> t.Type[EngineAdapter]: |
|
0 commit comments