Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def delete_inst(self, inst_name: str) -> None:
def fetch_table_data(
self,
catalog_name: str,
schema_name: str,
inst_name: str,
table_name: str,
warehouse_id: str,
limit: int = 1000,
Expand All @@ -221,7 +221,10 @@ def fetch_table_data(
raise ValueError(f"Failed to initialize WorkspaceClient: {e}")

# Construct the fully qualified table name
fully_qualified_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
schema_name = databricksify_inst_name(inst_name)
fully_qualified_table = (
f"`{catalog_name}`.`{schema_name}__silver`.`{table_name}`"
)
sql_query = f"SELECT * FROM {fully_qualified_table} LIMIT {limit}"

try:
Expand Down
32 changes: 5 additions & 27 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,28 +1023,6 @@ def get_upload_url(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/inference/test")
def test() -> Any:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name="dev_sst_02",
schema_name="default",
table_name="test_dataset",
warehouse_id="28e1cbabfe6deb87",
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str)
def get_top_features(
Expand Down Expand Up @@ -1077,7 +1055,7 @@ def get_top_features(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_features_with_most_impact",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down Expand Up @@ -1121,7 +1099,7 @@ def get_support_overview(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down Expand Up @@ -1164,7 +1142,7 @@ def get_feature_value(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down Expand Up @@ -1207,7 +1185,7 @@ def get_confusion_matrix(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_confusion_matrix",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down Expand Up @@ -1250,7 +1228,7 @@ def get_roc_curve(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_roc_curve",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down
Loading