Skip to content

Commit 45797d2

Browse files
committed
fix: add chunking to get past memory issues for fib_traces
1 parent 0075cb9 commit 45797d2

4 files changed

Lines changed: 92 additions & 14 deletions

File tree

src/biodata_cache/backend.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,60 @@ def read(self, table_name: str | list[str]) -> pd.DataFrame:
131131
return self._read_multiple(table_name)
132132
return self._read_single(table_name)
133133

134+
def clear_partition(self, table_name: str) -> None:
135+
"""Delete all parquet chunk files in a hive partition."""
136+
if "/" not in table_name:
137+
return
138+
base, value = table_name.split("/", 1)
139+
partition_key = HIVE_PARTITION_KEYS[base]
140+
prefix = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/"
141+
paginator = self.s3_client.get_paginator("list_objects_v2")
142+
to_delete = []
143+
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
144+
for obj in page.get("Contents", []):
145+
to_delete.append({"Key": obj["Key"]})
146+
for i in range(0, len(to_delete), 1000):
147+
self.s3_client.delete_objects(
148+
Bucket=self.bucket,
149+
Delete={"Objects": to_delete[i : i + 1000]},
150+
)
151+
152+
def write_chunk(self, table_name: str, data: pd.DataFrame, chunk_idx: int) -> None:
153+
"""Append one numbered parquet chunk to a hive partition."""
154+
base, value = table_name.split("/", 1)
155+
partition_key = HIVE_PARTITION_KEYS[base]
156+
s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data_{chunk_idx:04d}.pqt"
157+
json_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}.json"
158+
159+
parquet_buffer = io.BytesIO()
160+
table = pa.Table.from_pandas(data, preserve_index=False)
161+
float_cols = [f.name for f in table.schema if pa.types.is_floating(f.type)]
162+
dict_cols = [f.name for f in table.schema if f.name not in float_cols]
163+
pq.write_table(
164+
table,
165+
parquet_buffer,
166+
compression="zstd",
167+
use_dictionary=dict_cols if dict_cols else False,
168+
column_encoding={col: "BYTE_STREAM_SPLIT" for col in float_cols} or None,
169+
)
170+
parquet_buffer.seek(0)
171+
self.s3_client.put_object(Bucket=self.bucket, Key=s3_key, Body=parquet_buffer.getvalue())
172+
logging.info(
173+
CacheLogMessage(
174+
backend="S3Backend", table=table_name, message=f"Stored chunk {chunk_idx} to s3://{self.bucket}/{s3_key}"
175+
).to_json()
176+
)
177+
metadata = {"columns": data.columns.tolist()}
178+
self.s3_client.put_object(
179+
Bucket=self.bucket, Key=json_key, Body=json.dumps(metadata)
180+
)
181+
134182
def _read_single(self, table_name: str) -> pd.DataFrame:
135183
"""Fetch a single table from S3."""
136184
if "/" in table_name:
137185
base, value = table_name.split("/", 1)
138186
partition_key = HIVE_PARTITION_KEYS[base]
139-
s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data.pqt"
187+
s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{base}/{partition_key}={value}/data*.pqt"
140188
else:
141189
s3_key = f"{_CACHE_ROOT}/{_VERSION_FOLDER}/{table_name}.pqt"
142190

@@ -315,6 +363,17 @@ def get_versions_index(self) -> list[str]:
315363
"""Return the list of all available version folders from the in-memory index."""
316364
return json.loads(self._json_store.get("cache_versions.json", "[]"))
317365

366+
def clear_partition(self, table_name: str) -> None:
367+
"""Remove all chunks stored for a partitioned table."""
368+
self._store.pop(table_name, None)
369+
370+
def write_chunk(self, table_name: str, data: pd.DataFrame, chunk_idx: int) -> None:
371+
"""Append one chunk to the in-memory store for a partitioned table."""
372+
existing = self._store.get(table_name, pd.DataFrame())
373+
self._store[table_name] = (
374+
pd.concat([existing, data], ignore_index=True) if not existing.empty else data.copy()
375+
)
376+
318377
def _read_multiple(self, table_names: list[str]) -> pd.DataFrame:
319378
"""Fetch and merge multiple tables from memory."""
320379
dfs = []

src/biodata_cache/cache_table_helpers/platform_fib_traces.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_FALLBACK_METHOD = "dff-bright"
2828
_S3_URI_RE = re.compile(r"^s3://([^/]+)/(.+)$")
2929
_MAX_WORKERS = 32
30+
_CHUNK_SIZE = 10
3031

3132

3233
def _log(message: str) -> None:
@@ -151,7 +152,12 @@ def _extract_session_traces(root, asset_name: str, subject_id: str) -> pd.DataFr
151152

152153

153154
def _fetch_subject_fib_traces(subject_id: str) -> pd.DataFrame:
154-
"""Fetch and cache all processed dF/F traces for a subject from S3 NWB files."""
155+
"""Fetch and cache all processed dF/F traces for a subject from S3 NWB files.
156+
157+
Sessions are flushed to storage every ``_CHUNK_SIZE`` sessions so the
158+
in-memory footprint stays bounded. Returns an empty DataFrame; callers
159+
should read back from the backend.
160+
"""
155161
setup_logging()
156162
cache_key = f"{registry.NAMES['fib_traces']}/{subject_id}"
157163
_log(f"Updating cache for subject {subject_id}")
@@ -165,8 +171,14 @@ def _fetch_subject_fib_traces(subject_id: str) -> pd.DataFrame:
165171
]
166172
subject_assets = subject_assets[subject_assets["data_level"] == "derived"]
167173

168-
frames = []
169-
for _, row in subject_assets.iterrows():
174+
registry.BACKEND.clear_partition(cache_key)
175+
176+
rows = list(subject_assets.iterrows())
177+
frames: list[pd.DataFrame] = []
178+
chunk_idx = 0
179+
n_sessions = 0
180+
181+
for i, (_, row) in enumerate(rows):
170182
location = row["location"]
171183
if not location:
172184
continue
@@ -175,16 +187,21 @@ def _fetch_subject_fib_traces(subject_id: str) -> pd.DataFrame:
175187
_log(f"No NWB file found for asset {row['name']}")
176188
continue
177189
session_df = _extract_session_traces(root, row["name"], subject_id)
190+
del root
178191
if not session_df.empty:
179192
frames.append(session_df)
193+
n_sessions += 1
180194

181-
df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
182-
if not df.empty:
183-
df = df.sort_values(["asset_name", "channel", "timestamp", "fiber"]).reset_index(drop=True)
195+
if frames and (n_sessions % _CHUNK_SIZE == 0 or i == len(rows) - 1):
196+
chunk_df = pd.concat(frames, ignore_index=True)
197+
chunk_df = chunk_df.sort_values(["asset_name", "channel", "timestamp", "fiber"]).reset_index(drop=True)
198+
frames = []
199+
registry.BACKEND.write_chunk(cache_key, chunk_df, chunk_idx)
200+
del chunk_df
201+
chunk_idx += 1
184202

185-
_log(f"Cached fib traces for subject {subject_id} ({len(frames)} sessions, {len(df)} samples)")
186-
registry.BACKEND.write(cache_key, df)
187-
return df
203+
_log(f"Cached fib traces for subject {subject_id} ({n_sessions} sessions)")
204+
return pd.DataFrame()
188205

189206

190207
@registry.register_table(registry.NAMES["fib_traces"])
@@ -228,6 +245,8 @@ def platform_fib_traces(
228245

229246
if force_update:
230247
df = _fetch_subject_fib_traces(subject_id)
248+
if df.empty:
249+
df = registry.BACKEND.read(cache_key)
231250

232251
return df
233252

tests/cache_table_helpers/test_platform_fib_traces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def test_fetch_subject_filters_and_writes(mock_basics, mock_extract, mock_open,
191191

192192
# Only the single derived fib asset for subject 856239 is processed.
193193
mock_open.assert_called_once_with("s3://bucket/abc")
194-
mock_registry.BACKEND.write.assert_called_once()
195-
assert mock_registry.BACKEND.write.call_args[0][0] == "platform_fib_traces/856239"
196-
assert not result.empty
194+
mock_registry.BACKEND.write_chunk.assert_called_once()
195+
assert mock_registry.BACKEND.write_chunk.call_args[0][0] == "platform_fib_traces/856239"
196+
assert result.empty
197197

198198

199199
@patch("biodata_cache.cache_table_helpers.platform_fib_traces.registry")

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_s3_scurry_partitioned_table(mock_boto3_client, mock_duckdb_query):
187187
mock_result.to_df.return_value = expected_df
188188
mock_duckdb_query.return_value = mock_result
189189
result = S3Backend().read("qc/subject123")
190-
assert f"data-asset-cache/{_VF}/qc/subject_id=subject123/data.pqt" in mock_duckdb_query.call_args[0][0]
190+
assert f"data-asset-cache/{_VF}/qc/subject_id=subject123/data*.pqt" in mock_duckdb_query.call_args[0][0]
191191
pd.testing.assert_frame_equal(result, expected_df)
192192

193193

0 commit comments

Comments
 (0)