Skip to content

Commit 632c496

Browse files
committed
Fix geneformer tests
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 106496e commit 632c496

3 files changed

Lines changed: 6 additions & 3 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
**.png
2+
pretrain-recipe-short.yaml

sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/celltype_classification_bench/bench.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ def load_data_run_benchmark(result_path, adata_path, write_results=True):
153153
import torch
154154

155155
adata = read_h5ad(adata_path)
156-
157-
infer_Xs = torch.load(result_path / "predictions__rank_0__dp_rank_0.pt")["embeddings"].float().cpu().numpy()
156+
# TODO: update the prediction writer to support model and data parallelism, and modify the
157+
# path accordingly.
158+
infer_Xs = torch.load(result_path / "predictions__rank_0.pt")["embeddings"].float().cpu().numpy()
158159
assert len(adata) == len(infer_Xs), (len(adata), len(infer_Xs))
159160

160161
infer_metadata = adata.obs

sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_celltype_bench.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def test_celltype_embeddings_golden_values():
8888
config_class=GeneformerConfig,
8989
include_unrecognized_vocab_in_dataset=False,
9090
)
91-
91+
# TODO: update the prediction writer to support model and data parallelism, and modify the
92+
# path accordingly.
9293
result = torch.load(results_path / "predictions__rank_0.pt")["embeddings"]
9394
expected_vals = torch.load(golden_values_path / "predictions__rank_0.pt")["embeddings"]
9495

0 commit comments

Comments
 (0)