Skip to content

feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic#2467

Open
saumishr wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
saumishr:modularity/wire-weight-sync
Open

feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic#2467
saumishr wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
saumishr:modularity/wire-weight-sync

Conversation

@saumishr

@saumishr saumishr commented May 12, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Wire the WeightSynchronizer (from MR 1 / PR #2466) into the algorithm layer, replacing inline refit_policy_generation() calls and NEED_REFIT/POLICY_GENERATION_STALE flags with WeightSynchronizer method calls (sync_weights, mark_stale, is_stale).

Issues

Part of the modularity/interfaces initiative. Depends on PR #2466 (modularity/weight-sync-abc).

Usage

Before (old pattern):

# In setup()
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info)
if not colocated_inference:
    futures_train = policy.init_collective(ip, port, world_size, ...)
    futures_inference = policy_generation.init_collective(ip, port, world_size, ...)
    ray.get(futures_train + futures_inference)

# In grpo_train()
NEED_REFIT = True
POLICY_GENERATION_STALE = True
...
if NEED_REFIT and POLICY_GENERATION_STALE:
    refit_policy_generation(policy, policy_generation, colocated_inference, timer=timer, kv_scales=...)
    POLICY_GENERATION_STALE = False
...
POLICY_GENERATION_STALE = True  # after training

After (new pattern):

# In setup()
weight_sync = create_weight_synchronizer(
    policy=policy, generation=policy_generation,
    generation_backend=backend, colocated=colocated_inference,
    train_cluster=train_cluster, inference_cluster=inference_cluster,
)
weight_sync.init_communicator()
# In grpo_train()
if weight_sync is not None and weight_sync.is_stale:
    weight_sync.sync_weights(timer=timer, kv_scales=...)
...
weight_sync.mark_stale()  # after training

Tests

Updated test_grpo.py and test_distillation.py to use weight_sync mocks

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

Stacked on #2466

@copy-pr-bot

copy-pr-bot Bot commented May 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@saumishr saumishr force-pushed the modularity/wire-weight-sync branch from 8b5be01 to accfb63 Compare May 12, 2026 03:23
@saumishr saumishr force-pushed the modularity/wire-weight-sync branch 2 times, most recently from 0c6164e to 657abc9 Compare May 26, 2026 18:49
@saumishr saumishr added the CI:L1 Run doctests, unit tests, and functional tests label May 26, 2026
@saumishr saumishr requested a review from terrykong May 26, 2026 19:18
@saumishr saumishr marked this pull request as ready for review May 26, 2026 19:18
@saumishr saumishr requested review from a team as code owners May 26, 2026 19:18
@saumishr

Copy link
Copy Markdown
Contributor Author

/ok to test 657abc9

@saumishr

Copy link
Copy Markdown
Contributor Author

/ok to test 616f2b0

@saumishr saumishr force-pushed the modularity/wire-weight-sync branch from df6b623 to 40f6d46 Compare May 28, 2026 22:38
@saumishr

Copy link
Copy Markdown
Contributor Author

/ok to test 40f6d46

@terrykong

terrykong commented May 29, 2026

Copy link
Copy Markdown
Collaborator

@kajalj22 the L1 tests are timing out at 6hr. is something off? i did notice a lot of downloads where i expected the hf cache to be hit https://github.com/NVIDIA-NeMo/RL/actions/runs/26606449648/job/78418561361?pr=2467

@terrykong

Copy link
Copy Markdown
Collaborator

@saumishr given the nature of the change, could you run nightlies from these categories to make sure there's not some convergence issue due to the refit change #2467 (comment)

@saumishr saumishr force-pushed the modularity/wire-weight-sync branch from 40f6d46 to f303df0 Compare May 31, 2026 17:59
@saumishr

Copy link
Copy Markdown
Contributor Author

/ok to test f303df0

@saumishr

saumishr commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 3555948

…refit logic

Replaces the inline refit/weight-sync logic in the GRPO algorithm layer with
the WeightSynchronizer abstraction (IPC / HTTP / collective transports).

Also restores policy.prepare_refit_info() for the megatron-framework
generation path, and keeps refits gated on weight_sync.is_stale, which is set
only after a training step (matching main's semantics).

Signed-off-by: Saurabh Mishra <sauramishra@nvidia.com>
@saumishr

saumishr commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 196fd7c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests CI:L2 Run doctests, unit tests, functional tests, and convergence tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants