Skip to content

Commit dfd83a7

Browse files
authored
Remove pickle use from SCDL [BIO-408] (#1545)
### Description Remove the allow_pickle=True option when loading feature name paths in scdl. ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [ ] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [x] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Dataset labels now stored in JSON format for improved portability and safety. * **Bug Fixes** * Enhanced error handling with clearer messages for missing or corrupted label files. * Added backward compatibility with deprecation warnings for legacy label formats. * **Tests** * Added tests for label format migration and error scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: John St John <jstjohn@nvidia.com> Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent c4a160d commit dfd83a7

3 files changed

Lines changed: 97 additions & 3 deletions

File tree

sub-packages/bionemo-scdl/src/bionemo/scdl/index/row_feature_index.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from __future__ import annotations
3131

3232
import importlib.metadata
33+
import json
34+
import warnings
3335
from abc import ABC, abstractmethod
3436
from pathlib import Path
3537
from typing import Optional, Sequence, Tuple
@@ -143,7 +145,31 @@ def _load_common(datapath: str, instance: "RowFeatureIndex") -> "RowFeatureIndex
143145
for features in instance._feature_arr:
144146
instance._extend_num_entries_per_row(features)
145147
instance._cumulative_sum_index = np.load(Path(datapath) / "cumulative_sum_index.npy")
146-
instance._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=True)
148+
labels_json_path = Path(datapath) / "labels.json"
149+
legacy_labels_npy_path = Path(datapath) / "labels.npy"
150+
if labels_json_path.exists():
151+
with open(labels_json_path) as f:
152+
instance._labels = json.load(f)
153+
elif legacy_labels_npy_path.exists():
154+
warnings.warn(
155+
f"Found legacy labels.npy in '{datapath}'. This format is deprecated due to a "
156+
"security vulnerability (arbitrary code execution via pickle deserialization). "
157+
"To re-index, load this dataset and call .save() to write the new labels.json format. "
158+
"Support for labels.npy will be removed in a future release.",
159+
FutureWarning,
160+
stacklevel=3,
161+
)
162+
try:
163+
instance._labels = list(np.load(legacy_labels_npy_path, allow_pickle=False))
164+
except ValueError:
165+
raise ValueError(
166+
f"Cannot safely load labels.npy in '{datapath}' because it contains pickled objects. "
167+
"This is a security risk and is no longer supported. To migrate, re-create the dataset "
168+
"from source (e.g. re-run your h5ad-to-SCDL conversion) so that labels are saved in "
169+
"the new JSON format."
170+
)
171+
else:
172+
raise FileNotFoundError(f"No labels file found in {datapath}. Expected labels.json or labels.npy.")
147173
instance._version = np.load(Path(datapath) / "version.npy").item()
148174
return instance
149175

@@ -305,7 +331,8 @@ def save(self, datapath: str) -> None:
305331
dataframe_str_index = f"{index:0{num_digits}d}"
306332
pq.write_table(table, f"{datapath}/dataframe_{dataframe_str_index}.parquet")
307333
np.save(Path(datapath) / "cumulative_sum_index.npy", self._cumulative_sum_index)
308-
np.save(Path(datapath) / "labels.npy", self._labels)
334+
with open(Path(datapath) / "labels.json", "w") as f:
335+
json.dump([str(label) if label is not None else None for label in self._labels], f)
309336
np.save(Path(datapath) / "version.npy", np.array(self._version))
310337

311338

sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ def _write_header(self):
985985
features_rel_path = f"{feature_index_path}"
986986
index_files: List[str] = [
987987
f"{features_rel_path}/cumulative_sum_index.npy",
988-
f"{features_rel_path}/labels.npy",
988+
f"{features_rel_path}/labels.json",
989989
f"{features_rel_path}/version.npy",
990990
]
991991
if num_frames > 0:

sub-packages/bionemo-scdl/tests/bionemo/scdl/index/test_obs_feature_index.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
17+
import warnings
18+
1619
import numpy as np
1720
import pandas as pd
1821
import pytest
@@ -261,3 +264,67 @@ def testObeservedFetureIndex_getitem_slice_with_step_and_order_preserved(make_fe
261264
assert set(actual.keys()) == set(exp.keys())
262265
for k in exp:
263266
assert np.array_equal(actual[k], exp[k])
267+
268+
269+
def test_load_prefers_json_over_npy(tmp_path, make_feat_dictionary):
270+
"""When both labels.json and labels.npy exist, load should use labels.json and emit no FutureWarning."""
271+
idx = ObservedFeatureIndex()
272+
idx.append_features(make_feat_dictionary(2, 3), label="A")
273+
idx.save(tmp_path / "features")
274+
275+
# Write an extra legacy labels.npy alongside the existing labels.json
276+
legacy_labels = np.array(["WRONG_LABEL"])
277+
np.save(tmp_path / "features" / "labels.npy", legacy_labels)
278+
279+
with warnings.catch_warnings():
280+
warnings.simplefilter("error", FutureWarning)
281+
reloaded = ObservedFeatureIndex.load(tmp_path / "features")
282+
283+
assert reloaded.number_of_rows() == idx.number_of_rows()
284+
# Verify it used the JSON data (label "A"), not the npy data ("WRONG_LABEL")
285+
_, labels = reloaded[0 : reloaded.number_of_rows()]
286+
assert labels == ["A"]
287+
288+
289+
def test_load_legacy_labels_npy_emits_future_warning(tmp_path, make_feat_dictionary):
290+
"""Loading an index saved with the old labels.npy format should emit a FutureWarning."""
291+
idx = ObservedFeatureIndex()
292+
idx.append_features(make_feat_dictionary(2, 3), label="A")
293+
idx.save(tmp_path / "features")
294+
# Replace labels.json with a legacy labels.npy
295+
labels_json = tmp_path / "features" / "labels.json"
296+
labels_npy = tmp_path / "features" / "labels.npy"
297+
with open(labels_json) as f:
298+
labels = json.load(f)
299+
np.save(labels_npy, np.array([l if l is not None else "" for l in labels]))
300+
labels_json.unlink()
301+
302+
with pytest.warns(FutureWarning, match="legacy labels.npy"):
303+
reloaded = ObservedFeatureIndex.load(tmp_path / "features")
304+
assert reloaded.number_of_rows() == idx.number_of_rows()
305+
306+
307+
def test_load_legacy_labels_npy_with_pickle_raises(tmp_path, make_feat_dictionary):
308+
"""Loading a labels.npy that requires pickle should raise ValueError."""
309+
idx = ObservedFeatureIndex()
310+
idx.append_features(make_feat_dictionary(2, 3), label="A")
311+
idx.save(tmp_path / "features")
312+
# Replace labels.json with a labels.npy containing an object array (requires pickle)
313+
labels_json = tmp_path / "features" / "labels.json"
314+
labels_npy = tmp_path / "features" / "labels.npy"
315+
np.save(labels_npy, np.array([None, "A"], dtype=object), allow_pickle=True)
316+
labels_json.unlink()
317+
318+
with pytest.raises(ValueError, match="contains pickled objects"):
319+
ObservedFeatureIndex.load(tmp_path / "features")
320+
321+
322+
def test_load_missing_labels_file_raises(tmp_path, make_feat_dictionary):
323+
"""Loading an index with no labels file should raise FileNotFoundError."""
324+
idx = ObservedFeatureIndex()
325+
idx.append_features(make_feat_dictionary(2, 3), label="A")
326+
idx.save(tmp_path / "features")
327+
(tmp_path / "features" / "labels.json").unlink()
328+
329+
with pytest.raises(FileNotFoundError, match="No labels file found"):
330+
ObservedFeatureIndex.load(tmp_path / "features")

0 commit comments

Comments
 (0)