Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
8411f6f
plan
ZhiyuLi-Nvidia May 1, 2026
85acfdb
plan: align Stage 4 with rl-arena/verl 1-hop pattern
ZhiyuLi-Nvidia May 2, 2026
9a46c43
feat(data-plane): TransferQueue integration for GRPO with driver-side…
ZhiyuLi-Nvidia May 4, 2026
bcb451a
refactor(data-plane): extract driver-side balanced packing into presh…
ZhiyuLi-Nvidia May 5, 2026
196b6bb
feat(data-plane): AsyncTrajectoryCollector writes rollouts to TQ when…
ZhiyuLi-Nvidia May 5, 2026
0c216f4
feat(data-plane): wire async-on-TQ end-to-end with driver-side balanc…
ZhiyuLi-Nvidia May 5, 2026
bf092f7
fix(data-plane): preserve sample order and FLOPs semantics on @dp_dis…
ZhiyuLi-Nvidia May 5, 2026
a28b46d
feat(data-plane): grpo_sync routes logprob/ref-logprob through @dp_di…
ZhiyuLi-Nvidia May 5, 2026
c1bb667
refactor(data-plane): replace @dp_dispatch with TQPolicy subclass; ad…
ZhiyuLi-Nvidia May 5, 2026
67b242b
fix(data-plane): VLM extras, async fan-out, cleanup-on-failure
ZhiyuLi-Nvidia May 5, 2026
d05ad3f
docs(data-plane): add API lifecycle doc with verl comparison
ZhiyuLi-Nvidia May 7, 2026
9da2ec9
feat(data-plane): sync 1-hop trajectory collector + per-sample key li…
ZhiyuLi-Nvidia May 7, 2026
a7f4bcc
refactor(data-plane): extract make_actor_runtime_env, fix N² list copy
ZhiyuLi-Nvidia May 7, 2026
fc6ceea
feat(data-plane): jagged tensors on TQ wire + naming/factory cleanup
ZhiyuLi-Nvidia May 7, 2026
520bfef
refactor(data-plane): KVBatchMeta.subset/slice/concat methods
ZhiyuLi-Nvidia May 7, 2026
b732afe
Mooncake cpu backend
ZhiyuLi-Nvidia May 7, 2026
dcd62d8
Readability Refactor
ZhiyuLi-Nvidia May 8, 2026
fba1f32
wip test mooncake
ZhiyuLi-Nvidia May 8, 2026
69b09c1
refactor(data-plane): drop dead set_wire_format/_PACK_JAGGED + adapte…
ZhiyuLi-Nvidia May 8, 2026
a55ad5c
refactor(ray.sub): drop NETWORK_INIT_CMDS — MC_TCP_BIND_ADDRESS suffices
ZhiyuLi-Nvidia May 8, 2026
703bd36
docs(data-plane): consolidate README; drop stale plan/verl refs
ZhiyuLi-Nvidia May 8, 2026
d42e7b2
feat(data-plane): non-tensor object support on TQ wire
ZhiyuLi-Nvidia May 8, 2026
a8ff04e
feat(grpo-sync): equivalency fixes + content via TQ object column
ZhiyuLi-Nvidia May 9, 2026
77b0f6a
style: fix ruff lint errors and apply ruff format
ZhiyuLi-Nvidia May 9, 2026
d86aed2
style: apply pre-commit auto-fixes (ruff)
ZhiyuLi-Nvidia May 9, 2026
41258a4
chore(pyrefly): whitelist all new data_plane files + fix type errors
ZhiyuLi-Nvidia May 9, 2026
2b58c02
remove unnecessary script
ZhiyuLi-Nvidia May 9, 2026
1347c88
feat(data-plane): decompose message_log at wire boundary
ZhiyuLi-Nvidia May 12, 2026
1903125
refactor(data-plane): rename DataPlaneClient.get_meta → claim_meta
ZhiyuLi-Nvidia May 12, 2026
f527f77
docs(data-plane): tighten DataPlaneClient boundary docstring
ZhiyuLi-Nvidia May 12, 2026
0dea433
fix(data-plane): treat DataPlaneConfig.enabled as required field
ZhiyuLi-Nvidia May 12, 2026
de28a19
docs(data-plane): make build_data_plane_client docstring backend-agno…
ZhiyuLi-Nvidia May 12, 2026
0f710a4
refactor(data-plane): promote codec imports to module top-level
ZhiyuLi-Nvidia May 12, 2026
fe2aa71
refactor(data-plane): rename driver_io → column_io
ZhiyuLi-Nvidia May 12, 2026
e02d3c7
refactor(data-plane): validate dp_world at TQPolicy config time
ZhiyuLi-Nvidia May 12, 2026
0c985d4
refactor(data-plane): centralize packing-meta keys in schema.py
ZhiyuLi-Nvidia May 13, 2026
c5cf807
refactor(data-plane): drop redundant dp_world assert in shard_meta_fo…
ZhiyuLi-Nvidia May 13, 2026
734a01a
refactor(data-plane): move DP_SEED_FIELDS to schema.py as DP_TRAIN_FI…
ZhiyuLi-Nvidia May 13, 2026
5a6d53d
fix(data-plane): reject empty meta in shard_meta_for_dp
ZhiyuLi-Nvidia May 13, 2026
44c82aa
refactor(data-plane): print_event → log_event via stdlib logging
ZhiyuLi-Nvidia May 13, 2026
379dae1
style(data-plane): match repo logger naming convention
ZhiyuLi-Nvidia May 13, 2026
475b703
refactor(data-plane): convert DataPlaneStats to @dataclass
ZhiyuLi-Nvidia May 13, 2026
27f1d77
refactor(data-plane): type DataPlaneEvent as TypedDict
ZhiyuLi-Nvidia May 13, 2026
5a8e8a7
refactor(data-plane): drop placeholder 0s from _run; make sizes kw-only
ZhiyuLi-Nvidia May 13, 2026
e93fe5f
fix(data-plane): route check_consumption_status through _run
ZhiyuLi-Nvidia May 13, 2026
5d12647
fix(data-plane): route close() through _run
ZhiyuLi-Nvidia May 13, 2026
6ca3b47
perf(data-plane): single sync in to_nested_by_length
ZhiyuLi-Nvidia May 13, 2026
0d690f9
docs(data-plane): convert codec.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
d22709f
refactor(data-plane): centralize Layout type alias in schema.py
ZhiyuLi-Nvidia May 13, 2026
e23c400
fix(data-plane): validate pad_to_multiple >= 1 in materialize
ZhiyuLi-Nvidia May 13, 2026
e491025
fix(data-plane): fail fast on empty local IP at Mooncake bootstrap
ZhiyuLi-Nvidia May 13, 2026
f3dc3ee
fix(data-plane): surface chmod failure when mooncake_master is not exec
ZhiyuLi-Nvidia May 13, 2026
a8de1df
refactor(data-plane): scope mooncake_cpu 1D workaround to TQDataPlane…
ZhiyuLi-Nvidia May 13, 2026
8758b3a
docs(data-plane): clarify TQ module vs client access convention
ZhiyuLi-Nvidia May 13, 2026
245f04c
docs(data-plane): note trust boundary at pack_object_array pickle site
ZhiyuLi-Nvidia May 13, 2026
739c837
refactor(data-plane): drop codec pickle, use TQ-native NonTensorStack
ZhiyuLi-Nvidia May 13, 2026
f4b647f
refactor(data-plane): drop dead object-array codec helpers
ZhiyuLi-Nvidia May 13, 2026
4fa8a11
refactor(data-plane): centralize _meta_idx sentinel in schema.py
ZhiyuLi-Nvidia May 13, 2026
38921eb
docs(data-plane): convert interfaces.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
48802b0
refactor(data-plane): align schema constant names with their values
ZhiyuLi-Nvidia May 13, 2026
a65abaf
docs(data-plane): tighten preshard.py docstring to Google style
ZhiyuLi-Nvidia May 13, 2026
2aeb292
docs(data-plane): convert column_io.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
b44c0f4
docs(data-plane): convert factory.py docstring to Google style
ZhiyuLi-Nvidia May 13, 2026
0455e2e
docs(data-plane): add Args/Returns blocks to observability.py docstrings
ZhiyuLi-Nvidia May 13, 2026
c39e313
docs(data-plane): tighten transfer_queue.py docstrings, add Args/Retu…
ZhiyuLi-Nvidia May 13, 2026
2c12afd
docs(data-plane): add Args/Returns to worker_mixin.py docstrings
ZhiyuLi-Nvidia May 13, 2026
650e142
docs(data-plane): add Args/Returns blocks to tq_policy.py docstrings
ZhiyuLi-Nvidia May 13, 2026
db312b6
docs(data-plane): convert sync_rollout_actor.py docstrings to Google …
ZhiyuLi-Nvidia May 13, 2026
dbef790
docs(data-plane): add Args/Returns to grpo_sync.py dynamic-sampling h…
ZhiyuLi-Nvidia May 13, 2026
cb1dc34
refactor(data-plane): drop _to_wire's redundant promote_1d kwarg
ZhiyuLi-Nvidia May 13, 2026
b9c15ed
fix(data-plane): survive TQ simple-backend NonTensorData wire-strip
ZhiyuLi-Nvidia May 14, 2026
47d2f7f
build(data-plane): pin mooncake-transfer-engine-cuda13 wheel for cu13…
ZhiyuLi-Nvidia May 14, 2026
de22c5c
chore: ruff auto-fix and ruff-format pass
ZhiyuLi-Nvidia May 14, 2026
908ed7f
chore(pyrefly): rename driver_io → column_io in whitelist
ZhiyuLi-Nvidia May 14, 2026
356d166
chore(pyrefly): silence 5 latent type errors with targeted ignore com…
ZhiyuLi-Nvidia May 14, 2026
6666a89
chore(pyrefly): whitelist nemo_rl/data_plane/schema.py
ZhiyuLi-Nvidia May 14, 2026
5dbc600
fix(data-plane): preserve object-column identity through TQ wire
ZhiyuLi-Nvidia May 14, 2026
b9154bc
fix(data-plane): gate TQ write-back on TP×CP×PP leader to avoid dupli…
ZhiyuLi-Nvidia May 14, 2026
cab4bc0
chore: ruff auto-fix and D205 docstring fixes
ZhiyuLi-Nvidia May 14, 2026
db31b12
refactor(data-plane): drop async-grpo TQ scaffolding from sync PR
ZhiyuLi-Nvidia May 14, 2026
351916b
refactor(data-plane): consolidate producer codec, caller mints keys
ZhiyuLi-Nvidia May 14, 2026
53be031
test(data-plane): align codec tests with current contract
ZhiyuLi-Nvidia May 14, 2026
09099f0
refactor(grpo_sync): drop dead batch_cache; make TQPolicy attrs public
ZhiyuLi-Nvidia May 14, 2026
660dd89
refactor(data-plane): extract calibration field filter into named sch…
ZhiyuLi-Nvidia May 15, 2026
dabe37b
refactor(data-plane): make kv_batch_get(select_fields) required
ZhiyuLi-Nvidia May 15, 2026
d9258cd
refactor(sync-rollout-actor): remove unused wrappers; document full l…
ZhiyuLi-Nvidia May 15, 2026
1a937aa
test(data-plane): move data_plane unit tests under tests/unit/ for CI…
ZhiyuLi-Nvidia May 15, 2026
4cfd120
test(data-plane): apply ruff --fix and import-sort to data_plane unit…
ZhiyuLi-Nvidia May 15, 2026
534fb07
docs: fix broken nemo-gym Core Components link
ZhiyuLi-Nvidia May 15, 2026
e49b1ca
chore(grpo): drop stale mypy comments; rename TQPolicy ctor->actor
ZhiyuLi-Nvidia May 15, 2026
5d8de41
fix(data-plane): reject loopback IP; resolve TQ runtime_env pin from …
ZhiyuLi-Nvidia May 15, 2026
b512927
docs(data-plane): rewrite README around sync flow + async proposal
ZhiyuLi-Nvidia May 15, 2026
791671e
docs(data-plane): clarify partition scope and TQ mental model
ZhiyuLi-Nvidia May 15, 2026
30d6ccc
refactor(data-plane): per-row tags on KVBatchMeta; rename slice → dri…
ZhiyuLi-Nvidia May 16, 2026
0f01865
perf(sync-rollout-actor): subset driver_carry via carry_keys
ZhiyuLi-Nvidia May 16, 2026
1bbaa17
refactor(grpo-sync): apply overlong filter post-dynamic-sampling
ZhiyuLi-Nvidia May 16, 2026
52c1394
refactor(grpo-sync): isolate TQ ops behind TQPolicy/KVBatchMeta façades
ZhiyuLi-Nvidia May 16, 2026
63ea762
refactor(data-plane): YAML-only defaults for TQ config (terryk §9)
ZhiyuLi-Nvidia May 16, 2026
1d025f4
docs(data-plane): refresh README around encapsulated TQ path
ZhiyuLi-Nvidia May 16, 2026
c6d0d30
chore: ruff format + pyrefly ignore + underscore-md rename
ZhiyuLi-Nvidia May 16, 2026
1f637ea
docs(data-plane): drop api-lifecycle doc; realistic concrete examples
ZhiyuLi-Nvidia May 16, 2026
b4497f0
docs: align nemo-gym Core Components link with main
ZhiyuLi-Nvidia May 16, 2026
0d0d36b
fix(data-plane): close grad_norm collapse + NCCL desync in DP fsdp2 path
ZhiyuLi-Nvidia May 18, 2026
fb6ccef
refactor(data-plane): drop _tq() lazy wrapper; fail-fast in check_con…
ZhiyuLi-Nvidia May 18, 2026
28e634b
refactor(grpo-sync): mint uids in rollout actor (verl-style per-promp…
ZhiyuLi-Nvidia May 18, 2026
c3c2866
refactor(data-plane): rename KVBatchMeta.keys -> sample_ids (Phase A)
ZhiyuLi-Nvidia May 18, 2026
935c1b5
refactor(data-plane): rename DataPlaneClient kwarg keys -> sample_ids…
ZhiyuLi-Nvidia May 18, 2026
14e75cf
test(data-plane): update KVBatchMeta schema-pin to sample_ids
ZhiyuLi-Nvidia May 18, 2026
23d4353
refactor(data-plane): rename DataPlaneClient verbs kv_batch_* -> {put…
ZhiyuLi-Nvidia May 18, 2026
9474196
refactor(data-plane): tighten clear_samples(None) contract; warn on s…
ZhiyuLi-Nvidia May 18, 2026
fdfade3
chore(data-plane): apply ruff format
ZhiyuLi-Nvidia May 18, 2026
be54ac6
feat(data-plane): align seq-dim across DP ranks via meta-stamped glob…
ZhiyuLi-Nvidia May 18, 2026
2c6c022
test(data-plane): add missing DataPlaneConfig keys to test_seqpack_eq…
ZhiyuLi-Nvidia May 18, 2026
a6b4ab8
refactor(data-plane): remove _PartitionRecord from TQ adapter
ZhiyuLi-Nvidia May 18, 2026
f3a4a04
test(data-plane): remove empty tests/unit/data_plane/conftest.py
ZhiyuLi-Nvidia May 18, 2026
1c8a470
revert(test): restore NUM_MINUTES=150 in prorlv2 recipe sh
ZhiyuLi-Nvidia May 18, 2026
04f410a
test(data-plane): drop test_tq_multinode.py
ZhiyuLi-Nvidia May 18, 2026
9c6d0de
docs(data-plane): document DP-aligned forward pad seqlen in README
ZhiyuLi-Nvidia May 18, 2026
450f8d9
test(data-plane): drop stale import-isolation tests; merge codec_obje…
ZhiyuLi-Nvidia May 18, 2026
0d5bb92
refactor(data-plane): drop drive-by edits from PR scope
ZhiyuLi-Nvidia May 19, 2026
4b866cd
test(data-plane): accept attribute-style data_plane access in invariant
ZhiyuLi-Nvidia May 19, 2026
4c252c6
refactor(data-plane): use attribute-style access on MasterConfig
ZhiyuLi-Nvidia May 19, 2026
d4d9c7c
refactor(data-plane): replace run_grpo dispatch grep with behavioral …
ZhiyuLi-Nvidia May 19, 2026
a775aee
fix(data-plane): use attribute access for loss_fn KL penalty assert
ZhiyuLi-Nvidia May 19, 2026
cd45f8f
fix(data-plane): pre-register fields to dodge TQ controller race
ZhiyuLi-Nvidia May 19, 2026
1e1f0f2
fix(configs): set truncated_importance_sampling_type=tis on recipes t…
ZhiyuLi-Nvidia May 19, 2026
5980c8e
refactor(data-plane): close four cross-boundary leaks
ZhiyuLi-Nvidia May 19, 2026
f1bc4fa
chore(data-plane): apply ruff format to discard_samples
ZhiyuLi-Nvidia May 19, 2026
c34ba36
test(data-plane): consolidate suite under tests/unit/data_plane
ZhiyuLi-Nvidia May 19, 2026
80b5760
fix(data-plane): shrink mooncake_cpu segment defaults to fit CI runners
ZhiyuLi-Nvidia May 19, 2026
90d32a4
test(data-plane): update _apply_dynamic_sampling tests for policy= param
ZhiyuLi-Nvidia May 19, 2026
f6477a4
fix(data-plane): apply pad_to_seqlen to ALL 2D+ tensors in materialize
ZhiyuLi-Nvidia May 20, 2026
2d8115c
test(data-plane): add missing DataPlaneConfig keys to _TQ_CFG in chao…
ZhiyuLi-Nvidia May 20, 2026
3e3e3be
test(data-plane): remove storage-actor-kill chaos test
ZhiyuLi-Nvidia May 20, 2026
1c7d246
fix(data-plane): exclude MESSAGE_LOG_BULK_FIELDS from FP8 calib request
ZhiyuLi-Nvidia May 20, 2026
32be65a
test(data-plane): pin MESSAGE_LOG_BULK_FIELDS in DP_CALIB_EXCLUDED_FI…
ZhiyuLi-Nvidia May 20, 2026
56b78cd
test(data-plane): add missing DataPlaneConfig keys to tq_lifecycle fi…
ZhiyuLi-Nvidia May 20, 2026
42606b6
feat(data-plane): route FP8 KV scales through TQ (sync first cut)
ZhiyuLi-Nvidia May 20, 2026
45233e6
Revert "feat(data-plane): route FP8 KV scales through TQ (sync first …
ZhiyuLi-Nvidia May 20, 2026
0fe15b1
refactor(data-plane): flip calib filter to positive include-list
ZhiyuLi-Nvidia May 20, 2026
ccf5eb8
test(data-plane): add realistic-shape rollout fixtures + cross-file d…
ZhiyuLi-Nvidia May 20, 2026
c958c2a
chore(test): apply ruff isort + blank-line fixes
ZhiyuLi-Nvidia May 20, 2026
68206ef
fix(data-plane): override _is_writeback_leader in DTensor V1 worker
ZhiyuLi-Nvidia May 20, 2026
fb54dc7
test(data-plane): sync grpo_math_1B reference config buffer sizes
ZhiyuLi-Nvidia May 20, 2026
e84b25d
test(data-plane): slim test_architecture_invariants to 2 behavioral t…
ZhiyuLi-Nvidia May 20, 2026
6afdc98
undo unnecessary change
ZhiyuLi-Nvidia May 20, 2026
1a38153
build: resolve mooncake-transfer-engine-cuda13 from PyPI instead of G…
ZhiyuLi-Nvidia May 21, 2026
4183e63
perf(data-plane): skip Ray return of per-token logprob tensors
ZhiyuLi-Nvidia May 21, 2026
ed45e8c
perf(data-plane): worker-side suppress per-token logprob Ray return
ZhiyuLi-Nvidia May 21, 2026
35bb085
refactor(data-plane): drop aggregator path now that logprob workers r…
ZhiyuLi-Nvidia May 21, 2026
e908738
refactor(data-plane): make Ray worker_coords the single source of tru…
ZhiyuLi-Nvidia May 22, 2026
98bf3be
Revert "refactor(data-plane): make Ray worker_coords the single sourc…
ZhiyuLi-Nvidia May 22, 2026
079979a
fix(data-plane): unify leader-gate on NamedSharding.is_axis_zero; fix…
ZhiyuLi-Nvidia May 22, 2026
2b504b5
chore: ruff auto-fix and ruff-format pass post-rebase
ZhiyuLi-Nvidia May 22, 2026
bfb261f
undo unnecessary change
ZhiyuLi-Nvidia May 22, 2026
3dedfd9
build: remove unnecessary setuptools packages.find filter
ZhiyuLi-Nvidia May 22, 2026
ed24395
fix(data-plane): preserve non-tensor leaves in mooncake_cpu 1D wire-p…
ZhiyuLi-Nvidia May 22, 2026
26179fd
chore: ruff-format pass on test_leader_broadcast.py
ZhiyuLi-Nvidia May 23, 2026
7341341
chore: ruff-format test_leader_broadcast.py
ZhiyuLi-Nvidia May 24, 2026
b63c18f
fix(deps): include aarch64 in mooncake-cuda13 marker
ZhiyuLi-Nvidia May 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,19 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1

# TransferQueue-mediated data plane for sync GRPO.
# Off by default — the legacy grpo_train trainer never engages this.
# Flip enabled=true and run grpo_train_sync to use TQ-mediated bulk
# transfer between rollout and train. See nemo_rl/data_plane/README.md.
data_plane:
enabled: false
impl: transfer_queue
backend: "simple" # TQ storage backend ('simple' or 'mooncake_cpu')
storage_capacity: 1000000 # max samples retained per partition
num_storage_units: 2 # storage shards
claim_meta_poll_interval_s: 0.5 # blocking-claim poll cadence
global_segment_size: 549755813888 # 512 GiB — used when backend == "mooncake_cpu"
local_buffer_size: 68719476736 # 64 GiB — used when backend == "mooncake_cpu"
# observability: # NotRequired
# enabled: false
Comment thread
ZhiyuLi-Nvidia marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/grpo-glm47-flash-4n8g-automodel
policy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
ratio_clip_max: 0.28
ratio_clip_c: 10
checkpointing:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16
policy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-megatron-ep16
policy:
Expand Down
46 changes: 41 additions & 5 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@
from nemo_rl.utils.logger import get_next_experiment_dir


def _select_trainer(master_config: MasterConfig):
"""Pick the synchronous trainer based on ``data_plane.enabled``.

Factored out so test_architecture_invariants can verify dispatch
without the full setup() path.
"""
dp_cfg = master_config.data_plane or {}
if dp_cfg.get("enabled", False):
from nemo_rl.algorithms.grpo_sync import grpo_train_sync

print("🚀 Running synchronous GRPO training (TransferQueue)")
return grpo_train_sync
print("🚀 Running synchronous GRPO training (legacy)")
return grpo_train


def parse_args() -> tuple[argparse.Namespace, list[str]]:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
Expand Down Expand Up @@ -100,6 +116,20 @@ def main() -> None:
val_task_to_env,
) = setup_response_data(tokenizer, config.data, config.env)

# Pick the policy factory at the launcher level so the legacy trainer
# stays data-plane-agnostic (architectural invariant — see
# tests/data_plane/unit/test_architecture_invariants.py).
_dp_cfg = config.data_plane or {}
if _dp_cfg.get("enabled", False):
from nemo_rl.models.policy.tq_policy import TQPolicy

def _make_policy(**kwargs):
return TQPolicy(**kwargs, dp_cfg=_dp_cfg)

_policy_factory = _make_policy
else:
_policy_factory = None # setup() defaults to plain Policy

(
policy,
policy_generation,
Expand All @@ -111,7 +141,13 @@ def main() -> None:
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)
) = setup(
config,
tokenizer,
dataset,
val_dataset,
policy_factory=_policy_factory,
)

# Check if async mode is enabled
if "async_grpo" in config.grpo and config.grpo["async_grpo"]["enabled"]:
Expand Down Expand Up @@ -165,10 +201,10 @@ def main() -> None:
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
)
else:
print("🚀 Running synchronous GRPO training")

# Run standard GRPO training
grpo_train(
# Two parallel synchronous trainers (verl-style — main_ppo.py vs
# main_ppo_sync.py). data_plane.enabled selects which one runs.
trainer = _select_trainer(master_config)
trainer(
policy,
policy_generation,
dataloader,
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def distillation_train(
student_generation = student_policy # type: ignore
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert student_generation is not None # for mypy type check
assert student_generation is not None

# common config/state items
current_epoch = distillation_save_state["current_epoch"] # current epoch
Expand Down
14 changes: 11 additions & 3 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar, cast

import numpy as np
import ray
Expand Down Expand Up @@ -59,6 +59,7 @@
get_keys_from_message_log,
)
from nemo_rl.data.utils import extract_necessary_env_names, load_dataloader_state
from nemo_rl.data_plane.interfaces import DataPlaneConfig
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
Expand Down Expand Up @@ -207,6 +208,7 @@ class MasterConfig(BaseModel, extra="allow"):
logger: GRPOLoggerConfig
cluster: ClusterConfig
checkpointing: CheckpointingConfig
data_plane: Optional[DataPlaneConfig] = None


# ===============================================================================
Expand All @@ -220,6 +222,7 @@ def setup(
dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset],
val_dataset: Optional[AllTaskProcessedDataset],
processor: Optional[AutoProcessor] = None,
policy_factory: Optional[Callable[..., ColocatablePolicyInterface]] = None,
) -> tuple[
ColocatablePolicyInterface,
Optional[GenerationInterface],
Expand Down Expand Up @@ -580,10 +583,15 @@ def init_train_dataloader(dataset, suffix: str = ""):
"(reference model is not loaded)."
)

# Caller-supplied factory lets the sync trainer swap in a TQ-mediated
# Policy subclass without this shared setup needing to know the data
# plane exists. Default is the plain Policy class — legacy behavior.
_make_policy = policy_factory if policy_factory is not None else Policy

def init_policy():
"""Initialize policy training workers."""
t0 = time.perf_counter()
p = Policy(
p = _make_policy(
cluster=train_cluster,
config=policy_config,
tokenizer=tokenizer,
Expand Down Expand Up @@ -1360,7 +1368,7 @@ def grpo_train(
policy_generation = policy # type: ignore
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert policy_generation is not None # for mypy type check
assert policy_generation is not None

# Check if we need to sync KV cache scales
# When fallback to policy as the policy_generation, we use getattr to check.
Expand Down
Loading
Loading