-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabricks.py
More file actions
257 lines (228 loc) · 10.6 KB
/
Copy pathdatabricks.py
File metadata and controls
257 lines (228 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""Databricks SDk related helper functions."""
import os
from pydantic import BaseModel
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout
from .config import databricks_vars, gcs_vars
from .utilities import databricksify_inst_name, SchemaType
from typing import List, Any
import time
# List of data medallion levels
MEDALLION_LEVELS = ["silver", "gold", "bronze"]
# The name of the deployed pipeline in Databricks. Must match directly.
PDP_INFERENCE_JOB_NAME = "github_sourced_pdp_inference_pipeline"
class DatabricksInferenceRunRequest(BaseModel):
"""Databricks parameters for an inference run."""
inst_name: str
# Note that the following should be the filepath.
filepath_to_type: dict[str, list[SchemaType]]
model_name: str
model_type: str = "sklearn"
# The email where notifications will get sent.
email: str
gcp_external_bucket_name: str
class DatabricksInferenceRunResponse(BaseModel):
"""Databricks parameters for an inference run."""
job_run_id: int
def get_filepath_of_filetype(
file_dict: dict[str, list[SchemaType]], file_type: SchemaType
):
"""Helper functions to get a file of a given file_type.
For both, we will return the first file that matches the schema."""
for k, v in file_dict.items():
if file_type in v:
return k
return ""
def check_types(dict_values, file_type: SchemaType):
"""Check the file type is in the dict dictionary."""
for elem in dict_values:
if file_type in elem:
return True
return False
# Wrapping the usages in a class makes it easier to unit test via mocks.
class DatabricksControl(BaseModel):
"""Object to manage interfacing with GCS."""
def setup_new_inst(self, inst_name: str) -> None:
"""Sets up Databricks resources for a new institution."""
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
# This should still be cloud run, since it's cloud run triggering the databricks
# this account needs to exist on Databricks as well and needs to hvae the creation and job management permissions
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if w is None:
raise ValueError("setup_new_inst() workspace retrieval failed.")
db_inst_name = databricksify_inst_name(inst_name)
cat_name = databricks_vars["CATALOG_NAME"]
for medallion in MEDALLION_LEVELS:
w.schemas.create(name=f"{db_inst_name}_{medallion}", catalog_name=cat_name)
# Create a managed volume in the bronze schema for internal pipeline data.
# update to include a managed volume for toml files
created_volume_bronze = w.volumes.create(
catalog_name=cat_name,
schema_name=f"{db_inst_name}_bronze",
name="bronze_volume",
volume_type=catalog.VolumeType.MANAGED,
)
created_volume_silver = w.volumes.create(
catalog_name=cat_name,
schema_name=f"{db_inst_name}_silver",
name="silver_volume",
volume_type=catalog.VolumeType.MANAGED,
)
created_volume_gold = w.volumes.create(
catalog_name=cat_name,
schema_name=f"{db_inst_name}_gold",
name="gold_volume",
volume_type=catalog.VolumeType.MANAGED,
)
if (
created_volume_bronze is None
or created_volume_silver is None
or created_volume_gold is None
):
raise ValueError("setup_new_inst() volume creation failed.")
# Create directory on the volume
os.makedirs(
f"/Volumes/{cat_name}/{db_inst_name}_gold/gold_volume/configuration_files/",
exist_ok=True,
)
# Create directory on the volume
os.makedirs(
f"/Volumes/{cat_name}/{db_inst_name}_bronze/bronze_volume/raw_files/",
exist_ok=True,
)
# Note that for each unique PIPELINE, we'll need a new function, this is by nature of how unique pipelines
# may have unique parameters and would have a unique name (i.e. the name field specified in w.jobs.list()). But any run of a given pipeline (even across institutions) can use the same function.
# E.g. there is one PDP inference pipeline, so one PDP inference function here.
def run_pdp_inference(
self, req: DatabricksInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers PDP inference Databricks run."""
if (
not req.filepath_to_type
or not check_types(req.filepath_to_type.values(), SchemaType.PDP_COURSE)
or not check_types(req.filepath_to_type.values(), SchemaType.PDP_COHORT)
):
raise ValueError(
"run_pdp_inference() requires PDP_COURSE and PDP_COHORT type files to run."
)
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if not w:
raise ValueError("run_pdp_inference(): Databricks workspace not found.")
db_inst_name = databricksify_inst_name(req.inst_name)
job_id = next(w.jobs.list(name=PDP_INFERENCE_JOB_NAME)).job_id
if not job_id:
raise ValueError("run_pdp_inference(): Job was not created.")
run_job = w.jobs.run_now(
job_id,
job_parameters={
"cohort_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.PDP_COHORT
),
"course_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.PDP_COURSE
),
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
"DATABRICKS_WORKSPACE"
], # is this value the same PER environ? dev/staging/prod
"gcp_bucket_name": req.gcp_external_bucket_name,
"model_name": req.model_name,
"model_type": req.model_type,
"notification_email": req.email,
},
)
if not run_job:
raise ValueError("run_pdp_inference(): Job could not be run.")
return DatabricksInferenceRunResponse(job_run_id=run_job.response.run_id)
def delete_inst(self, inst_name: str) -> None:
"""Cleanup tasks required on the Databricks side to delete an institution."""
db_inst_name = databricksify_inst_name(inst_name)
cat_name = databricks_vars["CATALOG_NAME"]
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
# This should still be cloud run, since it's cloud run triggering the databricks
# this account needs to exist on Databricks as well and needs to have permissions.
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if not w:
raise ValueError("delete_inst(): Databricks workspace not found.")
# Delete the managed volume.
w.volumes.delete(name=f"{cat_name}.{db_inst_name}_bronze.bronze_volume")
w.volumes.delete(name=f"{cat_name}.{db_inst_name}_silver.silver_volume")
w.volumes.delete(name=f"{cat_name}.{db_inst_name}_gold.gold_volume")
# TODO implement model deletion
# Delete tables and schemas for each medallion level.
for medallion in MEDALLION_LEVELS:
all_tables = [
table.name
for table in w.tables.list(
catalog_name=cat_name,
schema_name=f"{db_inst_name}_{medallion}",
)
]
for table in all_tables:
w.tables.delete(
full_name=f"{cat_name}.{db_inst_name}_{medallion}.{table}"
)
w.schemas.delete(full_name=f"{cat_name}.{db_inst_name}_{medallion}")
def fetch_table_data(
self,
catalog_name: Any,
schema_name: Any,
table_name: Any,
warehouse_id: Any,
limit: int = 1000,
) -> List[dict[str, Any]]:
"""
Runs a simple SELECT * FROM <catalog>.<schema>.<table> LIMIT <limit>
against the specified SQL warehouse, and returns a list of row‐dicts.
"""
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if not w:
raise ValueError(
"fetch_table_data(): could not initialize WorkspaceClient."
)
fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
sql = f"SELECT * FROM {fq_table} LIMIT {limit}"
resp = w.statement_execution.execute_statement(
warehouse_id=warehouse_id,
statement=sql,
format=Format.JSON_ARRAY,
wait_timeout="10s",
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
)
status = getattr(resp, "status", None)
if status and status.state == "SUCCEEDED" and getattr(resp, "result", None):
# resp.results is a list of row‐arrays, resp.schema is a list of column metadata
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array
else:
# A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set.
stmt_id = getattr(resp, "statement_id", None)
if not stmt_id:
raise ValueError(
f"fetch_table_data(): unexpected response state: {resp}"
)
# B. Poll until the statement succeeds (or fails/cancels)
status = resp.status.state if getattr(resp, "status", None) else None
while status not in ("SUCCEEDED", "FAILED", "CANCELED"):
time.sleep(1)
resp2 = w.statement_execution.get_statement(statement_id=stmt_id)
status = resp2.status.state if getattr(resp2, "status", None) else None
resp = resp2
if status != "SUCCEEDED":
raise ValueError(f"fetch_table_data(): query ended with state {status}")
# C. At this point, resp holds the final manifest and first chunk
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array
# Transform each row (a list of values) into a dict
return [dict(zip(column_names, row)) for row in rows]