Skip to content

Commit bc8a359

Browse files
authored
Merge pull request #95 from datakind/develop
merging adjustments
2 parents 3a50f17 + 1010235 commit bc8a359

9 files changed

Lines changed: 354 additions & 80 deletions

File tree

src/webapp/authn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def get_api_key(
5353
)
5454

5555

56+
def check_creds(username: str, password: str) -> bool:
57+
return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]
58+
59+
5660
def verify_password(plain_password: str, hashed_password: str) -> bool:
5761
"""Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel
5862
Generates hashes that start with 2y. The hashing scheme recognizes both."""

src/webapp/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
"API_KEY_ISSUERS": [],
1414
"INITIAL_API_KEY": "",
1515
"INITIAL_API_KEY_ID": "",
16+
"CATALOG_NAME": "",
17+
"SQL_WAREHOUSE_ID": "",
18+
"USERNAME": "",
19+
"PASSWORD": "",
1620
}
1721

1822
# The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex)

src/webapp/databricks.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from pydantic import BaseModel
55
from databricks.sdk import WorkspaceClient
66
from databricks.sdk.service import catalog
7-
7+
from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout
88
from .config import databricks_vars, gcs_vars
99
from .utilities import databricksify_inst_name, SchemaType
10+
from typing import List, Any
11+
import time
1012

1113
# List of data medallion levels
1214
MEDALLION_LEVELS = ["silver", "gold", "bronze"]
@@ -191,3 +193,65 @@ def delete_inst(self, inst_name: str) -> None:
191193
full_name=f"{cat_name}.{db_inst_name}_{medallion}.{table}"
192194
)
193195
w.schemas.delete(full_name=f"{cat_name}.{db_inst_name}_{medallion}")
196+
197+
def fetch_table_data(
198+
self,
199+
catalog_name: Any,
200+
schema_name: Any,
201+
table_name: Any,
202+
warehouse_id: Any,
203+
limit: int = 1000,
204+
) -> List[dict[str, Any]]:
205+
"""
206+
Runs a simple SELECT * FROM <catalog>.<schema>.<table> LIMIT <limit>
207+
against the specified SQL warehouse, and returns a list of row‐dicts.
208+
"""
209+
w = WorkspaceClient(
210+
host=databricks_vars["DATABRICKS_HOST_URL"],
211+
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
212+
)
213+
if not w:
214+
raise ValueError(
215+
"fetch_table_data(): could not initialize WorkspaceClient."
216+
)
217+
218+
fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
219+
sql = f"SELECT * FROM {fq_table} LIMIT {limit}"
220+
221+
resp = w.statement_execution.execute_statement(
222+
warehouse_id=warehouse_id,
223+
statement=sql,
224+
format=Format.JSON_ARRAY,
225+
wait_timeout="10s",
226+
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
227+
)
228+
229+
status = getattr(resp, "status", None)
230+
if status and status.state == "SUCCEEDED" and getattr(resp, "result", None):
231+
# resp.results is a list of row‐arrays, resp.schema is a list of column metadata
232+
column_names = [col.name for col in resp.manifest.schema]
233+
rows = resp.result.data_array
234+
else:
235+
# A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set.
236+
stmt_id = getattr(resp, "statement_id", None)
237+
if not stmt_id:
238+
raise ValueError(
239+
f"fetch_table_data(): unexpected response state: {resp}"
240+
)
241+
242+
# B. Poll until the statement succeeds (or fails/cancels)
243+
status = resp.status.state if getattr(resp, "status", None) else None
244+
while status not in ("SUCCEEDED", "FAILED", "CANCELED"):
245+
time.sleep(1)
246+
resp2 = w.statement_execution.get_statement(statement_id=stmt_id)
247+
status = resp2.status.state if getattr(resp2, "status", None) else None
248+
resp = resp2
249+
if status != "SUCCEEDED":
250+
raise ValueError(f"fetch_table_data(): query ended with state {status}")
251+
252+
# C. At this point, resp holds the final manifest and first chunk
253+
column_names = [col.name for col in resp.manifest.schema]
254+
rows = resp.result.data_array
255+
256+
# Transform each row (a list of values) into a dict
257+
return [dict(zip(column_names, row)) for row in rows]

src/webapp/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import secrets
77
from fastapi import FastAPI, Depends, HTTPException, status, Security
88
from fastapi.responses import FileResponse
9+
from fastapi.security import OAuth2PasswordRequestForm
910
from pydantic import BaseModel
1011
from sqlalchemy.future import select
1112
from sqlalchemy import update
@@ -37,6 +38,7 @@
3738
create_access_token,
3839
get_api_key,
3940
get_api_key_hash,
41+
check_creds,
4042
)
4143

4244
# Set the logging
@@ -99,13 +101,16 @@ async def access_token_from_api_key(
99101
) -> Token:
100102
"""Generate a token from an API key."""
101103
local_session.set(sql_session)
104+
102105
user = authenticate_api_key(api_key_enduser_tuple, local_session.get())
106+
103107
if not user:
104108
raise HTTPException(
105109
status_code=status.HTTP_401_UNAUTHORIZED,
106-
detail="API key not valid",
110+
detail="Invalid API key and credentials",
107111
headers={"WWW-Authenticate": "X-API-KEY"},
108112
)
113+
109114
access_token_expires = timedelta(
110115
minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"])
111116
)

src/webapp/main_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_session,
1414
ApiKeyTable,
1515
)
16+
from unittest.mock import patch
1617
from .authn import get_password_hash, get_api_key_hash
1718
from .test_helper import (
1819
DATAKINDER,

src/webapp/routers/data.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import logging
1414
from sqlalchemy.exc import IntegrityError
15+
from ..config import env_vars
1516

1617
from ..utilities import (
1718
has_access_to_inst_or_err,
@@ -31,8 +32,10 @@
3132
local_session,
3233
BatchTable,
3334
FileTable,
35+
InstTable,
3436
)
3537

38+
from ..databricks import DatabricksControl
3639
from ..gcsdbutils import update_db_from_bucket
3740

3841
from ..gcsutil import StorageControl
@@ -1018,3 +1021,220 @@ def get_upload_url(
10181021
except ValueError as ve:
10191022
# Return a 400 error with the specific message from ValueError
10201023
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1024+
1025+
1026+
# Get SHAP Values for Inference
1027+
@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str)
1028+
def get_top_features(
1029+
inst_id: str,
1030+
run_id: str,
1031+
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1032+
sql_session: Annotated[Session, Depends(get_session)],
1033+
) -> List[dict[str, Any]]:
1034+
"""Returns a signed URL for uploading data to a specific institution."""
1035+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1036+
# has_access_to_inst_or_err(inst_id, current_user)
1037+
local_session.set(sql_session)
1038+
query_result = (
1039+
local_session.get()
1040+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1041+
.all()
1042+
)
1043+
if not query_result or len(query_result) == 0:
1044+
raise HTTPException(
1045+
status_code=status.HTTP_404_NOT_FOUND,
1046+
detail="Institution not found.",
1047+
)
1048+
if len(query_result) > 1:
1049+
raise HTTPException(
1050+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1051+
detail="Institution duplicates found.",
1052+
)
1053+
1054+
try:
1055+
dbc = DatabricksControl()
1056+
rows = dbc.fetch_table_data(
1057+
catalog_name=env_vars["CATALOG_NAME"],
1058+
schema_name=f"{query_result[0][0].name}_silver",
1059+
table_name=f"sample_inference_{run_id}_features_with_most_impact",
1060+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1061+
limit=500,
1062+
)
1063+
1064+
return rows
1065+
except ValueError as ve:
1066+
# Return a 400 error with the specific message from ValueError
1067+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1068+
1069+
1070+
# Get SHAP Values for Inference
1071+
@router.get("/{inst_id}/inference/support-overview/{run_id}", response_model=str)
1072+
def get_support_overview(
1073+
inst_id: str,
1074+
run_id: str,
1075+
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1076+
sql_session: Annotated[Session, Depends(get_session)],
1077+
) -> List[dict[str, Any]]:
1078+
"""Returns a signed URL for uploading data to a specific institution."""
1079+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1080+
# has_access_to_inst_or_err(inst_id, current_user)
1081+
local_session.set(sql_session)
1082+
query_result = (
1083+
local_session.get()
1084+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1085+
.all()
1086+
)
1087+
if not query_result or len(query_result) == 0:
1088+
raise HTTPException(
1089+
status_code=status.HTTP_404_NOT_FOUND,
1090+
detail="Institution not found.",
1091+
)
1092+
if len(query_result) > 1:
1093+
raise HTTPException(
1094+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1095+
detail="Institution duplicates found.",
1096+
)
1097+
1098+
try:
1099+
dbc = DatabricksControl()
1100+
rows = dbc.fetch_table_data(
1101+
catalog_name=env_vars["CATALOG_NAME"],
1102+
schema_name=f"{query_result[0][0].name}_silver",
1103+
table_name=f"sample_inference_{run_id}_support_overview",
1104+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1105+
limit=500,
1106+
)
1107+
1108+
return rows
1109+
except ValueError as ve:
1110+
# Return a 400 error with the specific message from ValueError
1111+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1112+
1113+
1114+
@router.get("/{inst_id}/inference/feature_value/{run_id}", response_model=str)
1115+
def get_feature_value(
1116+
inst_id: str,
1117+
run_id: str,
1118+
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1119+
sql_session: Annotated[Session, Depends(get_session)],
1120+
) -> List[dict[str, Any]]:
1121+
"""Returns a signed URL for uploading data to a specific institution."""
1122+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1123+
# has_access_to_inst_or_err(inst_id, current_user)
1124+
local_session.set(sql_session)
1125+
query_result = (
1126+
local_session.get()
1127+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1128+
.all()
1129+
)
1130+
if not query_result or len(query_result) == 0:
1131+
raise HTTPException(
1132+
status_code=status.HTTP_404_NOT_FOUND,
1133+
detail="Institution not found.",
1134+
)
1135+
if len(query_result) > 1:
1136+
raise HTTPException(
1137+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1138+
detail="Institution duplicates found.",
1139+
)
1140+
1141+
try:
1142+
dbc = DatabricksControl()
1143+
rows = dbc.fetch_table_data(
1144+
catalog_name=env_vars["CATALOG_NAME"],
1145+
schema_name=f"{query_result[0][0].name}_silver",
1146+
table_name=f"sample_inference_{run_id}_shap_feature_importance",
1147+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1148+
limit=500,
1149+
)
1150+
1151+
return rows
1152+
except ValueError as ve:
1153+
# Return a 400 error with the specific message from ValueError
1154+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1155+
1156+
1157+
@router.get("/{inst_id}/training/confusion_matrix/{run_id}", response_model=str)
1158+
def get_confusion_matrix(
1159+
inst_id: str,
1160+
run_id: str,
1161+
##current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1162+
sql_session: Annotated[Session, Depends(get_session)],
1163+
) -> List[dict[str, Any]]:
1164+
"""Returns a signed URL for uploading data to a specific institution."""
1165+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1166+
# has_access_to_inst_or_err(inst_id, current_user)
1167+
local_session.set(sql_session)
1168+
query_result = (
1169+
local_session.get()
1170+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1171+
.all()
1172+
)
1173+
if not query_result or len(query_result) == 0:
1174+
raise HTTPException(
1175+
status_code=status.HTTP_404_NOT_FOUND,
1176+
detail="Institution not found.",
1177+
)
1178+
if len(query_result) > 1:
1179+
raise HTTPException(
1180+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1181+
detail="Institution duplicates found.",
1182+
)
1183+
1184+
try:
1185+
dbc = DatabricksControl()
1186+
rows = dbc.fetch_table_data(
1187+
catalog_name=env_vars["CATALOG_NAME"],
1188+
schema_name=f"{query_result[0][0].name}_silver",
1189+
table_name=f"sample_training_{run_id}_confusion_matrix",
1190+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1191+
limit=500,
1192+
)
1193+
1194+
return rows
1195+
except ValueError as ve:
1196+
# Return a 400 error with the specific message from ValueError
1197+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1198+
1199+
1200+
@router.get("/{inst_id}/training/roc_curve/{run_id}", response_model=str)
1201+
def get_roc_curve(
1202+
inst_id: str,
1203+
run_id: str,
1204+
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1205+
sql_session: Annotated[Session, Depends(get_session)],
1206+
) -> List[dict[str, Any]]:
1207+
"""Returns a signed URL for uploading data to a specific institution."""
1208+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1209+
# has_access_to_inst_or_err(inst_id, current_user)
1210+
local_session.set(sql_session)
1211+
query_result = (
1212+
local_session.get()
1213+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1214+
.all()
1215+
)
1216+
if not query_result or len(query_result) == 0:
1217+
raise HTTPException(
1218+
status_code=status.HTTP_404_NOT_FOUND,
1219+
detail="Institution not found.",
1220+
)
1221+
if len(query_result) > 1:
1222+
raise HTTPException(
1223+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1224+
detail="Institution duplicates found.",
1225+
)
1226+
1227+
try:
1228+
dbc = DatabricksControl()
1229+
rows = dbc.fetch_table_data(
1230+
catalog_name=env_vars["CATALOG_NAME"],
1231+
schema_name=f"{query_result[0][0].name}_silver",
1232+
table_name=f"sample_training_{run_id}_roc_curve",
1233+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1234+
limit=500,
1235+
)
1236+
1237+
return rows
1238+
except ValueError as ve:
1239+
# Return a 400 error with the specific message from ValueError
1240+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))

0 commit comments

Comments
 (0)