Skip to content

Commit 91869ec

Browse files
authored
Reduce moco notebook training iterations in CI via FAST_CI_MODE (#1542)
## Summary Fixes [BIO-403](https://linear.app/nvidia/issue/BIO-403/investigate-flaky-tests-and-timeout-errors-in-bionemo-moco-notebook): Investigate flaky tests and timeout errors in bionemo-moco notebook. The `discrete_data_interpolant_tutorial.ipynb` notebook has 3 training loops that hard-code `range(50000)` iterations. When executed in CI via `nbval`, these loops cause timeouts. ## Changes 1. Added a configuration cell that detects the `FAST_CI_MODE` environment variable (already set by `unit-tests-framework.yml` when running notebook tests) 2. Sets `NUM_TRAINING_STEPS = 500` when in CI mode, `50000` otherwise 3. Replaced all 3 `range(50000)` occurrences with `range(NUM_TRAINING_STEPS)` in the DFM, D3PM, and MDLM training loops This follows the existing pattern used in `bionemo-recipes/recipes/evo2_megatron/examples/` notebooks. ## Testing - The notebook behavior is unchanged when run outside CI (`FAST_CI_MODE` not set) - When `FAST_CI_MODE=true` (as in framework CI), iterations drop from 50,000 to 500, preventing timeouts Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com> Co-authored-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent 0883c09 commit 91869ec

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

sub-packages/bionemo-moco/examples/discrete_data_interpolant_tutorial.ipynb

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222
"torch.cuda.manual_seed(42)"
2323
]
2424
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"import os\n",
32+
"\n",
33+
"\n",
34+
"FAST_CI_MODE: bool = os.environ.get(\"FAST_CI_MODE\", \"\").lower() in (\"1\", \"true\", \"yes\")\n",
35+
"NUM_TRAINING_STEPS = 500 if FAST_CI_MODE else 50000"
36+
]
37+
},
2538
{
2639
"cell_type": "markdown",
2740
"metadata": {},
@@ -144,7 +157,7 @@
144157
"source": [
145158
"model = model.to(DEVICE)\n",
146159
"losses = []\n",
147-
"for _ in tqdm(range(50000)):\n",
160+
"for _ in tqdm(range(NUM_TRAINING_STEPS)):\n",
148161
" num_ones = torch.randint(0, D + 1, (B,))\n",
149162
" x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)\n",
150163
" # x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]\n",
@@ -659,7 +672,7 @@
659672
"# NBVAL_SKIP\n",
660673
"model = model.to(DEVICE)\n",
661674
"losses = []\n",
662-
"for _ in tqdm(range(50000)):\n",
675+
"for _ in tqdm(range(NUM_TRAINING_STEPS)):\n",
663676
" num_ones = torch.randint(0, D + 1, (B,))\n",
664677
" x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)\n",
665678
" # x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]\n",
@@ -892,7 +905,7 @@
892905
"\n",
893906
"model = model.to(DEVICE)\n",
894907
"losses = []\n",
895-
"for _ in tqdm(range(50000)):\n",
908+
"for _ in tqdm(range(NUM_TRAINING_STEPS)):\n",
896909
" num_ones = torch.randint(0, D + 1, (B,))\n",
897910
" x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)\n",
898911
" # x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]\n",

0 commit comments

Comments
 (0)