diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 88092ee2..6c3289a1 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -4,7 +4,7 @@ from datetime import datetime, date from typing import Annotated, Any, Dict, List, Optional, Tuple, Union, cast from pydantic import BaseModel, Field -from fastapi import APIRouter, Depends, HTTPException, status, Response +from fastapi import APIRouter, Depends, HTTPException, status, Response, Query from sqlalchemy import and_, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select @@ -664,12 +664,14 @@ def get_eda_data( current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], storage_control: Annotated[StorageControl, Depends(StorageControl)], + clear_cache: Annotated[Optional[str], Query(alias="clear-cache")] = None, ) -> Any: """Returns EDA (Exploratory Data Analysis) data for a specific batch. This endpoint provides all the data needed to populate the EDA dashboard, including summary statistics, GPA charts, enrollment data, and demographic breakdowns. Analyzes all files in the batch together to provide comprehensive insights. + Pass query ``clear-cache=1`` to drop any cached EDA result for this batch before serving. """ has_access_to_inst_or_err(inst_id, current_user) has_full_data_access_or_err(current_user, "EDA data") @@ -694,6 +696,9 @@ def get_eda_data( ) cache_key = f"{inst_id}:{batch_id}" + if clear_cache == "1": + EDA_CACHE.pop(cache_key, None) + cached_result = EDA_CACHE.get(cache_key) if cached_result is not None: logger.debug(f"EDA cache hit for {cache_key}") diff --git a/src/webapp/routers/data_test.py b/src/webapp/routers/data_test.py index 7add80bd..3da9cc95 100644 --- a/src/webapp/routers/data_test.py +++ b/src/webapp/routers/data_test.py @@ -29,6 +29,7 @@ get_session, ) from ..utilities import uuid_to_str, get_current_active_user, SchemaType +from . import data as data_router from .data import ( router, DataOverview, @@ -892,6 +893,92 @@ def mock_read_csv(bucket_name: str, blob_path: str) -> pd.DataFrame: assert "series" in data["race_by_pell_status"] +def test_get_eda_data_clear_cache( + client: TestClient, session: sqlalchemy.orm.Session +) -> None: + """clear-cache=1 evicts the TTL entry so batch files are read from storage again.""" + import pandas as pd + + data_router.EDA_CACHE.clear() + + eda_batch = BatchTable( + id=uuid.UUID("66666666-6666-6666-6666-666666666666"), + inst_id=USER_VALID_INST_UUID, + name="batch_eda_clear_cache", + created_by=CREATOR_UUID, + created_at=DATETIME_TESTING, + updated_at=DATETIME_TESTING, + completed=True, + ) + student_file = FileTable( + id=uuid.UUID("77777777-7777-7777-7777-777777777777"), + inst_id=USER_VALID_INST_UUID, + name="student_clear_cache.csv", + source="MANUAL_UPLOAD", + batches={eda_batch}, + created_at=DATETIME_TESTING, + updated_at=DATETIME_TESTING, + sst_generated=False, + valid=True, + schemas=[SchemaType.STUDENT], + ) + session.add_all([eda_batch, student_file]) + session.commit() + + df_cohort = pd.DataFrame( + { + "student_id": ["S001"], + "cohort": ["2020"], + "cohort_term": ["FALL"], + "enrollment_type": ["FIRST-TIME"], + "enrollment_intensity_first_term": ["Full-Time"], + "gpa_group_year_1": [3.5], + "credential_type_sought_year_1": ["Bachelor"], + "pell_status_first_year": ["N"], + "first_gen": ["N"], + "gender": ["Female"], + "race": ["White"], + "student_age": ["20 - 24"], + } + ) + + def mock_read_csv(bucket_name: str, blob_path: str) -> pd.DataFrame: + if "student" in blob_path.lower(): + return df_cohort + raise ValueError(f"File not found: {blob_path}") + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = mock_read_csv + + base_path = ( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/batch/" + + uuid_to_str(eda_batch.id) + + "/eda" + ) + + try: + with mock.patch.object( + data_router, + "read_batch_files_as_dataframes", + wraps=data_router.read_batch_files_as_dataframes, + ) as read_dfs: + r1 = client.get(base_path) + assert r1.status_code == 200 + assert read_dfs.call_count == 1 + r2 = client.get(base_path) + assert r2.status_code == 200 + assert read_dfs.call_count == 1 + r3 = client.get(base_path + "?clear-cache=1") + assert r3.status_code == 200 + assert read_dfs.call_count == 2 + r4 = client.get(base_path) + assert r4.status_code == 200 + assert read_dfs.call_count == 2 + finally: + MOCK_STORAGE.reset_mock() + + # ==================== EDVISE VALIDATION TESTS ==================== EDVISE_INST_UUID = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")