Skip to content

Commit cad830b

Browse files
trvachovclaude
andcommitted
Fix pre-commit: add missing docstring args, apply ruff formatting
Add missing expert_parallel_size arg descriptions to checkpoint.py docstrings (D417) in both recipes. Apply ruff auto-formatting to fused_token_router.py and test_distributed_checkpointing.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 54ace79 commit cad830b

6 files changed

Lines changed: 31 additions & 15 deletions

File tree

bionemo-recipes/recipes/mixtral_native_te/checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def load_checkpoint_ddp(
186186
dataloader: The dataloader to load.
187187
weights_only: Whether to load the checkpoint weights only. We have to set this to True when loading FP8
188188
checkpoints.
189+
expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling.
189190
"""
190191
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
191192

@@ -394,6 +395,7 @@ def load_checkpoint_fsdp2(
394395
dist_config: The distributed configuration.
395396
dataloader: The dataloader to load.
396397
process_group: The process group to use for checkpointing.
398+
expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling.
397399
"""
398400
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
399401
if not checkpoint_path:
@@ -454,6 +456,7 @@ def save_checkpoint_fsdp2(
454456
process_group: The process group to use for checkpointing.
455457
max_checkpoints: The maximum number of checkpoints to keep.
456458
async_save: Whether to save the checkpoint asynchronously.
459+
expert_parallel_size: Expert parallelism size. When > 1, saves expert weights with EP-aware state dict handling.
457460
"""
458461
start_time = time.perf_counter()
459462
ckpt_path = Path(ckpt_path)

bionemo-recipes/recipes/mixtral_native_te/fused_token_router.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import torch
3030
import torch.distributed as dist
3131
import transformer_engine.pytorch
32-
3332
from fused_a2a import fused_combine, fused_dispatch
3433
from fused_indices_converter import HAVE_TRITON, fused_indices_to_multihot
3534
from modeling_mixtral_te import DispatchOutput

bionemo-recipes/recipes/mixtral_native_te/tests/test_distributed_checkpointing.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from tokenizers import Tokenizer
2828
from tokenizers.models import WordLevel
2929
from tokenizers.pre_tokenizers import Whitespace
30-
from transformers import PreTrainedTokenizerFast
31-
3230
from train_ddp import main as main_ddp
3331
from train_fsdp2 import main as main_fsdp2
32+
from transformers import PreTrainedTokenizerFast
3433

3534

3635
os.environ["WANDB_DISABLED"] = "true"
@@ -111,15 +110,21 @@ def _assert_checkpoint_step(ckpt_subdir, step, num_ranks, is_ddp, use_distribute
111110
model_files = [f for f in files if f.startswith("model_rank_")]
112111
optimizer_files = [f for f in files if f.startswith("optimizer_rank_")]
113112
assert len(model_files) >= num_ranks, f"Expected model files for {num_ranks} ranks in {step_dir}: {files}"
114-
assert len(optimizer_files) >= num_ranks, f"Expected optimizer files for {num_ranks} ranks in {step_dir}: {files}"
113+
assert len(optimizer_files) >= num_ranks, (
114+
f"Expected optimizer files for {num_ranks} ranks in {step_dir}: {files}"
115+
)
115116
assert "metadata.pt" in files, f"Missing metadata.pt in {step_dir}: {files}"
116117
dataloader_files = [f for f in files if "dataloader" in f]
117-
assert len(dataloader_files) >= num_ranks, f"Expected dataloader files for {num_ranks} ranks in {step_dir}: {files}"
118+
assert len(dataloader_files) >= num_ranks, (
119+
f"Expected dataloader files for {num_ranks} ranks in {step_dir}: {files}"
120+
)
118121

119122

120123
def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_subdir_name, extra_overrides, is_ddp):
121124
tokenizer_path = _create_local_tokenizer(tmp_path)
122-
expert_parallel_size = int(next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size=")))
125+
expert_parallel_size = int(
126+
next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size="))
127+
)
123128
use_distributed_checkpoint = is_ddp and expert_parallel_size > 1
124129
common = [
125130
"checkpoint.save_every_n_steps=5",
@@ -138,7 +143,9 @@ def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_sub
138143
torch.cuda.empty_cache()
139144

140145
ckpt_subdir = os.path.join(str(tmp_path / "ckpt"), ckpt_subdir_name)
141-
_assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint)
146+
_assert_checkpoint_step(
147+
ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
148+
)
142149

143150
cfg2 = _compose_config(
144151
recipe_path,
@@ -149,7 +156,9 @@ def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_sub
149156
gc.collect()
150157
torch.cuda.empty_cache()
151158

152-
_assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint)
159+
_assert_checkpoint_step(
160+
ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
161+
)
153162
_assert_checkpoint_step(
154163
ckpt_subdir, 10, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
155164
)
@@ -162,7 +171,9 @@ def _run_multi_process_checkpoint_test(
162171
):
163172
ckpt_dir = str(tmp_path / "ckpt")
164173
tokenizer_path = _create_local_tokenizer(tmp_path)
165-
expert_parallel_size = int(next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size=")))
174+
expert_parallel_size = int(
175+
next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size="))
176+
)
166177
use_distributed_checkpoint = is_ddp and expert_parallel_size > 1
167178
env = os.environ.copy()
168179
env["WANDB_MODE"] = "disabled"
@@ -194,7 +205,9 @@ def _run_multi_process_checkpoint_test(
194205
assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}"
195206

196207
ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name)
197-
_assert_checkpoint_step(ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint)
208+
_assert_checkpoint_step(
209+
ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
210+
)
198211

199212
result2 = subprocess.run(
200213
[*base_cmd, "num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common],
@@ -205,7 +218,9 @@ def _run_multi_process_checkpoint_test(
205218
)
206219
assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}"
207220

208-
_assert_checkpoint_step(ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint)
221+
_assert_checkpoint_step(
222+
ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
223+
)
209224
_assert_checkpoint_step(
210225
ckpt_subdir, 10, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint
211226
)

bionemo-recipes/recipes/opengenome2_mixtral_native_te/checkpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def load_checkpoint_fsdp2(
215215
dist_config: The distributed configuration.
216216
dataloader: The dataloader to load.
217217
process_group: The process group to use for checkpointing.
218+
expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling.
218219
"""
219220
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
220221
if not checkpoint_path:
@@ -280,6 +281,7 @@ def save_checkpoint_fsdp2(
280281
process_group: The process group to use for checkpointing.
281282
max_checkpoints: The maximum number of checkpoints to keep.
282283
async_save: Whether to save the checkpoint asynchronously.
284+
expert_parallel_size: Expert parallelism size. When > 1, saves expert weights with EP-aware state dict handling.
283285
"""
284286
start_time = time.perf_counter()
285287
ckpt_path = Path(ckpt_path)

bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_token_router.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import torch
3030
import torch.distributed as dist
3131
import transformer_engine.pytorch
32-
3332
from fused_a2a import fused_combine, fused_dispatch
3433
from fused_indices_converter import HAVE_TRITON, fused_indices_to_multihot
3534
from modeling_mixtral_te import DispatchOutput

bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_distributed_checkpointing.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def _assert_checkpoint_step(ckpt_subdir, step, num_ranks):
7070
distcp_files = [f for f in files if f.endswith(".distcp")]
7171
has_metadata = ".metadata" in files
7272
assert has_metadata, f"Missing .metadata in {step_dir}: {files}"
73-
assert len(distcp_files) >= num_ranks, (
74-
f"Expected at least {num_ranks} .distcp files in {step_dir}: {files}"
75-
)
73+
assert len(distcp_files) >= num_ranks, f"Expected at least {num_ranks} .distcp files in {step_dir}: {files}"
7674
dataloader_files = [f for f in files if "dataloader" in f]
7775
assert len(dataloader_files) >= num_ranks, (
7876
f"Expected dataloader files for {num_ranks} ranks in {step_dir}: {files}"

0 commit comments

Comments
 (0)