diff --git a/scripts/test_foraging_integration.py b/scripts/test_foraging_integration.py new file mode 100644 index 0000000..712e90c --- /dev/null +++ b/scripts/test_foraging_integration.py @@ -0,0 +1,196 @@ +"""Integration tests for the foraging acorn against the upstream public S3 tables. + +Run with FOREST_TYPE=memory (default). These tests read from the public upstream +S3 cache — no AWS credentials required. They compare row values against known +sessions to catch schema or data drift between zombie-squirrel and the upstream build. + +Usage: + cd /path/to/zombie-squirrel + python -m pytest scripts/test_foraging_integration.py -v +""" + +import math + +import duckdb +import pandas as pd +import pytest + +from zombie_squirrel.acorn_helpers.foraging.query import ( + SESSION_DB, + TRIAL_DB, + fetch_trials, + select_sessions, +) +from zombie_squirrel.acorn_helpers.foraging.session import ( + UPSTREAM_SESSION_S3, + _add_asset_name, + _fetch_upstream, +) + +# A known stable session from 2024 (unlikely to be modified retroactively). +# Values verified against the upstream parquet on 2026-06-08. +_KNOWN_SUBJECT = "699982" +_KNOWN_DATE = "2024-01-09" + + +@pytest.fixture(scope="module") +def upstream_session_sample(): + """Fetch a small slice of the upstream session table for comparison.""" + conn = duckdb.connect() + conn.execute("INSTALL httpfs; LOAD httpfs;") + return conn.sql( + f"SELECT * FROM read_parquet('{UPSTREAM_SESSION_S3}') " + f"WHERE subject_id = '{_KNOWN_SUBJECT}' AND session_date = '{_KNOWN_DATE}'" + ).df() + + +class TestUpstreamSessionTableStructure: + def test_upstream_has_expected_columns(self, upstream_session_sample): + expected = { + "subject_id", "session_date", "nwb_suffix", "_session_id", + "co_asset_id", "co_s3_nwb_uri", "nwb_data_source", + "foraging_eff", "finished_trials", "bias_naive", + "curriculum_name", "current_stage_actual", + } + assert expected.issubset(set(upstream_session_sample.columns)) + + def test_known_session_exists(self, upstream_session_sample): + assert len(upstream_session_sample) >= 1, ( + f"Expected at least one session for subject {_KNOWN_SUBJECT} on {_KNOWN_DATE}" + ) + + def test_known_session_subject_id(self, upstream_session_sample): + assert upstream_session_sample["subject_id"].iloc[0] == _KNOWN_SUBJECT + + def test_known_session_date(self, upstream_session_sample): + assert upstream_session_sample["session_date"].iloc[0] == _KNOWN_DATE + + def test_known_session_has_session_id(self, upstream_session_sample): + sid = upstream_session_sample["_session_id"].iloc[0] + assert sid.startswith(f"{_KNOWN_SUBJECT}_{_KNOWN_DATE}_") + + +class TestAddAssetNameOnUpstream: + def test_asset_name_derived_from_co_uri(self, upstream_session_sample): + enriched = _add_asset_name(upstream_session_sample) + row = enriched.iloc[0] + if pd.notna(row["co_s3_nwb_uri"]): + assert pd.notna(row["asset_name"]) + assert row["asset_name"].startswith("behavior_") + assert _KNOWN_SUBJECT in row["asset_name"] + else: + pytest.skip("Known session has no CO asset URI; skipping asset_name check") + + def test_asset_name_matches_co_uri_stem(self, upstream_session_sample): + enriched = _add_asset_name(upstream_session_sample) + for _, row in enriched.iterrows(): + if pd.notna(row["co_s3_nwb_uri"]): + expected = row["co_s3_nwb_uri"].rsplit("/", 1)[-1].replace(".nwb", "") + assert row["asset_name"] == expected + + +class TestSelectSessions: + def test_returns_dataframe(self): + result = select_sessions(subjects=[_KNOWN_SUBJECT]) + assert isinstance(result, pd.DataFrame) + + def test_subject_filter(self): + result = select_sessions(subjects=[_KNOWN_SUBJECT]) + assert (result["subject_id"] == _KNOWN_SUBJECT).all() + + def test_where_clause_filters(self): + result = select_sessions( + subjects=[_KNOWN_SUBJECT], + where="session_date IS NOT NULL", + ) + assert len(result) > 0 + + def test_extra_columns_carried(self): + result = select_sessions( + subjects=[_KNOWN_SUBJECT], + columns=["foraging_eff", "finished_trials"], + ) + assert "foraging_eff" in result.columns + assert "finished_trials" in result.columns + assert "_session_id" in result.columns + + def test_empty_result_for_nonexistent_subject(self): + result = select_sessions(subjects=["000000_nonexistent"]) + assert len(result) == 0 + + def test_values_match_upstream(self, upstream_session_sample): + result = select_sessions( + subjects=[_KNOWN_SUBJECT], + columns=["foraging_eff", "bias_naive", "finished_trials"], + where=f"session_date = '{_KNOWN_DATE}'", + ) + assert len(result) >= 1 + our_row = result.iloc[0] + up_row = upstream_session_sample.iloc[0] + + if pd.notna(up_row["foraging_eff"]) and pd.notna(our_row["foraging_eff"]): + assert math.isclose(our_row["foraging_eff"], up_row["foraging_eff"], rel_tol=1e-6) + if pd.notna(up_row["finished_trials"]) and pd.notna(our_row["finished_trials"]): + assert our_row["finished_trials"] == up_row["finished_trials"] + + +class TestFetchTrials: + def test_returns_dataframe(self): + sessions = select_sessions(subjects=[_KNOWN_SUBJECT], where=f"session_date = '{_KNOWN_DATE}'") + if len(sessions) == 0: + pytest.skip("No sessions found for known subject/date") + trials = fetch_trials(sessions) + assert isinstance(trials, pd.DataFrame) + + def test_has_required_columns(self): + sessions = select_sessions(subjects=[_KNOWN_SUBJECT], where=f"session_date = '{_KNOWN_DATE}'") + if len(sessions) == 0: + pytest.skip("No sessions found for known subject/date") + trials = fetch_trials(sessions) + assert "trial" in trials.columns + assert "animal_response" in trials.columns + assert "earned_reward" in trials.columns + assert "subject_id" in trials.columns + assert "session_id" in trials.columns + + def test_trials_belong_to_selected_sessions(self): + sessions = select_sessions(subjects=[_KNOWN_SUBJECT], where=f"session_date = '{_KNOWN_DATE}'") + if len(sessions) == 0: + pytest.skip("No sessions found for known subject/date") + trials = fetch_trials(sessions) + assert set(trials["subject_id"]).issubset({_KNOWN_SUBJECT}) + + def test_empty_sessions_returns_empty(self): + empty = pd.DataFrame(columns=["_session_id", "subject_id", "session_date"]) + result = fetch_trials(empty) + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + + def test_trial_count_reasonable(self): + sessions = select_sessions( + subjects=[_KNOWN_SUBJECT], + columns=["finished_trials"], + where=f"session_date = '{_KNOWN_DATE}'", + ) + if len(sessions) == 0: + pytest.skip("No sessions found") + trials = fetch_trials(sessions) + if pd.notna(sessions["finished_trials"].iloc[0]): + expected = int(sessions["finished_trials"].iloc[0]) + # trial count should be within 20% of finished_trials (total_trials may differ) + assert abs(len(trials) - expected) / max(expected, 1) < 0.2 + + +class TestFetchUpstream: + def test_returns_dataframe_with_asset_name(self): + conn = duckdb.connect() + conn.execute("INSTALL httpfs; LOAD httpfs;") + sample = conn.sql( + f"SELECT * FROM read_parquet('{UPSTREAM_SESSION_S3}') " + f"WHERE subject_id = '{_KNOWN_SUBJECT}' LIMIT 5" + ).df() + enriched = _add_asset_name(sample) + assert "asset_name" in enriched.columns + co_rows = enriched[enriched["co_s3_nwb_uri"].notna()] + if len(co_rows): + assert co_rows["asset_name"].notna().all() diff --git a/src/zombie_squirrel/__init__.py b/src/zombie_squirrel/__init__.py index fe7a3e0..daef556 100644 --- a/src/zombie_squirrel/__init__.py +++ b/src/zombie_squirrel/__init__.py @@ -9,7 +9,14 @@ __version__ = "0.29.0" from zombie_squirrel.acorn_helpers.asset_basics import asset_basics # noqa: F401 -from zombie_squirrel.acorn_helpers.foraging_sessions import foraging_sessions # noqa: F401 +from zombie_squirrel.acorn_helpers.foraging.session import foraging_session # noqa: F401 +from zombie_squirrel.acorn_helpers.foraging.query import ( # noqa: F401 + select_sessions, + fetch_trials, + fetch_events, + read_trials, + read_events, +) from zombie_squirrel.acorn_helpers.assets_smartspim import assets_smartspim # noqa: F401 from zombie_squirrel.acorn_helpers.behavior_curriculum import behavior_curriculum # noqa: F401 from zombie_squirrel.acorn_helpers.platform_fib import platform_fib # noqa: F401 diff --git a/src/zombie_squirrel/acorn_helpers/__init__.py b/src/zombie_squirrel/acorn_helpers/__init__.py index 52c2bc8..1f8862a 100644 --- a/src/zombie_squirrel/acorn_helpers/__init__.py +++ b/src/zombie_squirrel/acorn_helpers/__init__.py @@ -3,7 +3,7 @@ from zombie_squirrel.acorn_helpers import ( # noqa: F401 asset_basics, custom, - foraging_sessions, + foraging, metadata_core, platform_qc, qc, diff --git a/src/zombie_squirrel/acorn_helpers/foraging/__init__.py b/src/zombie_squirrel/acorn_helpers/foraging/__init__.py new file mode 100644 index 0000000..83a9479 --- /dev/null +++ b/src/zombie_squirrel/acorn_helpers/foraging/__init__.py @@ -0,0 +1,14 @@ +"""Foraging database acorn and query helpers.""" + +from zombie_squirrel.acorn_helpers.foraging import session # noqa: F401 +from zombie_squirrel.acorn_helpers.foraging.query import ( # noqa: F401 + SESSION_DB, + TRIAL_DB, + EVENT_DB, + clear_caches, + fetch_events, + fetch_trials, + read_events, + read_trials, + select_sessions, +) diff --git a/src/zombie_squirrel/acorn_helpers/foraging/query.py b/src/zombie_squirrel/acorn_helpers/foraging/query.py new file mode 100644 index 0000000..74fc02b --- /dev/null +++ b/src/zombie_squirrel/acorn_helpers/foraging/query.py @@ -0,0 +1,254 @@ +"""DuckDB query helpers for the foraging parquet database. + +Two layers — reach for the simple helpers first, drop to native SQL when you need more: + + Layer 1 (convenience): + select_sessions -> fetch_trials / fetch_events + Filter the (small) session table on any metric / metadata, then pull those sessions' + trials or events with the session metadata already joined on — in one call. + + Layer 0 (escape hatch): + read_trials / read_events + Return a fast, partition-scoped ``read_parquet(...)`` clause for a set of subjects. + Drop it into whatever SQL you write — aggregations, window functions, trial<->event + joins, custom GROUP BY. + +Everything reads the public S3 database (no AWS credentials needed). Pass ``base=`` to +redirect to a local build or a custom S3 path. + +Ported from aind-dynamic-foraging-database with minor adaptations. +""" + +import duckdb + +PROD_S3_PREFIX = "s3://aind-scratch-data/aind-dynamic-foraging-cache" +SESSION_DB = f"{PROD_S3_PREFIX}/session_table.parquet" +TRIAL_DB = f"{PROD_S3_PREFIX}/trial_table" +EVENT_DB = f"{PROD_S3_PREFIX}/event_table" + +DEFAULT_TRIAL_COLUMNS = [ + "trial", "animal_response", "earned_reward", + "reward_probabilityL", "reward_probabilityR", +] +DEFAULT_EVENT_COLUMNS = ["trial", "timestamps", "event", "data"] + +_KEYS = ("subject_id", "session_date", "session_id") + +_SCOPED_MAX = 100 +_PARTITION_CACHE: dict[str, set] = {} + + +def _conn(con): + return con if con is not None else duckdb + + +def _quote_in(values): + return ", ".join("'" + str(v).replace("'", "''") + "'" for v in values) + + +def _partition_subjects(base, con=None): + """Subject ids that have a partition file under ``base`` (memoized per base).""" + cached = _PARTITION_CACHE.get(base) + if cached is not None: + return cached + rows = _conn(con).sql(f"SELECT file FROM glob('{base}/subject_id=*/*.parquet')").df() + found = rows["file"].str.extract(r"subject_id=([^/]+)/", expand=False).dropna() + _PARTITION_CACHE[base] = result = set(found) + return result + + +def clear_caches(): + """Drop memoized partition listings (call after rebuilding a local cache in-session).""" + _PARTITION_CACHE.clear() + + +def _full_glob(base): + return f"read_parquet('{base}/**/*.parquet', hive_partitioning=true, union_by_name=true)" + + +def _scoped_read(base, subjects, con): + if subjects is None: + return _full_glob(base) + want = sorted({str(s) for s in subjects} & _partition_subjects(base, con)) + if not want: + return f"(SELECT * FROM {_full_glob(base)} WHERE false)" + if len(want) > _SCOPED_MAX: + return ( + f"(SELECT * FROM {_full_glob(base)} " + f"WHERE CAST(subject_id AS VARCHAR) IN ({_quote_in(want)}))" + ) + files = [f"'{base}/subject_id={s}/*.parquet'" for s in want] + return f"read_parquet([{', '.join(files)}], hive_partitioning=true, union_by_name=true)" + + +# --------------------------------------------------------------------------- +# Layer 0 — escape hatch: a fast, partition-scoped read_parquet(...) source +# --------------------------------------------------------------------------- + +def read_trials(subjects=None, base=None, con=None): + """Return a ``read_parquet(...)`` clause for the trial table, scoped to ``subjects``. + + Drop the returned string into any SQL:: + + src = read_trials(['754372', '758435']) + duckdb.sql(f"SELECT subject_id, AVG(earned_reward::DOUBLE) FROM {src} GROUP BY subject_id") + + Parameters + ---------- + subjects : iterable, optional + Subject ids to scope the read to. ``None`` reads the full table. + base : str, optional + Trial-table directory prefix (default: production S3 ``trial_table``). + con : duckdb connection, optional + """ + return _scoped_read(base or TRIAL_DB, subjects, con) + + +def read_events(subjects=None, base=None, con=None): + """Return a ``read_parquet(...)`` clause for the event table, scoped to ``subjects``. + + Parameters + ---------- + subjects : iterable, optional + base : str, optional + Event-table directory prefix (default: production S3 ``event_table``). + con : duckdb connection, optional + """ + return _scoped_read(base or EVENT_DB, subjects, con) + + +# --------------------------------------------------------------------------- +# Layer 1 — convenience: filter sessions, then fetch their trials / events +# --------------------------------------------------------------------------- + +def select_sessions(where=None, subjects=None, columns=None, base=None, con=None, + order_by="subject_id, session_date"): + """Filter the session table; return selected sessions as a DataFrame. + + Parameters + ---------- + where : str, optional + Raw SQL predicate, e.g. ``"task LIKE '%Uncoupled%' AND foraging_eff > 0.8"``. + subjects : iterable, optional + Restrict to these subject ids. + columns : list[str], optional + Extra session-metadata columns to carry onto trials/events. ``_session_id``, + ``subject_id``, ``session_date`` are always included. + base : str, optional + Session parquet file path (default: production S3 ``session_table.parquet``). + con : duckdb connection, optional + order_by : str, optional + SQL ORDER BY clause (default: ``"subject_id, session_date"``). + + Returns + ------- + pandas.DataFrame + One row per selected session, with ``_session_id`` as the join key. + """ + base = base or SESSION_DB + extra = [c for c in (columns or []) if c not in ("_session_id", *_KEYS)] + sel_cols = ", ".join(["_session_id", "subject_id", "session_date", *extra]) + clauses = [] + if subjects is not None: + clauses.append(f"subject_id IN ({_quote_in(subjects)})") + if where: + clauses.append(f"({where})") + where_sql = ("WHERE " + " AND ".join(clauses)) if clauses else "" + order_sql = f"ORDER BY {order_by}" if order_by else "" + return _conn(con).sql( + f"SELECT {sel_cols} FROM read_parquet('{base}') {where_sql} {order_sql}" + ).df() + + +def fetch_trials(sessions, columns=None, base=None, con=None): + """Pull trial rows for a set of selected sessions, with session metadata joined on. + + Parameters + ---------- + sessions : pandas.DataFrame + Selected sessions from :func:`select_sessions`. Must contain ``_session_id`` + and ``subject_id``; every other column is carried onto each trial row. + columns : list[str] or "*", optional + Trial columns to project (default: small choice/reward set). ``"*"`` returns all. + base : str, optional + Trial-table directory prefix (default: production S3). + con : duckdb connection, optional + + Returns + ------- + pandas.DataFrame + One row per trial, ordered by ``subject_id, session_date, trial``. + """ + return _fetch(sessions, base or TRIAL_DB, columns or DEFAULT_TRIAL_COLUMNS, + con, order_tail="trial", lead="trial") + + +def fetch_events(sessions, events=None, columns=None, base=None, con=None): + """Pull event rows for a set of selected sessions, with session metadata joined on. + + Parameters + ---------- + sessions : pandas.DataFrame + Selected sessions from :func:`select_sessions`. + events : iterable, optional + Restrict to these event types, e.g. ``['left_lick_time', 'right_lick_time']``. + columns : list[str] or "*", optional + Event columns to project (default: ``trial, timestamps, event, data``). + base : str, optional + Event-table directory prefix (default: production S3). + con : duckdb connection, optional + + Returns + ------- + pandas.DataFrame + One row per event, ordered by ``subject_id, session_date, timestamps``. + """ + extra_where = f"t.event IN ({_quote_in(events)})" if events else None + return _fetch(sessions, base or EVENT_DB, columns or DEFAULT_EVENT_COLUMNS, + con, order_tail="timestamps", extra_where=extra_where) + + +def _fetch(sessions, base, columns, con, order_tail, extra_where=None, lead=None): + import pandas as pd + + if len(sessions) == 0: + return pd.DataFrame() + conn = _conn(con) + src = _scoped_read(base, sessions["subject_id"].unique().tolist(), con) + conn.register("_sel_sessions", sessions) + try: + try: + return _run_fetch(conn, src, sessions, columns, order_tail, extra_where, None, lead) + except duckdb.BinderException: + avail = set(conn.sql(f"DESCRIBE SELECT * FROM {src}").df()["column_name"]) + return _run_fetch(conn, src, sessions, columns, order_tail, extra_where, avail, lead) + finally: + conn.unregister("_sel_sessions") + + +def _col_expr(col, avail): + return f"t.{col}" if (avail is None or col in avail) else f"CAST(NULL AS DOUBLE) AS {col}" + + +def _run_fetch(conn, src, sessions, columns, order_tail, extra_where, avail, lead=None): + meta = [f"s.{c}" for c in sessions.columns if c not in ("_session_id", *_KEYS)] + lead_proj = [_col_expr(lead, avail)] if lead else [] + if columns in ("*", ["*"]): + excl = [k for k in _KEYS if avail is None or k in avail] + if lead and (avail is None or lead in avail): + excl.append(lead) + proj = [f"t.* EXCLUDE ({', '.join(excl)})"] + else: + proj = [_col_expr(c, avail) for c in columns if c not in _KEYS and c != lead] + select = ", ".join(["s.subject_id", "s.session_date", "t.session_id", *lead_proj, *meta, *proj]) + where_sql = f"WHERE {extra_where}" if extra_where else "" + order = ["s.subject_id", "s.session_date"] + if avail is None or order_tail in avail: + order.append(f"t.{order_tail}") + return conn.sql(f""" + SELECT {select} + FROM {src} t + JOIN _sel_sessions s ON t.session_id = s._session_id + {where_sql} + ORDER BY {', '.join(order)} + """).df() diff --git a/src/zombie_squirrel/acorn_helpers/foraging/session.py b/src/zombie_squirrel/acorn_helpers/foraging/session.py new file mode 100644 index 0000000..b75df2b --- /dev/null +++ b/src/zombie_squirrel/acorn_helpers/foraging/session.py @@ -0,0 +1,93 @@ +"""Foraging session acorn: one row per foraging session, sourced from the upstream cache.""" + +import logging + +import duckdb +import pandas as pd + +import zombie_squirrel.acorns as acorns +from zombie_squirrel.squirrel import Column +from zombie_squirrel.utils import SquirrelMessage, setup_logging + +UPSTREAM_SESSION_S3 = "s3://aind-scratch-data/aind-dynamic-foraging-cache/session_table.parquet" +_TABLE_NAME = "foraging/session" + + +def _add_asset_name(df: pd.DataFrame) -> pd.DataFrame: + """Derive asset_name from co_s3_nwb_uri for joining with asset_basics.""" + df = df.copy() + uri = df["co_s3_nwb_uri"].astype(object) # ensure object dtype so .str works on NaN-only columns + df["asset_name"] = uri.str.extract(r"/nwb/(.+?)\.nwb$", expand=False) + return df + + +def _fetch_upstream() -> pd.DataFrame: + conn = duckdb.connect() + conn.execute("INSTALL httpfs; LOAD httpfs;") + df = conn.sql(f"SELECT * FROM read_parquet('{UPSTREAM_SESSION_S3}')").df() + return _add_asset_name(df) + + +@acorns.register_acorn(acorns.NAMES["foraging"]) +def foraging_session(force_update: bool = False) -> pd.DataFrame: + """Return a table of dynamic foraging sessions from the upstream parquet cache. + + Source: s3://aind-scratch-data/aind-dynamic-foraging-cache/session_table.parquet + + Args: + force_update: If True, bypass cache and re-fetch from upstream S3. + + Returns: + DataFrame with one row per session. Includes asset_name for joining + with asset_basics, plus all upstream session metrics and metadata. + """ + df = acorns.TREE.scurry(_TABLE_NAME) + + if df.empty and not force_update: + raise ValueError("Cache is empty. Use force_update=True to fetch from upstream.") + + if df.empty or force_update: + setup_logging() + logging.info( + SquirrelMessage( + tree=acorns.TREE.__class__.__name__, + acorn=acorns.NAMES["foraging"], + message="Updating cache from upstream S3", + ).to_json() + ) + df = _fetch_upstream() + acorns.TREE.hide(_TABLE_NAME, df) + + return df + + +def foraging_session_columns() -> list[Column]: + return [ + Column(name="subject_id", description="Subject/mouse ID"), + Column(name="session_date", description="Date of the session (YYYY-MM-DD)"), + Column(name="nwb_suffix", description="NWB time suffix identifying the session file"), + Column(name="session", description="Session number within the day"), + Column(name="_session_id", description="Unique session key: subject_id_session_date_nwb_suffix"), + Column(name="asset_name", description="AIND asset name for joining with asset_basics"), + Column(name="co_asset_id", description="Code Ocean data asset ID"), + Column(name="co_s3_nwb_uri", description="S3 URI of the NWB file inside the CO asset"), + Column(name="nwb_data_source", description="NWB data source: co_asset, bonsai_s3, or bpod_s3"), + Column(name="rig", description="Rig used for the session"), + Column(name="trainer", description="Trainer who ran the session"), + Column(name="PI", description="Principal investigator"), + Column(name="curriculum_name", description="Auto-training curriculum name"), + Column(name="curriculum_version", description="Auto-training curriculum version"), + Column(name="current_stage_actual", description="Actual curriculum stage at session time"), + Column(name="task", description="Task name"), + Column(name="session_start_time", description="Session start timestamp"), + Column(name="session_end_time", description="Session end timestamp"), + Column(name="session_run_time_in_min", description="Session duration in minutes"), + Column(name="total_trials", description="Total number of trials"), + Column(name="finished_trials", description="Number of finished (non-ignored) trials"), + Column(name="finished_rate", description="Fraction of trials that were finished"), + Column(name="foraging_eff", description="Foraging efficiency metric"), + Column(name="foraging_eff_random_seed", description="Foraging efficiency with random seed baseline"), + Column(name="bias_naive", description="Naive side bias estimate"), + Column(name="reaction_time_median", description="Median reaction time in seconds"), + Column(name="early_lick_rate", description="Rate of early lick trials"), + ] diff --git a/src/zombie_squirrel/acorn_helpers/foraging_sessions.py b/src/zombie_squirrel/acorn_helpers/foraging_sessions.py deleted file mode 100644 index 53564dc..0000000 --- a/src/zombie_squirrel/acorn_helpers/foraging_sessions.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Foraging sessions acorn: one row per behavior session from df_sessions.pkl.""" - -import io -import logging - -import boto3 -import pandas as pd - -import zombie_squirrel.acorns as acorns -from zombie_squirrel.squirrel import Column -from zombie_squirrel.utils import SquirrelMessage, normalize_name, setup_logging - -_SOURCE_BUCKET = "aind-behavior-data" -_SOURCE_KEY = "foraging_nwb_bonsai_processed/df_sessions.pkl" - -_COLUMN_MAP = { - ("metadata", "rig"): "rig", - ("metadata", "user_name"): "trainer", - ("metadata", "task"): "task", - ("auto_train", "curriculum_name"): "curriculum_name", - ("auto_train", "curriculum_version"): "curriculum_version", - ("auto_train", "current_stage_actual"): "current_stage_actual", - ("session_stats", "foraging_eff"): "foraging_eff", - ("session_stats", "foraging_eff_random_seed"): "foraging_eff_random_seed", - ("session_stats", "finished_trials"): "finished_trials", - ("session_stats", "finished_rate"): "finished_rate", - ("session_stats", "total_trials"): "total_trials", - ("session_stats", "bias_naive"): "bias_naive", -} - - -def _fetch_foraging_sessions() -> pd.DataFrame: - """Download df_sessions.pkl from S3 and return a flattened DataFrame.""" - s3 = boto3.client("s3") - obj = s3.get_object(Bucket=_SOURCE_BUCKET, Key=_SOURCE_KEY) - raw = pd.read_pickle(io.BytesIO(obj["Body"].read())) - - # Extract row-index levels into a plain dict — avoids reset_index() which - # pads them into the MultiIndex column structure as ('subject_id', '') etc., - # causing to_parquet() to write tuple-string column names that DuckDB can't use. - idx = raw.index.to_frame(index=False) - data = { - "subject_id": idx["subject_id"], - "session_date": idx["session_date"].astype(str), - "session": idx["session"], - "nwb_suffix": idx["nwb_suffix"], - } - for src_col, dest_col in _COLUMN_MAP.items(): - data[dest_col] = raw[src_col].values - - df = pd.DataFrame(data) - df["trainer_normalized"] = df["trainer"].apply(lambda v: normalize_name(v) if pd.notna(v) else "") - return df - - -@acorns.register_acorn(acorns.NAMES["foraging"]) -def foraging_sessions(force_update: bool = False) -> pd.DataFrame: - """Return a flattened table of foraging behavior sessions. - - Source: s3://aind-behavior-data/foraging_nwb_bonsai_processed/df_sessions.pkl - - Args: - force_update: If True, bypass cache and rebuild from source pkl. - - Returns: - DataFrame with one row per session and the columns listed in - foraging_sessions_columns(). - """ - df = acorns.TREE.scurry(acorns.NAMES["foraging"]) - - if df.empty and not force_update: - raise ValueError("Cache is empty. Use force_update=True to fetch data from source.") - - if df.empty or force_update: - setup_logging() - logging.info( - SquirrelMessage( - tree=acorns.TREE.__class__.__name__, - acorn=acorns.NAMES["foraging"], - message="Updating cache from S3 pkl", - ).to_json() - ) - - df = _fetch_foraging_sessions() - acorns.TREE.hide(acorns.NAMES["foraging"], df) - - return df - - -def foraging_sessions_columns() -> list[Column]: - """Return foraging_sessions acorn column definitions.""" - return [ - Column(name="subject_id", description="Subject/mouse ID"), - Column(name="session_date", description="Date of the session (YYYY-MM-DD)"), - Column(name="session", description="Session number within the day"), - Column(name="nwb_suffix", description="NWB file suffix identifying the session file"), - Column(name="rig", description="Rig/apparatus used for the session"), - Column(name="trainer", description="User/trainer who ran the session"), - Column(name="trainer_normalized", description="Normalized display name of the trainer"), - Column(name="task", description="Task name (e.g. Coupled Baiting)"), - Column(name="curriculum_name", description="Auto-training curriculum name"), - Column(name="curriculum_version", description="Auto-training curriculum version"), - Column(name="current_stage_actual", description="Actual training stage at time of session"), - Column(name="foraging_eff", description="Foraging efficiency (fraction of optimal reward collected)"), - Column(name="foraging_eff_random_seed", description="Foraging efficiency relative to random-seed baseline"), - Column(name="finished_trials", description="Number of finished (non-ignored) trials"), - Column(name="finished_rate", description="Fraction of trials that were finished"), - Column(name="total_trials", description="Total number of trials in the session"), - Column(name="bias_naive", description="Naive lick-side bias estimate"), - ] diff --git a/src/zombie_squirrel/acorns.py b/src/zombie_squirrel/acorns.py index 368583e..db65f6c 100644 --- a/src/zombie_squirrel/acorns.py +++ b/src/zombie_squirrel/acorns.py @@ -44,7 +44,7 @@ "upgrade": "metadata_upgrade", "fib": "platform_fib", "core": "metadata_core", - "foraging": "foraging_sessions", + "foraging": "foraging_session", "curriculum": "behavior_curriculum", "platform_qc": "platform_qc", } diff --git a/src/zombie_squirrel/forest.py b/src/zombie_squirrel/forest.py index 2be802b..5a7b803 100644 --- a/src/zombie_squirrel/forest.py +++ b/src/zombie_squirrel/forest.py @@ -21,6 +21,13 @@ } +def _is_hive_partition(table_name: str) -> bool: + """Return True only if table_name is a partitioned write (base in HIVE_PARTITION_KEYS).""" + if "/" not in table_name: + return False + return table_name.split("/")[0] in HIVE_PARTITION_KEYS + + class Tree(ABC): """Base class for a storage backend (the cache).""" @@ -71,7 +78,7 @@ def __init__(self) -> None: def hide(self, table_name: str, data: pd.DataFrame) -> None: """Store DataFrame as parquet file in S3.""" - if "/" in table_name: + if _is_hive_partition(table_name): base, value = table_name.split("/", 1) partition_key = HIVE_PARTITION_KEYS[base] s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data.pqt" @@ -114,7 +121,7 @@ def scurry(self, table_name: str | list[str]) -> pd.DataFrame: def _scurry_single(self, table_name: str) -> pd.DataFrame: """Fetch a single table from S3.""" - if "/" in table_name: + if _is_hive_partition(table_name): base, value = table_name.split("/", 1) partition_key = HIVE_PARTITION_KEYS[base] s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data.pqt" @@ -146,7 +153,7 @@ def get_location(self, table_name: str, partitioned: bool = False) -> str: """Return the S3 URI for a given table.""" if partitioned: return f"s3://{self.bucket}/{_CACHE_ROOT}/{_VERSION_FOLDER}/{table_name}/" - if "/" in table_name: + if _is_hive_partition(table_name): base, value = table_name.split("/", 1) partition_key = HIVE_PARTITION_KEYS[base] return f"s3://{self.bucket}/{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data.pqt" @@ -260,7 +267,7 @@ def get_location(self, table_name: str, partitioned: bool = False) -> str: """Return the in-memory identifier for a given table.""" if partitioned: return f"{_VERSION_FOLDER}/{table_name}/" - if "/" in table_name: + if _is_hive_partition(table_name): base, value = table_name.split("/", 1) partition_key = HIVE_PARTITION_KEYS[base] return f"{_VERSION_FOLDER}/{base}/{partition_key}={value}/data.pqt" diff --git a/src/zombie_squirrel/sync.py b/src/zombie_squirrel/sync.py index 96cd92d..8d3f23d 100644 --- a/src/zombie_squirrel/sync.py +++ b/src/zombie_squirrel/sync.py @@ -4,7 +4,7 @@ from .acorn_helpers.asset_basics import asset_basics_columns from .acorn_helpers.assets_smartspim import assets_smartspim_columns -from .acorn_helpers.foraging_sessions import foraging_sessions_columns +from .acorn_helpers.foraging.session import foraging_session_columns from .acorn_helpers.behavior_curriculum import behavior_curriculum_columns from .acorn_helpers.platform_fib import platform_fib_columns from .acorn_helpers.platform_qc import platform_qc_columns, PLATFORMS @@ -100,11 +100,11 @@ def publish_squirrel_metadata() -> None: ), Acorn( name=NAMES["foraging"], - description="Foraging behavior sessions with key performance metrics, one row per session", - location=TREE.get_location(NAMES["foraging"]), + description="Dynamic foraging sessions with full session metrics, one row per session", + location=TREE.get_location("foraging/session"), partitioned=False, type=AcornType.metadata, - columns=foraging_sessions_columns(), + columns=foraging_session_columns(), ), Acorn( name=NAMES["curriculum"], @@ -162,7 +162,7 @@ def hide_acorns(): ACORN_REGISTRY[NAMES["smartspim"]](force_update=True) ACORN_REGISTRY[NAMES["fib"]](force_update=True) - ACORN_REGISTRY[NAMES["foraging"]](force_update=True) + ACORN_REGISTRY[NAMES["foraging"]](force_update=True) # foraging_session ACORN_REGISTRY[NAMES["curriculum"]](force_update=True) for platform in PLATFORMS: diff --git a/tests/acorn_helpers/test_foraging_session.py b/tests/acorn_helpers/test_foraging_session.py new file mode 100644 index 0000000..44d47b5 --- /dev/null +++ b/tests/acorn_helpers/test_foraging_session.py @@ -0,0 +1,119 @@ +"""Unit tests for the foraging session acorn.""" + +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from zombie_squirrel.acorn_helpers.foraging.session import ( + _TABLE_NAME, + _add_asset_name, + foraging_session, + foraging_session_columns, +) + + +def _make_session_df(**overrides): + row = { + "subject_id": "123456", + "session_date": "2024-01-15", + "nwb_suffix": 100000, + "_session_id": "123456_2024-01-15_100000", + "co_s3_nwb_uri": "s3://bucket/asset-id/nwb/behavior_123456_2024-01-15_10-00-00.nwb", + "foraging_eff": 0.75, + "finished_trials": 300.0, + } + row.update(overrides) + return pd.DataFrame([row]) + + +class TestAddAssetName: + def test_extracts_from_valid_uri(self): + df = _make_session_df() + result = _add_asset_name(df) + assert result["asset_name"].iloc[0] == "behavior_123456_2024-01-15_10-00-00" + + def test_nan_when_uri_is_nan(self): + df = _make_session_df(co_s3_nwb_uri=float("nan")) + result = _add_asset_name(df) + assert pd.isna(result["asset_name"].iloc[0]) + + def test_does_not_mutate_input(self): + df = _make_session_df() + original_cols = set(df.columns) + _add_asset_name(df) + assert set(df.columns) == original_cols + + def test_multiple_rows(self): + df = pd.DataFrame([ + {"co_s3_nwb_uri": "s3://b/a1/nwb/behavior_111_2024-01-01_09-00-00.nwb"}, + {"co_s3_nwb_uri": "s3://b/a2/nwb/behavior_222_2024-02-01_10-30-00.nwb"}, + {"co_s3_nwb_uri": float("nan")}, + ]) + result = _add_asset_name(df) + assert result["asset_name"].iloc[0] == "behavior_111_2024-01-01_09-00-00" + assert result["asset_name"].iloc[1] == "behavior_222_2024-02-01_10-30-00" + assert pd.isna(result["asset_name"].iloc[2]) + + +class TestForagingSessionAcorn: + @patch("zombie_squirrel.acorn_helpers.foraging.session.acorns.TREE") + def test_cache_hit_returns_df(self, mock_tree): + cached = _make_session_df() + mock_tree.scurry.return_value = cached + result = foraging_session(force_update=False) + mock_tree.scurry.assert_called_once_with(_TABLE_NAME) + assert len(result) == 1 + + @patch("zombie_squirrel.acorn_helpers.foraging.session.acorns.TREE") + def test_empty_cache_raises(self, mock_tree): + mock_tree.scurry.return_value = pd.DataFrame() + with pytest.raises(ValueError, match="Cache is empty"): + foraging_session(force_update=False) + + @patch("zombie_squirrel.acorn_helpers.foraging.session._fetch_upstream") + @patch("zombie_squirrel.acorn_helpers.foraging.session.acorns.TREE") + def test_force_update_fetches_and_hides(self, mock_tree, mock_fetch): + mock_tree.scurry.return_value = pd.DataFrame() + fresh = _make_session_df() + mock_fetch.return_value = fresh + result = foraging_session(force_update=True) + mock_fetch.assert_called_once() + mock_tree.hide.assert_called_once_with(_TABLE_NAME, fresh) + assert len(result) == 1 + + @patch("zombie_squirrel.acorn_helpers.foraging.session._fetch_upstream") + @patch("zombie_squirrel.acorn_helpers.foraging.session.acorns.TREE") + def test_cold_cache_with_force_update_fetches(self, mock_tree, mock_fetch): + mock_tree.scurry.return_value = _make_session_df() + fresh = _make_session_df(foraging_eff=0.9) + mock_fetch.return_value = fresh + result = foraging_session(force_update=True) + mock_fetch.assert_called_once() + assert result["foraging_eff"].iloc[0] == 0.9 + + @patch("zombie_squirrel.acorn_helpers.foraging.session._fetch_upstream") + @patch("zombie_squirrel.acorn_helpers.foraging.session.acorns.TREE") + def test_no_fetch_on_warm_cache(self, mock_tree, mock_fetch): + mock_tree.scurry.return_value = _make_session_df() + foraging_session(force_update=False) + mock_fetch.assert_not_called() + + +class TestForagingSessionColumns: + def test_returns_list_of_columns(self): + cols = foraging_session_columns() + assert isinstance(cols, list) + assert len(cols) > 0 + + def test_has_asset_name_column(self): + names = [c.name for c in foraging_session_columns()] + assert "asset_name" in names + + def test_has_session_id_column(self): + names = [c.name for c in foraging_session_columns()] + assert "_session_id" in names + + def test_all_columns_have_descriptions(self): + for col in foraging_session_columns(): + assert col.description, f"Column '{col.name}' has no description" diff --git a/tests/test_sync.py b/tests/test_sync.py index f93790a..65e901b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -21,7 +21,7 @@ def _make_registry(mock_upn, mock_usi, mock_ugt, mock_basics, mock_d2r, mock_r2d "assets_smartspim": mock_smartspim, "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), } @@ -204,7 +204,7 @@ def test_hide_acorns_calls_all_acorns(mock_registry, mock_tree): "assets_smartspim": mock_smartspim, "metadata_upgrade": MagicMock(), "platform_fib": mock_fib, - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), }[x] @@ -239,7 +239,7 @@ def test_hide_acorns_empty_registry(mock_registry, mock_tree): "assets_smartspim": MagicMock(), "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), }[x] @@ -271,7 +271,7 @@ def test_hide_acorns_single_acorn(mock_registry, mock_tree): "assets_smartspim": MagicMock(), "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), }[x] @@ -298,7 +298,7 @@ def test_hide_acorns_acorn_order_independent(mock_registry, mock_tree): "assets_smartspim": MagicMock(), "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), }[x] @@ -325,7 +325,7 @@ def test_hide_acorns_propagates_exceptions(mock_registry): "assets_smartspim": MagicMock(), "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), - "foraging_sessions": MagicMock(), + "foraging_session": MagicMock(), "behavior_curriculum": MagicMock(), "platform_qc": MagicMock(), }[x] diff --git a/tests/test_sync_coverage.py b/tests/test_sync_coverage.py index 4719de7..8065bf4 100644 --- a/tests/test_sync_coverage.py +++ b/tests/test_sync_coverage.py @@ -33,6 +33,9 @@ def qc_side_effect(*args, **kwargs): "assets_smartspim": MagicMock(), "metadata_upgrade": MagicMock(), "platform_fib": MagicMock(), + "foraging_session": MagicMock(), + "behavior_curriculum": MagicMock(), + "platform_qc": MagicMock(), }.__getitem__ failed_future = MagicMock()