Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def run_pdp_inference(
"""Triggers PDP inference Databricks run."""
if (
not req.filepath_to_type
or not check_types(req.filepath_to_type.values(), SchemaType.PDP_COURSE)
or not check_types(req.filepath_to_type.values(), SchemaType.PDP_COHORT)
or not check_types(req.filepath_to_type.values(), SchemaType.COURSE)
or not check_types(req.filepath_to_type.values(), SchemaType.STUDENT)
):
raise ValueError(
"run_pdp_inference() requires PDP_COURSE and PDP_COHORT type files to run."
Expand All @@ -147,10 +147,10 @@ def run_pdp_inference(
job_id,
job_parameters={
"cohort_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.PDP_COHORT
req.filepath_to_type, SchemaType.STUDENT
),
"course_file_name": get_filepath_of_filetype(
req.filepath_to_type, SchemaType.PDP_COURSE
req.filepath_to_type, SchemaType.COURSE
),
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
Expand Down
8 changes: 2 additions & 6 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,16 +855,12 @@ def infer_models_from_filename(file_path: str, institution_id: str) -> List[str]
inferred.add("COURSE")
if "student" in name:
inferred.add("STUDENT")
if institution_id == "pdp":
inferred.add("SEMESTER")
if "semester" in name:
inferred.add("SEMESTER")
if "cohort" in name:
inferred.add("STUDENT")
inferred.add("SEMESTER")
if "course" not in name and ("ar" in name or "deidentified" in name):
inferred.add("STUDENT")
inferred.add("SEMESTER")

if not inferred:
logging.error(
Expand Down Expand Up @@ -938,7 +934,7 @@ def validation_helper(
uploader=str_to_uuid(current_user.user_id),
source=source_str,
sst_generated=False,
schemas=list(inferred_schemas),
schemas=list(allowed_schemas),
valid=True,
)
local_session.get().add(new_file_record)
Expand All @@ -960,7 +956,7 @@ def validation_helper(
return {
"name": file_name,
"inst_id": inst_id,
"file_types": list(inferred_schemas),
"file_types": list(allowed_schemas),
"source": source_str,
"status": db_status,
}
Expand Down
6 changes: 3 additions & 3 deletions src/webapp/routers/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=True,
valid=True,
schemas=[SchemaType.PDP_COHORT],
schemas=[SchemaType.STUDENT],
)
file_4 = FileTable(
id=SAMPLE_UUID,
Expand All @@ -145,7 +145,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=True,
valid=True,
schemas=[SchemaType.PDP_COHORT],
schemas=[SchemaType.STUDENT],
)
try:
with sqlalchemy.orm.Session(engine) as session:
Expand All @@ -168,7 +168,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=False,
valid=False,
schemas=[SchemaType.PDP_COURSE],
schemas=[SchemaType.COURSE],
),
file_3,
file_4,
Expand Down
16 changes: 13 additions & 3 deletions src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_external_bucket_name,
SchemaType,
decode_url_piece,
LEGACY_TO_NEW_SCHEMA,
)
from ..database import (
get_session,
Expand Down Expand Up @@ -517,13 +518,22 @@ def trigger_inference_run(
detail="Unexpected number of batches found: Expected 1, got "
+ str(len(inst_result)),
)
inst_file_schemas = [x.schemas for x in batch_result[0][0].files]
schema_configs = jsonpickle.decode(query_result[0][0].schema_configs)

for config_group in schema_configs:
for config in config_group:
config.schema_type = LEGACY_TO_NEW_SCHEMA.get(
config.schema_type, config.schema_type
)

if not check_file_types_valid_schema_configs(
[x.schemas for x in batch_result[0][0].files],
jsonpickle.decode(query_result[0][0].schema_configs),
inst_file_schemas,
schema_configs,
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The files in this batch don't conform to the schema configs allowed by this model.",
detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}",
)
# Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines.
db_req = DatabricksInferenceRunRequest(
Expand Down
68 changes: 23 additions & 45 deletions src/webapp/routers/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=False,
valid=True,
schemas=[SchemaType.PDP_COURSE],
schemas=[SchemaType.COURSE],
)
file_3 = FileTable(
id=FILE_UUID_3,
Expand All @@ -129,7 +129,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=True,
valid=True,
schemas=[SchemaType.PDP_COHORT],
schemas=[SchemaType.STUDENT],
)
model_1 = ModelTable(
id=SAMPLE_UUID,
Expand All @@ -139,20 +139,15 @@ def session_fixture():
[
[
SchemaConfigObj(
schema_type=SchemaType.PDP_COURSE,
schema_type=SchemaType.COURSE,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.PDP_COHORT,
schema_type=SchemaType.STUDENT,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.SST_PDP_FINANCE,
optional=True,
multiple_allowed=False,
),
]
]
),
Expand Down Expand Up @@ -189,7 +184,7 @@ def session_fixture():
updated_at=DATETIME_TESTING,
sst_generated=False,
valid=False,
schemas=[SchemaType.PDP_COURSE],
schemas=[SchemaType.COURSE],
),
file_3,
model_1,
Expand Down Expand Up @@ -329,23 +324,18 @@ def test_read_inst_model_output(client: TestClient):
def test_create_model(client: TestClient):
"""Depending on timeline, fellows may not get to this."""
schema_config_1 = {
"schema_type": SchemaType.PDP_COURSE,
"schema_type": SchemaType.COURSE,
"count": 1,
}
schema_config_2 = {
"schema_type": SchemaType.PDP_COHORT,
"schema_type": SchemaType.STUDENT,
"count": 1,
}
schema_config_3 = {
"schema_type": SchemaType.SST_PDP_FINANCE,
"count": 1,
"optional": True,
}
response = client.post(
"/institutions/" + uuid_to_str(USER_VALID_INST_UUID) + "/models/",
json={
"name": "my_model",
"schema_configs": [[schema_config_1, schema_config_2, schema_config_3]],
"schema_configs": [[schema_config_1, schema_config_2]],
},
)

Expand All @@ -368,9 +358,8 @@ def test_trigger_inference_run(client: TestClient):
)

assert response.status_code == 400
assert (
response.text
== '{"detail":"The files in this batch don\'t conform to the schema configs allowed by this model."}'
assert response.json()["detail"].startswith(
"The files in this batch don't conform to the schema configs allowed by this model."
)

response = client.post(
Expand All @@ -395,56 +384,45 @@ def test_trigger_inference_run(client: TestClient):
def test_check_file_types_valid_schema_configs():
"""Test batch schema validation logic."""
file_types1 = [
[SchemaType.PDP_COURSE],
[SchemaType.PDP_COHORT],
[SchemaType.COURSE],
[SchemaType.STUDENT],
[SchemaType.UNKNOWN],
]
file_types2 = [
[SchemaType.SST_PDP_COHORT],
[SchemaType.SST_PDP_COURSE],
[SchemaType.SST_PDP_FINANCE],
[SchemaType.STUDENT],
[SchemaType.COURSE],
]
file_types3 = [
[SchemaType.SST_PDP_COHORT, SchemaType.UNKNOWN],
[SchemaType.SST_PDP_COURSE],
[SchemaType.STUDENT, SchemaType.UNKNOWN],
[SchemaType.COURSE],
]
file_types4 = [
[SchemaType.SST_PDP_COHORT, SchemaType.UNKNOWN],
[SchemaType.STUDENT, SchemaType.UNKNOWN],
[SchemaType.UNKNOWN],
]
pdp_configs = [
SchemaConfigObj(
schema_type=SchemaType.PDP_COURSE,
schema_type=SchemaType.COURSE,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.PDP_COHORT,
schema_type=SchemaType.STUDENT,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.SST_PDP_FINANCE,
optional=True,
multiple_allowed=False,
),
]
sst_configs = [
SchemaConfigObj(
schema_type=SchemaType.SST_PDP_COHORT,
schema_type=SchemaType.STUDENT,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.SST_PDP_COURSE,
schema_type=SchemaType.COURSE,
optional=False,
multiple_allowed=False,
),
SchemaConfigObj(
schema_type=SchemaType.SST_PDP_FINANCE,
optional=True,
multiple_allowed=False,
),
]
custom = [
SchemaConfigObj(
Expand All @@ -463,10 +441,10 @@ def test_check_file_types_valid_schema_configs():
assert not check_file_types_valid_schema_configs(file_types1, [custom])
assert not check_file_types_valid_schema_configs(file_types1, schema_configs1)
assert check_file_types_valid_schema_configs(file_types2, [sst_configs])
assert not check_file_types_valid_schema_configs(file_types2, [pdp_configs])
assert check_file_types_valid_schema_configs(file_types2, [pdp_configs])
assert not check_file_types_valid_schema_configs(file_types2, [custom])
assert check_file_types_valid_schema_configs(file_types3, [sst_configs])
assert not check_file_types_valid_schema_configs(file_types3, [pdp_configs])
assert check_file_types_valid_schema_configs(file_types3, [pdp_configs])
assert not check_file_types_valid_schema_configs(file_types3, [custom])
assert not check_file_types_valid_schema_configs(file_types4, [sst_configs])
assert not check_file_types_valid_schema_configs(file_types4, [pdp_configs])
Expand Down
24 changes: 12 additions & 12 deletions src/webapp/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,25 @@ class SchemaType(StrEnum):

# If an institution uses UNKNOWN as an allowed schema, it means bypass the validation check.
UNKNOWN = "UNKNOWN"
# The standard PDP ARF file schemas
PDP_COHORT = "PDP_COHORT"
PDP_COURSE = "PDP_COURSE"
# The PDP aligned SST schemas
SST_PDP_COHORT = "SST_PDP_COHORT"
SST_PDP_COURSE = "SST_PDP_COURSE"
SST_PDP_FINANCE = "SST_PDP_FINANCE"
STUDENT = "STUDENT"
SEMESTER = "SEMESTER"
COURSE = "COURSE"

# Schema Types of output files
SST_OUTPUT = "SST_OUTPUT"
PNG = "PNG"


LEGACY_TO_NEW_SCHEMA = {
"PDP_COHORT": "STUDENT",
"SST_PDP_COHORT": "STUDENT",
"PDP_COURSE": "COURSE",
"SST_PDP_COURSE": "COURSE",
}

PDP_SCHEMA_GROUP: Final = {
SchemaType.PDP_COHORT,
SchemaType.PDP_COURSE,
SchemaType.SST_PDP_FINANCE,
SchemaType.SST_PDP_COHORT,
SchemaType.SST_PDP_COURSE,
SchemaType.STUDENT,
SchemaType.COURSE,
}


Expand Down
1 change: 1 addition & 0 deletions src/webapp/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def validate_dataset(
print("Optional column validation errors on: ", opt_failures)
return {
"validation_status": "passed_with_soft_errors",
"schemas": model_list,
"missing_optional": missing_optional,
"optional_validation_failures": opt_failures,
"failure_cases": err.failure_cases.to_dict(orient="records"),
Expand Down
Loading
Loading