Skip to content

Gemma4 26B-A4B MoE Expert Weights float32 After FSDP2 Prefetch PR #1863

@sharonyu-115

Description

@sharonyu-115

Describe the bug

Ran into this error while working on NVIDIA-NeMo/RL#2212

It seems to be related to #1711

Gemma4 26B-A4B MoE GRPO training crashes with RuntimeError: Expected b.scalar_type() == torch::kBFloat16 to be true, but got false in GroupedExpertsDeepEP.forward()ops.gmm(). It seems to be expert weights (gate_and_up_projs, down_projs) are float32 while activations are bf16.

Full traceback:

  File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/models/gemma4_moe/model.py", line 223, in forward
    moe_out = self.moe(moe_input, padding_mask)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/moe/layers.py", line 670, in forward
    y = self.experts(x_latent, token_mask, weights, indices)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/src/NeMo-RL/nemo-rl/3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/moe/experts.py", line 720, in forward
    output1 = ops.gmm(
              ^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/grouped_gemm/ops.py", line 41, in gmm
    return GroupedGemm.apply(a, b, batch_sizes, trans_b)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/autograd/function.py", line 583, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/grouped_gemm/ops.py", line 19, in forward
    return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/grouped_gemm/backend.py", line 27, in gmm
    backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
RuntimeError: Expected b.scalar_type() == torch::kBFloat16 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Steps/Code to reproduce bug

I don't yet have an automodel standalone test case to reproduce this.
What I've done is the following tests from NeMo RL side:

Bisected to commit 2013a4dd — "feat: FSDP2 w weight prefetching and async TP optimization (#1711)"

Automodel Commit Status Job
c5f0f804 (Apr 10) PASS — gen_kl=0.11, trains normally
2013a4dd (Apr 10) CRASH — dtype mismatch

Additional context

Analysis from Claude just for reference:

2013a4dd added nn.ModuleDict handling to apply_fsdp2_sharding_recursively() in parallelizer.py.

Before (c5f0f80): nn.ModuleDict falls into the else branch — recursed into children without calling fully_shard() on the layer itself:

if isinstance(module, nn.ModuleList):
    # ... handles ModuleList
else:
    for name, sub_module in module.named_children():
        apply_fsdp2_sharding_recursively(sub_module, mesh, mp_policy, offload_policy)

After (2013a4d): nn.ModuleDict is now handled alongside nn.ModuleList, with fully_shard() called directly on each layer:

if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
    # ...
    for enum_id, (layer_key, child_module) in enumerate(flat_layer_items):
        fully_shard(child_module, mesh=mesh, mp_policy=mp_policy, ...)

Gemma4 MoE uses nn.ModuleDict for decoder layers (in Gemma4MoETextModelBackend). Each layer contains EP-sharded expert DTensors on the EP mesh. When fully_shard() wraps the layer with MixedPrecisionPolicy(param_dtype=bf16) on the DP mesh, the EP-sharded DTensors (on a different mesh) are not properly cast to bf16, leaving them as float32.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions