Skip to content

Commit ef7893d

Browse files
authored
Feat: Generate per-dialect options on init project (#3733)
1 parent cf779ad commit ef7893d

2 files changed

Lines changed: 75 additions & 11 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from sqlmesh.integrations.dlt import generate_dlt_models_and_settings
99
from sqlmesh.utils.date import yesterday_ds
1010

11+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
12+
13+
14+
PRIMITIVES = (str, int, bool, float)
15+
1116

1217
class ProjectTemplate(Enum):
1318
AIRFLOW = "airflow"
@@ -23,30 +28,66 @@ def _gen_config(
2328
start: t.Optional[str],
2429
template: ProjectTemplate,
2530
) -> str:
26-
connection_settings = (
27-
settings
28-
or """ type: duckdb
31+
if not settings:
32+
connection_settings = """ type: duckdb
2933
database: db.db"""
30-
)
34+
35+
doc_link = "# Visit https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link} for more information on configuring the connection to your execution engine."
36+
engine_link = ""
37+
38+
engine = "mssql" if dialect == "tsql" else dialect
39+
40+
if engine in CONNECTION_CONFIG_TO_TYPE:
41+
required_fields = []
42+
non_required_fields = []
43+
44+
for name, field in CONNECTION_CONFIG_TO_TYPE[engine].model_fields.items():
45+
field_name = field.alias or name
46+
default_value = field.get_default()
47+
48+
if isinstance(default_value, Enum):
49+
default_value = default_value.value
50+
elif not isinstance(default_value, PRIMITIVES):
51+
default_value = None
52+
53+
required = field.is_required() or field_name == "type"
54+
option_str = (
55+
f" {'# ' if not required else ''}{field_name}: {default_value or ''}\n"
56+
)
57+
58+
if required:
59+
required_fields.append(option_str)
60+
else:
61+
non_required_fields.append(option_str)
62+
63+
connection_settings = "".join(required_fields + non_required_fields)
64+
65+
engine_link = f"/{engine}/#connection-options"
66+
67+
connection_settings = (
68+
f" {doc_link.format(engine_link=engine_link)}\n{connection_settings}"
69+
)
70+
else:
71+
connection_settings = settings
3172

3273
default_configs = {
3374
ProjectTemplate.DEFAULT: f"""gateways:
34-
local:
75+
dev:
3576
connection:
3677
{connection_settings}
3778
38-
default_gateway: local
79+
default_gateway: dev
3980
4081
model_defaults:
4182
dialect: {dialect}
4283
start: {start or yesterday_ds()}
4384
""",
4485
ProjectTemplate.AIRFLOW: f"""gateways:
45-
local:
86+
dev:
4687
connection:
47-
{connection_settings}
88+
{connection_settings}
4889
49-
default_gateway: local
90+
default_gateway: dev
5091
5192
default_scheduler:
5293
type: airflow

tests/cli/test_cli.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,12 +793,12 @@ def test_plan_dlt(runner, tmp_path):
793793
init_example_project(tmp_path, "duckdb", ProjectTemplate.DLT, "sushi")
794794

795795
expected_config = f"""gateways:
796-
local:
796+
dev:
797797
connection:
798798
type: duckdb
799799
database: {dataset_path}
800800
801-
default_gateway: local
801+
default_gateway: dev
802802
803803
model_defaults:
804804
dialect: duckdb
@@ -948,3 +948,26 @@ def test_plan_dlt(runner, tmp_path):
948948
assert dlt_sushi_twice_nested_model_path.exists()
949949
finally:
950950
remove(dataset_path)
951+
952+
953+
def test_init_project_dialects(runner, tmp_path):
954+
dialect_to_config = {
955+
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: ",
956+
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
957+
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
958+
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: \n # force_databricks_connect: \n # disable_databricks_connect: \n # disable_spark_session: ",
959+
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: ",
960+
}
961+
962+
for dialect, expected_config in dialect_to_config.items():
963+
init_example_project(tmp_path, dialect=dialect)
964+
965+
config_start = f"gateways:\n dev:\n connection:\n # Visit https://sqlmesh.readthedocs.io/en/stable/integrations/engines/{dialect}/#connection-options for more information on configuring the connection to your execution engine.\n type: {dialect}\n "
966+
config_end = f"\n\n\ndefault_gateway: dev\n\nmodel_defaults:\n dialect: {dialect}\n start: 2025-01-29\n"
967+
968+
with open(tmp_path / "config.yaml") as file:
969+
config = file.read()
970+
971+
assert config == f"{config_start}{expected_config}{config_end}"
972+
973+
remove(tmp_path / "config.yaml")

0 commit comments

Comments
 (0)