Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 4cf56db

Browse files
authored
Merge pull request #491 from pik94/ilia-dx-564-grab-credentials-from-profilesyml
Grab credentials from profiles.yml
2 parents acd1643 + e391403 commit 4cf56db

5 files changed

Lines changed: 421 additions & 294 deletions

File tree

data_diff/cloud/data_source.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import enum
21
import time
3-
from typing import List, Optional
2+
from typing import List, Optional, Union, overload
43

54
import pydantic
65
import rich
76
from rich.table import Table
87
from rich.prompt import Confirm, Prompt, FloatPrompt, IntPrompt, InvalidResponse
8+
from typing_extensions import Literal
99

1010
from .datafold_api import (
1111
DatafoldAPI,
@@ -14,6 +14,7 @@
1414
TDsConfig,
1515
TestDataSourceStatus,
1616
)
17+
from ..dbt_parser import DbtParser
1718

1819

1920
UNKNOWN_VALUE = "unknown_value"
@@ -49,8 +50,12 @@ def _validate_temp_schema(temp_schema: str):
4950
raise ValueError("Temporary schema should have a format <database>.<schema>")
5051

5152

52-
def create_ds_config(ds_config: TCloudApiDataSourceConfigSchema, data_source_name: str) -> TDsConfig:
53-
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True)
53+
def create_ds_config(
54+
ds_config: TCloudApiDataSourceConfigSchema,
55+
data_source_name: str,
56+
dbt_parser: Optional[DbtParser] = None,
57+
) -> TDsConfig:
58+
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)
5459

5560
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
5661
float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)
@@ -64,7 +69,41 @@ def create_ds_config(ds_config: TCloudApiDataSourceConfigSchema, data_source_nam
6469
)
6570

6671

67-
def _parse_ds_credentials(ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True):
72+
@overload
73+
def _cast_value(value: str, type_: Literal["integer"]) -> int:
74+
...
75+
76+
77+
@overload
78+
def _cast_value(value: str, type_: Literal["boolean"]) -> bool:
79+
...
80+
81+
82+
@overload
83+
def _cast_value(value: str, type_: Literal["string"]) -> str:
84+
...
85+
86+
87+
def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
88+
if type_ == "integer":
89+
return int(value)
90+
elif type_ == "boolean":
91+
return bool(value)
92+
return value
93+
94+
95+
def _parse_ds_credentials(
96+
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
97+
):
98+
creds = {}
99+
use_dbt_data = False
100+
if dbt_parser is not None:
101+
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
102+
try:
103+
creds = dbt_parser.get_connection_creds()[0]
104+
except Exception as e:
105+
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")
106+
68107
ds_options = {}
69108
basic_required_fields = set(ds_config.config_schema.required)
70109
for param_name, param_data in ds_config.config_schema.properties.items():
@@ -83,6 +122,14 @@ def _parse_ds_credentials(ds_config: TCloudApiDataSourceConfigSchema, only_basic
83122
if default_value != UNKNOWN_VALUE:
84123
input_values["default"] = default_value
85124

125+
if use_dbt_data:
126+
value = creds.get(param_name, UNKNOWN_VALUE)
127+
if value == UNKNOWN_VALUE:
128+
rich.print(f'[red]Cannot extract "{param_name}" from dbt profiles.yml. Please, type it manually')
129+
else:
130+
ds_options[param_name] = _cast_value(value, type_)
131+
continue
132+
86133
if type_ == "integer":
87134
value = IntPrompt.ask(**input_values)
88135
elif type_ == "boolean":
@@ -177,7 +224,7 @@ def _render_data_source_test_results(test_results: List[TDataSourceTestStage]) -
177224
rich.print(table)
178225

179226

180-
def get_or_create_data_source(api: DatafoldAPI) -> int:
227+
def get_or_create_data_source(api: DatafoldAPI, dbt_parser: Optional[DbtParser] = None) -> int:
181228
ds_configs = api.get_data_source_schema_config()
182229
data_sources = api.get_data_sources()
183230

@@ -198,10 +245,10 @@ def get_or_create_data_source(api: DatafoldAPI) -> int:
198245
_render_data_source(data_source=ds, title=f'Found existing data source for name "{ds.name}"')
199246
use_existing_ds = Confirm.ask("Would you like to continue with the existing data source?")
200247
if not use_existing_ds:
201-
return get_or_create_data_source(api=api)
248+
return get_or_create_data_source(api=api, dbt_parser=dbt_parser)
202249
return ds.id
203250

204-
ds_config = create_ds_config(ds_config, ds_name)
251+
ds_config = create_ds_config(ds_config=ds_config, data_source_name=ds_name, dbt_parser=dbt_parser)
205252
ds = api.create_data_source(ds_config)
206253
data_source_url = f"{api.host}/settings/integrations/dwh/{ds.type}/{ds.id}"
207254
_render_data_source(data_source=ds, title=f"Created a new data source with ID = {ds.id} ({data_source_url})")

0 commit comments

Comments
 (0)