Skip to content

Pass full_param_layout into DDP (Megatron-LM #3812) #3483

@yaoyu-33

Description

@yaoyu-33

Summary

Megatron-LM PR NVIDIA/Megatron-LM#3812 refactors DistributedDataParallel to accept a full_param_layout argument describing how parameters and gradients are mapped in _ParamAndGradBuffer. Distributed optimizers compute this mapping via a static compute_full_param_layout method.

MBridge will need to pass full_param_layout into DDP to fully support this change.

Urgency

Not pressing. DDP currently falls back to the existing behavior when full_param_layout is not passed (_compute_default_per_buffer_param_layout). However, Deepak plans to remove that fallback code in a future cleanup pass, at which point MBridge must supply full_param_layout or DDP initialization will break.

What needs to happen

  1. Understand the full_param_layout / PerBufferParamLayout / BufferKey dataclasses introduced in param_layout.py.
  2. Wire up DistributedOptimizer.compute_full_param_layout() in MBridge's training initialization and pass the result to DDP.
  3. Update any MBridge code that constructs DistributedDataParallel to forward the layout.
  4. Test with distributed optimizer enabled to verify gradient buffer layout matches expectations.

Context

From Slack discussion with @deepakn94 (2026-04-22). Should be addressed before the fallback path is removed upstream.

cc @deepakn94

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:trainingTraining loop, callbacks, and runtime integrationfeatureNew capabilities, enhancements, or enablement workmlm-syncRequires API/behavior sync with upstream Megatron-LM changestrackingTracking issue for an ongoing project with smaller steps

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions