Skip to content

Commit 877bb67

Browse files
committed
more fix
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
1 parent 9e85ccf commit 877bb67

10 files changed

Lines changed: 90 additions & 29 deletions

File tree

examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ grpo:
3838
seq_logprob_error_threshold: null
3939

4040
async_grpo:
41-
enabled: false # Set to true to enable async training mode
41+
enabled: true # Set to true to enable async training mode
4242
# Max age (in training steps) for trajectories used in training
43-
max_trajectory_age_steps: 1
44-
in_flight_weight_updates: false # Set to true to enable in-flight weight updates
43+
max_trajectory_age_steps: 2
44+
in_flight_weight_updates: true # Set to true to enable in-flight weight updates
4545
recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates
4646

4747
loss_fn:
@@ -55,7 +55,7 @@ loss_fn:
5555
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
5656
use_on_policy_kl_approximation: false
5757
truncated_importance_sampling_ratio: null
58-
use_importance_sampling_correction: false
58+
use_importance_sampling_correction: true
5959
token_level_loss: true
6060

6161
checkpointing:
@@ -234,15 +234,14 @@ policy:
234234
# Workplace assistant uses 26 tools, so we enable auto_tools.
235235
# For Nemotron Nano v2, we use the dedicated `nemotron_json` tool parser
236236
enable_auto_tools: true
237-
tool_parser: nemotron_json
237+
tool_parser: hermes
238+
reasoning_parser: qwen3
238239
vllm_kwargs:
239240
compilation_config:
240241
# when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.backend=eager for better accuracy,
241242
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
242243
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
243244
backend: eager
244-
# We need the Mamba cache to be set to fp32 for Nemotron Nano v2
245-
mamba_ssm_cache_dtype: "float32"
246245
colocated:
247246
# true: generation shares training GPUs
248247
# false: uses dedicated generation resources
@@ -297,10 +296,10 @@ env:
297296
responses_api_models:
298297
vllm_model:
299298
# Disable reasoning!
300-
uses_reasoning_parser: false
299+
uses_reasoning_parser: true
301300
extra_body:
302301
chat_template_kwargs:
303-
enable_thinking: false
302+
enable_thinking: true
304303
code_gen:
305304
resources_servers:
306305
code_gen:
@@ -328,3 +327,4 @@ logger:
328327
cluster:
329328
gpus_per_node: 8
330329
num_nodes: 1 # Single node by default; set to 2+ for multi-node training
330+

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def collect_trajectories(
9494
input_batch=val_batch,
9595
tokenizer=tokenizer,
9696
task_to_env=val_task_to_env,
97-
max_seq_len=None,
97+
max_seq_len=master_config["policy"]["max_total_sequence_length"],
9898
generation_config=generation_config,
9999
max_rollout_turns=None,
100100
greedy=False,

nemo_rl/algorithms/async_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,9 @@ def _run_prompt_group_worker(
657657
input_batch=repeated_batch,
658658
tokenizer=self.tokenizer,
659659
task_to_env=self.task_to_env,
660-
max_seq_len=None,
660+
max_seq_len=self.master_config["policy"][
661+
"max_total_sequence_length"
662+
],
661663
generation_config=generation_config,
662664
max_rollout_turns=None,
663665
greedy=False,

nemo_rl/algorithms/grpo.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def init_dynamo():
779779

780780
# if it is not colocated inference, initialize collective communication for update weights
781781
# Dynamo backend does not support weight updates — skip collective init and refit.
782-
if not colocated_inference and backend != "dynamo":
782+
if not colocated_inference and backend not in ("dynamo", "vllm"):
783783
t0 = time.perf_counter()
784784
ip, port = train_cluster.get_master_address_and_port()
785785
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
@@ -800,7 +800,7 @@ def init_dynamo():
800800

801801
# prepare refit info
802802
state_dict_info = policy.prepare_refit_info()
803-
if policy_generation is not None and backend != "dynamo":
803+
if policy_generation is not None and backend not in ("dynamo", "vllm"):
804804
policy_generation.prepare_refit_info(state_dict_info)
805805

806806
# Calculate total setup time
@@ -1393,7 +1393,7 @@ def grpo_train(
13931393
if policy_generation is None:
13941394
policy_generation = policy # type: ignore
13951395
NEED_REFIT = False
1396-
elif master_config["policy"]["generation"]["backend"] == "dynamo":
1396+
elif master_config["policy"]["generation"]["backend"] in ("dynamo", "vllm"):
13971397
NEED_REFIT = False
13981398
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
13991399
assert policy_generation is not None # for mypy type check
@@ -1579,7 +1579,9 @@ def grpo_train(
15791579
input_batch=repeated_batch,
15801580
tokenizer=tokenizer,
15811581
task_to_env=task_to_env,
1582-
max_seq_len=None,
1582+
max_seq_len=master_config["policy"][
1583+
"max_total_sequence_length"
1584+
],
15831585
generation_config=generation_config,
15841586
max_rollout_turns=None,
15851587
greedy=False,
@@ -2316,7 +2318,7 @@ def validate(
23162318
input_batch=val_batch,
23172319
tokenizer=tokenizer,
23182320
task_to_env=val_task_to_env,
2319-
max_seq_len=None,
2321+
max_seq_len=master_config["policy"]["max_total_sequence_length"],
23202322
generation_config=generation_config,
23212323
max_rollout_turns=None,
23222324
greedy=False,
@@ -2489,7 +2491,7 @@ def async_grpo_train(
24892491
if policy_generation is None:
24902492
policy_generation = policy
24912493
NEED_REFIT = False
2492-
elif master_config["policy"]["generation"]["backend"] == "dynamo":
2494+
elif master_config["policy"]["generation"]["backend"] in ("dynamo", "vllm"):
24932495
NEED_REFIT = False
24942496
POLICY_GENERATION_STALE = True
24952497
assert policy_generation is not None
@@ -2950,6 +2952,11 @@ def async_grpo_train(
29502952
weight_version += 1
29512953
trajectory_collector.set_weight_version.remote(weight_version)
29522954
trajectory_collector.resume_after_refit.remote()
2955+
else:
2956+
# Advance the trajectory collector's weight version even when refit is skipped
2957+
# so that the replay buffer can sample trajectories targeted for subsequent steps.
2958+
weight_version += 1
2959+
trajectory_collector.set_weight_version.remote(weight_version)
29532960

29542961
# Clear logger metrics after each refit (weight sync), starting a new logging cycle
29552962
if policy_generation is not None:

nemo_rl/distributed/model_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,13 +2151,18 @@ def forward(
21512151
# calculate the logprobs for the last token and then return the logprobs
21522152
vocab_start_index = tp_rank * (self.vocab_size // tp_size)
21532153
vocab_end_index = min((tp_rank + 1) * (self.vocab_size // tp_size), self.vocab_size)
2154-
output_weight_layer = self.output_layer.weight
2154+
# For models with tied embeddings (e.g. Qwen3), self.output_layer.weight is None —
2155+
# the real weight lives on the embedding and must be fetched via
2156+
# shared_embedding_or_output_weight().
2157+
output_weight_layer = (
2158+
self.shared_embedding_or_output_weight()
2159+
if self.share_embeddings_and_output_weights
2160+
else self.output_layer.weight
2161+
)
21552162
logprobs = from_parallel_hidden_states_to_logprobs(
21562163
hidden_states, # .transpose(0, 1).contiguous(),
21572164
output_weight_layer,
2158-
self.shared_embedding_or_output_weight()
2159-
if self.share_embeddings_and_output_weights
2160-
else self.output_layer.weight,
2165+
output_weight_layer,
21612166
runtime_gather_output,
21622167
labels,
21632168
vocab_start_index=vocab_start_index,

nemo_rl/models/generation/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def configure_generation_config(
4444
# vllm setting
4545
if config["backend"] == "vllm":
4646
config = cast(VllmConfig, config)
47-
# set load_format
48-
config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy"
47+
# set load_format (respect user override if they set it explicitly,
48+
# e.g. to force "auto" for benchmarking without a Megatron refit).
49+
if config["vllm_cfg"].get("load_format") is None:
50+
config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy"
4951
speculative_config = config.get("vllm_kwargs", {}).get("speculative_config")
5052
if speculative_config:
5153
# Speculative decoding needs real startup weights unless the draft

nemo_rl/models/generation/dynamo/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class DynamoCfg(TypedDict, total=False):
4444
namespace: str
4545
enable_planner: bool # Launch planner + VirtualConnectorClient for autoscaling
4646
initial_dp_size: int # Workers at startup (must be <= cluster.world_size() // tp_size)
47+
tool_call_parser: str # Dynamo parser name, or "none" to disable
48+
reasoning_parser: str # Dynamo parser name, or "none" to disable
4749

4850

4951
class DynamoVllmConfig(GenerationConfig):

nemo_rl/models/generation/dynamo/dynamo_generation.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _start_planner(self) -> None:
326326
"pre_deployment_sweeping_mode": "none",
327327
"decode_engine_num_gpu": self.tp_size,
328328
"ttft": 500.0,
329-
"itl": 50.0,
329+
"itl": 1.0,
330330
"max_gpu_budget": self._inference_gpu_count,
331331
"min_endpoint": 1,
332332
"load_adjustment_interval": 5,
@@ -366,7 +366,8 @@ def _start_frontend(self) -> None:
366366
"--router-mode", "kv",
367367
"--active-decode-blocks-threshold", "1000.0",
368368
"--active-prefill-tokens-threshold", "1000000000000",
369-
"--active-prefill-tokens-threshold-frac", "1000.0"
369+
"--active-prefill-tokens-threshold-frac", "1000.0",
370+
"--router-predict-on-route"
370371
],
371372
env=env,
372373
)
@@ -543,6 +544,24 @@ def shutdown(self) -> bool:
543544
self._pool.shutdown()
544545
return True
545546

547+
# ------------------------------------------------------------------
548+
# Serialization support (for Ray actor pickling in async GRPO)
549+
# ------------------------------------------------------------------
550+
551+
def __getstate__(self):
552+
state = self.__dict__.copy()
553+
for attr, _, _ in _SUBPROCESS_REGISTRY:
554+
state[attr] = None
555+
state["_vc_stop"] = None
556+
state["_vc_thread"] = None
557+
state["_pool"] = None
558+
return state
559+
560+
def __setstate__(self, state):
561+
self.__dict__.update(state)
562+
self._external = True
563+
self._vc_stop = threading.Event()
564+
546565
# ------------------------------------------------------------------
547566
# Unsupported weight-update methods
548567
# ------------------------------------------------------------------

nemo_rl/models/generation/dynamo/dynamo_worker.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@
3434
from nemo_rl.models.generation.dynamo.config import DynamoVllmConfig
3535

3636

37+
def _normalize_parser_name(value: Optional[str], default: Optional[str]) -> Optional[str]:
38+
"""Normalize parser names from config/env, treating empty/none as disabled."""
39+
if value is None:
40+
return default
41+
normalized = value.strip()
42+
if not normalized or normalized.lower() == "none":
43+
return None
44+
return normalized
45+
46+
3747
def _build_vllm_cli_args(
3848
model_name: str,
3949
vllm_cfg: dict[str, Any],
@@ -276,10 +286,22 @@ def __init__(
276286
kv_events_config_json=kv_events_json,
277287
seed=seed,
278288
),
279-
"--dyn-tool-call-parser", "hermes",
280-
"--dyn-reasoning-parser", "qwen3"
281289
]
282290

291+
dynamo_cfg = config.get("dynamo_cfg", {})
292+
tool_call_parser = _normalize_parser_name(
293+
os.environ.get("DYNAMO_TOOL_CALL_PARSER"),
294+
_normalize_parser_name(dynamo_cfg.get("tool_call_parser"), "hermes"),
295+
)
296+
reasoning_parser = _normalize_parser_name(
297+
os.environ.get("DYNAMO_REASONING_PARSER"),
298+
_normalize_parser_name(dynamo_cfg.get("reasoning_parser"), "qwen3"),
299+
)
300+
if tool_call_parser is not None:
301+
cmd.extend(["--dyn-tool-call-parser", tool_call_parser])
302+
if reasoning_parser is not None:
303+
cmd.extend(["--dyn-reasoning-parser", reasoning_parser])
304+
283305
# --- Subprocess environment ---
284306
env = os.environ.copy()
285307
env["CUDA_VISIBLE_DEVICES"] = cuda_visible
@@ -315,7 +337,9 @@ def __init__(
315337
print(
316338
f" [DynamoVllmWorker] Launched dynamo.vllm (pid={self._process.pid}, "
317339
f"CUDA_VISIBLE_DEVICES={cuda_visible}, "
318-
f"TP={tp_size})",
340+
f"TP={tp_size}, "
341+
f"tool_call_parser={tool_call_parser or 'disabled'}, "
342+
f"reasoning_parser={reasoning_parser or 'disabled'})",
319343
flush=True,
320344
)
321345

0 commit comments

Comments
 (0)