Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3691156
Merge pull request #73 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
2b0c35b
Merge pull request #74 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
4b53cf7
Merge pull request #75 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
bed6487
Merge pull request #76 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
3566c13
Merge pull request #77 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
5dcc4ab
Merge pull request #78 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
b51b4a8
Merge pull request #79 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
3e3f4c0
Merge pull request #80 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
a9cda3f
Merge pull request #81 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
9333002
Merge pull request #82 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
b66ec69
Merge pull request #83 from datakind/Validation-Errors
Mesh-ach Jun 2, 2025
5a4a7b2
Merge pull request #84 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
3601164
Merge pull request #85 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
1070b7d
Merge pull request #86 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
e3ef82a
Merge pull request #87 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
f1bbfc1
Merge pull request #88 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
8009c52
Merge pull request #89 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
d85c4ec
Merge pull request #90 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
5e87ef8
Merge pull request #92 from datakind/Validation-Errors
Mesh-ach Jun 3, 2025
aba20d2
Merge pull request #93 from datakind/Validation-Errors
Mesh-ach Jun 4, 2025
9831998
Merge pull request #94 from datakind/Validation-Errors
Mesh-ach Jun 4, 2025
3b57641
feat: added option for api auth
Mesh-ach Jun 4, 2025
9ed2271
feat: added option for api auth
Mesh-ach Jun 4, 2025
4bdaea0
feat: added option for api auth
Mesh-ach Jun 4, 2025
ea08663
feat: added option for api auth
Mesh-ach Jun 4, 2025
b69c428
feat: added option for api auth
Mesh-ach Jun 4, 2025
788093f
feat: added option for api auth
Mesh-ach Jun 4, 2025
1010235
feat: added option for api auth
Mesh-ach Jun 4, 2025
cd538ee
feat: added option for api auth
Mesh-ach Jun 4, 2025
715019b
feat: added option for api auth
Mesh-ach Jun 4, 2025
fdc272e
feat: added option for api auth
Mesh-ach Jun 4, 2025
908aef6
feat: added option for api auth
Mesh-ach Jun 5, 2025
7e44786
feat: added option for api auth
Mesh-ach Jun 5, 2025
93fbd24
feat: added option for api auth
Mesh-ach Jun 5, 2025
7fa1584
feat: added option for api auth
Mesh-ach Jun 5, 2025
127cb12
feat: added option for api auth
Mesh-ach Jun 5, 2025
5bd7474
feat: added option for api auth
Mesh-ach Jun 5, 2025
d525382
feat: added option for api auth
Mesh-ach Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/webapp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"INITIAL_API_KEY_ID": "",
"CATALOG_NAME": "",
"SQL_WAREHOUSE_ID": "",
"USERNAME": "",
"PASSWORD": "",
}

# The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex)
Expand Down
112 changes: 60 additions & 52 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
from pydantic import BaseModel
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
from databricks.sdk.service.sql import Format, ExecuteStatementRequestOnWaitTimeout
from databricks.sdk.service.sql import (
Format,
ExecuteStatementRequestOnWaitTimeout,
Disposition,
)
from .config import databricks_vars, gcs_vars
from .utilities import databricksify_inst_name, SchemaType
from typing import List, Any
import time
from typing import List, Any, Dict
from databricks.sdk.errors import DatabricksError


# List of data medallion levels
MEDALLION_LEVELS = ["silver", "gold", "bronze"]
Expand Down Expand Up @@ -196,62 +201,65 @@ def delete_inst(self, inst_name: str) -> None:

def fetch_table_data(
self,
catalog_name: Any,
schema_name: Any,
table_name: Any,
warehouse_id: Any,
catalog_name: str,
inst_name: str,
table_name: str,
warehouse_id: str,
limit: int = 1000,
) -> List[dict[str, Any]]:
) -> List[Dict[str, Any]]:
"""
Runs a simple SELECT * FROM <catalog>.<schema>.<table> LIMIT <limit>
against the specified SQL warehouse, and returns a list of row‐dicts.
Executes a SELECT * query on the specified table within the given catalog and schema,
using the provided SQL warehouse. Returns the result as a list of dictionaries.
"""
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
if not w:
raise ValueError(
"fetch_table_data(): could not initialize WorkspaceClient."
try:
# Initialize the WorkspaceClient with default authentication
client = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
raise ValueError(f"Failed to initialize WorkspaceClient: {e}")

fq_table = f"`{catalog_name}`.`{schema_name}`.`{table_name}`"
sql = f"SELECT * FROM {fq_table} LIMIT {limit}"

resp = w.statement_execution.execute_statement(
warehouse_id=warehouse_id,
statement=sql,
format=Format.JSON_ARRAY,
wait_timeout="10s",
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
# Construct the fully qualified table name
schema_name = databricksify_inst_name(inst_name)
fully_qualified_table = (
f"`{catalog_name}`.`{schema_name}_silver`.`{table_name}`"
)
sql_query = f"SELECT * FROM {fully_qualified_table} LIMIT {limit}"

status = getattr(resp, "status", None)
if status and status.state == "SUCCEEDED" and getattr(resp, "result", None):
# resp.results is a list of row‐arrays, resp.schema is a list of column metadata
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array
else:
# A. If the SQL didn’t finish in 10 seconds, resp.statement_id will be set.
stmt_id = getattr(resp, "statement_id", None)
if not stmt_id:
raise ValueError(
f"fetch_table_data(): unexpected response state: {resp}"
)
try:
# Execute the SQL statement
response = client.statement_execution.execute_statement(
warehouse_id=warehouse_id,
statement=sql_query,
disposition=Disposition.INLINE, # Use Enum member
format=Format.JSON_ARRAY, # Use Enum member
wait_timeout="30s", # Wait up to 30 seconds for execution
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # Use Enum member
)
except DatabricksError as e:
raise ValueError(f"Databricks API call failed: {e}")

# Check if the query execution was successful
if response.status.state != "SUCCEEDED":
error_message = (
response.status.error.message
if response.status.error
else "No additional error info."
)
raise ValueError(
f"Query did not succeed (state={response.status.state}): {error_message}"
)

# B. Poll until the statement succeeds (or fails/cancels)
status = resp.status.state if getattr(resp, "status", None) else None
while status not in ("SUCCEEDED", "FAILED", "CANCELED"):
time.sleep(1)
resp2 = w.statement_execution.get_statement(statement_id=stmt_id)
status = resp2.status.state if getattr(resp2, "status", None) else None
resp = resp2
if status != "SUCCEEDED":
raise ValueError(f"fetch_table_data(): query ended with state {status}")
# Validate the presence of the result and schema
if not response.manifest or not response.manifest.schema:
raise ValueError("Query succeeded but schema manifest is missing.")
if not response.result or not response.result.data_array:
raise ValueError("Query succeeded but result data is missing.")

# C. At this point, resp holds the final manifest and first chunk
column_names = [col.name for col in resp.manifest.schema]
rows = resp.result.data_array
# Extract column names and data rows
column_names = [column.name for column in response.manifest.schema]
data_rows = response.result.data_array

# Transform each row (a list of values) into a dict
return [dict(zip(column_names, row)) for row in rows]
# Combine column names with corresponding row values
return [dict(zip(column_names, row)) for row in data_rows]
13 changes: 3 additions & 10 deletions src/webapp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import secrets
from fastapi import FastAPI, Depends, HTTPException, status, Security
from fastapi.responses import FileResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from sqlalchemy.future import select
from sqlalchemy import update
Expand Down Expand Up @@ -38,7 +37,6 @@
create_access_token,
get_api_key,
get_api_key_hash,
check_creds,
)

# Set the logging
Expand Down Expand Up @@ -97,30 +95,25 @@ def read_root() -> Any:
@app.post("/token-from-api-key")
async def access_token_from_api_key(
sql_session: Annotated[Session, Depends(get_session)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
api_key_enduser_tuple: str = Security(get_api_key),
) -> Token:
"""Generate a token from an API key."""
local_session.set(sql_session)

user = authenticate_api_key(api_key_enduser_tuple, local_session.get())
valid = check_creds(form_data.username, form_data.password)
logger.info(f"api_key input: {api_key_enduser_tuple}")
logger.info(f"user: {user}")
logger.info(f"valid creds: {valid}")

if not user and not valid:
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key and credentials",
headers={"WWW-Authenticate": "X-API-KEY"},
)
email = user.email if user else form_data.username

access_token_expires = timedelta(
minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"])
)
access_token = create_access_token(
data={"sub": email}, expires_delta=access_token_expires
data={"sub": user.email}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")

Expand Down
16 changes: 7 additions & 9 deletions src/webapp/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
get_session,
ApiKeyTable,
)
from unittest.mock import patch
from .authn import get_password_hash, get_api_key_hash
from .test_helper import (
DATAKINDER,
Expand Down Expand Up @@ -146,14 +145,13 @@ def test_get_root(client: TestClient):


def test_retrieve_token_gen_from_api_key(client: TestClient):
with patch.dict("os.environ", {"USERNAME": "fake", "PASSWORD": "fake"}):
response = client.post(
"/token-from-api-key",
headers={"X-API-KEY": "key_1"},
data={"username": "fake", "password": "fake"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
"""Test POST /token-from-api-key."""
response = client.post(
"/token-from-api-key",
headers={"X-API-KEY": "key_1"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"


def test_get_cross_isnt_users(client: TestClient):
Expand Down
30 changes: 15 additions & 15 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,12 @@ def get_upload_url(
def get_top_features(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand All @@ -1055,7 +1055,7 @@ def get_top_features(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_features_with_most_impact",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand All @@ -1072,12 +1072,12 @@ def get_top_features(
def get_support_overview(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand All @@ -1099,7 +1099,7 @@ def get_support_overview(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand All @@ -1115,12 +1115,12 @@ def get_support_overview(
def get_feature_value(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand All @@ -1142,7 +1142,7 @@ def get_feature_value(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand All @@ -1158,12 +1158,12 @@ def get_feature_value(
def get_confusion_matrix(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
##current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand All @@ -1185,7 +1185,7 @@ def get_confusion_matrix(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_confusion_matrix",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand All @@ -1201,12 +1201,12 @@ def get_confusion_matrix(
def get_roc_curve(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
# has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand All @@ -1228,7 +1228,7 @@ def get_roc_curve(
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
schema_name=f"{query_result[0][0].name}_silver",
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_roc_curve",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
Expand Down
Loading