Skip to content

Commit dbd58fb

Browse files
Follow "More flexible cluster configuration". (#194)
### Description Follows "More flexible cluster configuration" at dbt-labs/dbt-spark#467. - Reuse `dbt-spark`'s implementation - Remove the dependency on `databricks-cli` - Internal refactorings Co-authored-by: allisonwang-db <allison.wang@databricks.com>
1 parent 3a41729 commit dbd58fb

8 files changed

Lines changed: 119 additions & 205 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
### Features
44
- Support python model through run command API, currently supported materializations are table and incremental. ([dbt-labs/dbt-spark#377](https://github.com/dbt-labs/dbt-spark/pull/377), [#126](https://github.com/databricks/dbt-databricks/pull/126))
55
- Enable Pandas and Pandas-on-Spark DataFrames for dbt python models ([dbt-labs/dbt-spark#469](https://github.com/dbt-labs/dbt-spark/pull/469), [#181](https://github.com/databricks/dbt-databricks/pull/181))
6+
- Support job cluster in notebook submission method ([dbt-labs/dbt-spark#467](https://github.com/dbt-labs/dbt-spark/pull/467), [#194](https://github.com/databricks/dbt-databricks/pull/194))
7+
- In `all_purpose_cluster` submission method, a config `http_path` can be specified in Python model config to switch the cluster where Python model runs.
8+
```py
9+
def model(dbt, _):
10+
dbt.config(
11+
materialized='table',
12+
http_path='...'
13+
)
14+
...
15+
```
616
- Use builtin timestampadd and timestampdiff functions for dateadd/datediff macros if available ([#185](https://github.com/databricks/dbt-databricks/pull/185))
717
- Implement testing for a test for various Python models ([#189](https://github.com/databricks/dbt-databricks/pull/189))
818
- Implement testing for `type_boolean` in Databricks ([dbt-labs/dbt-spark#471](https://github.com/dbt-labs/dbt-spark/pull/471), [#188](https://github.com/databricks/dbt-databricks/pull/188))

dbt/adapters/databricks/api_client.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

dbt/adapters/databricks/connections.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,50 @@ def __post_init__(self) -> None:
129129
)
130130
self.connection_parameters = connection_parameters
131131

132+
def validate_creds(self) -> None:
133+
for key in ["host", "http_path", "token"]:
134+
if not getattr(self, key):
135+
raise dbt.exceptions.DbtProfileError(
136+
"The config '{}' is required to connect to Databricks".format(key)
137+
)
138+
139+
@classmethod
140+
def get_invocation_env(cls) -> Optional[str]:
141+
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
142+
if invocation_env:
143+
# Thrift doesn't allow nested () so we need to ensure
144+
# that the passed user agent is valid.
145+
if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env):
146+
raise dbt.exceptions.ValidationException(
147+
f"Invalid invocation environment: {invocation_env}"
148+
)
149+
return invocation_env
150+
151+
@classmethod
152+
def get_all_http_headers(cls, user_http_session_headers: Dict[str, str]) -> Dict[str, str]:
153+
http_session_headers_str: Optional[str] = os.environ.get(
154+
DBT_DATABRICKS_HTTP_SESSION_HEADERS
155+
)
156+
157+
http_session_headers_dict: Dict[str, str] = (
158+
{k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()}
159+
if http_session_headers_str is not None
160+
else {}
161+
)
162+
163+
intersect_http_header_keys = (
164+
user_http_session_headers.keys() & http_session_headers_dict.keys()
165+
)
166+
167+
if len(intersect_http_header_keys) > 0:
168+
raise dbt.exceptions.ValidationException(
169+
f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}"
170+
)
171+
172+
http_session_headers_dict.update(user_http_session_headers)
173+
174+
return http_session_headers_dict
175+
132176
@property
133177
def type(self) -> str:
134178
return "databricks"
@@ -165,14 +209,18 @@ def _connection_keys(self, *, with_aliases: bool = False) -> Tuple[str, ...]:
165209
connection_keys.append("session_properties")
166210
return tuple(connection_keys)
167211

168-
@property
169-
def cluster_id(self) -> Optional[str]:
170-
m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(self.http_path) # type: ignore[arg-type]
212+
@classmethod
213+
def extract_cluster_id(cls, http_path: str) -> Optional[str]:
214+
m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(http_path)
171215
if m:
172216
return m.group(1).strip()
173217
else:
174218
return None
175219

220+
@property
221+
def cluster_id(self) -> Optional[str]:
222+
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]
223+
176224

177225
class DatabricksSQLConnectionWrapper:
178226
"""Wrap a Databricks SQL connector in a way that no-ops transactions"""
@@ -437,69 +485,25 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
437485
lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema),
438486
)
439487

440-
@classmethod
441-
def validate_creds(cls, creds: DatabricksCredentials, required: List[str]) -> None:
442-
for key in required:
443-
if not getattr(creds, key):
444-
raise dbt.exceptions.DbtProfileError(
445-
"The config '{}' is required to connect to Databricks".format(key)
446-
)
447-
448-
@classmethod
449-
def validate_invocation_env(cls, invocation_env: str) -> None:
450-
# Thrift doesn't allow nested () so we need to ensure that the passed user agent is valid
451-
if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env):
452-
raise dbt.exceptions.ValidationException(
453-
f"Invalid invocation environment: {invocation_env}"
454-
)
455-
456-
@classmethod
457-
def get_all_http_headers(
458-
cls, user_http_session_headers: Dict[str, str]
459-
) -> List[Tuple[str, str]]:
460-
http_session_headers_str: Optional[str] = os.environ.get(
461-
DBT_DATABRICKS_HTTP_SESSION_HEADERS
462-
)
463-
464-
http_session_headers_dict: Dict[str, str] = (
465-
{k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()}
466-
if http_session_headers_str is not None
467-
else {}
468-
)
469-
470-
intersect_http_header_keys = (
471-
user_http_session_headers.keys() & http_session_headers_dict.keys()
472-
)
473-
474-
if len(intersect_http_header_keys) > 0:
475-
raise dbt.exceptions.ValidationException(
476-
f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}"
477-
)
478-
479-
http_session_headers_dict.update(user_http_session_headers)
480-
481-
return list(http_session_headers_dict.items())
482-
483488
@classmethod
484489
def open(cls, connection: Connection) -> Connection:
485490
if connection.state == ConnectionState.OPEN:
486491
logger.debug("Connection is already open, skipping open.")
487492
return connection
488493

489494
creds: DatabricksCredentials = connection.credentials
490-
cls.validate_creds(creds, ["host", "http_path", "token"])
495+
creds.validate_creds()
491496

492497
user_agent_entry = f"dbt-databricks/{__version__}"
493498

494-
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
495-
if invocation_env is not None and len(invocation_env) > 0:
496-
cls.validate_invocation_env(invocation_env)
499+
invocation_env = creds.get_invocation_env()
500+
if invocation_env:
497501
user_agent_entry = f"{user_agent_entry}; {invocation_env}"
498502

499503
connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]
500504

501-
http_headers: List[Tuple[str, str]] = cls.get_all_http_headers(
502-
connection_parameters.pop("http_headers", {})
505+
http_headers: List[Tuple[str, str]] = list(
506+
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
503507
)
504508

505509
exc: Optional[Exception] = None

dbt/adapters/databricks/impl.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626

2727
from dbt.adapters.databricks.column import DatabricksColumn
2828
from dbt.adapters.databricks.connections import DatabricksConnectionManager
29-
from dbt.adapters.databricks.python_submissions import CommandApiPythonJobHelper
29+
from dbt.adapters.databricks.python_submissions import (
30+
DbtDatabricksAllPurposeClusterPythonJobHelper,
31+
DbtDatabricksJobClusterPythonJobHelper,
32+
)
3033
from dbt.adapters.databricks.relation import DatabricksRelation
3134
from dbt.adapters.databricks.utils import undefined_proof
3235

@@ -264,13 +267,12 @@ def run_sql_for_tests(
264267
def valid_incremental_strategies(self) -> List[str]:
265268
return ["append", "merge", "insert_overwrite"]
266269

267-
@property
268-
def default_python_submission_method(self) -> str:
269-
return "commands"
270-
271270
@property
272271
def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
273-
return {"commands": CommandApiPythonJobHelper}
272+
return {
273+
"job_cluster": DbtDatabricksJobClusterPythonJobHelper,
274+
"all_purpose_cluster": DbtDatabricksAllPurposeClusterPythonJobHelper,
275+
}
274276

275277
@contextmanager
276278
def _catalog(self, catalog: Optional[str]) -> Iterator[None]:

0 commit comments

Comments
 (0)