diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 735e4596..8acb3be4 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -202,7 +202,7 @@ def delete_inst(self, inst_name: str) -> None: def fetch_table_data( self, catalog_name: str, - schema_name: str, + inst_name: str, table_name: str, warehouse_id: str, limit: int = 1000, @@ -221,7 +221,10 @@ def fetch_table_data( raise ValueError(f"Failed to initialize WorkspaceClient: {e}") # Construct the fully qualified table name - fully_qualified_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`" + schema_name = databricksify_inst_name(inst_name) + fully_qualified_table = ( + f"`{catalog_name}`.`{schema_name}__silver`.`{table_name}`" + ) sql_query = f"SELECT * FROM {fully_qualified_table} LIMIT {limit}" try: diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 4208ce66..c93d621a 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1023,28 +1023,6 @@ def get_upload_url( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) -# Get SHAP Values for Inference -@router.get("/inference/test") -def test() -> Any: - """Returns a signed URL for uploading data to a specific institution.""" - # raise error at this level instead bc otherwise it's getting wrapped as a 200 - - try: - dbc = DatabricksControl() - rows = dbc.fetch_table_data( - catalog_name="dev_sst_02", - schema_name="default", - table_name="test_dataset", - warehouse_id="28e1cbabfe6deb87", - limit=500, - ) - - return rows - except ValueError as ve: - # Return a 400 error with the specific message from ValueError - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) - - # Get SHAP Values for Inference @router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str) def get_top_features( @@ -1077,7 +1055,7 @@ def get_top_features( dbc = DatabricksControl() rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], - schema_name=f"{query_result[0][0].name}_silver", + inst_name=f"{query_result[0][0].name}", table_name=f"sample_inference_{run_id}_features_with_most_impact", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], limit=500, @@ -1121,7 +1099,7 @@ def get_support_overview( dbc = DatabricksControl() rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], - schema_name=f"{query_result[0][0].name}_silver", + inst_name=f"{query_result[0][0].name}", table_name=f"sample_inference_{run_id}_support_overview", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], limit=500, @@ -1164,7 +1142,7 @@ def get_feature_value( dbc = DatabricksControl() rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], - schema_name=f"{query_result[0][0].name}_silver", + inst_name=f"{query_result[0][0].name}", table_name=f"sample_inference_{run_id}_shap_feature_importance", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], limit=500, @@ -1207,7 +1185,7 @@ def get_confusion_matrix( dbc = DatabricksControl() rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], - schema_name=f"{query_result[0][0].name}_silver", + inst_name=f"{query_result[0][0].name}", table_name=f"sample_training_{run_id}_confusion_matrix", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], limit=500, @@ -1250,7 +1228,7 @@ def get_roc_curve( dbc = DatabricksControl() rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], - schema_name=f"{query_result[0][0].name}_silver", + inst_name=f"{query_result[0][0].name}", table_name=f"sample_training_{run_id}_roc_curve", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], limit=500,