From c6a661e86b388321d1c0e4bab7716d3b200dcab0 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Fri, 3 Apr 2026 10:52:35 -0700 Subject: [PATCH] Use batched async generation for validation --- nemo_rl/algorithms/distillation.py | 3 +- nemo_rl/algorithms/grpo.py | 3 +- nemo_rl/experience/rollouts.py | 210 +++++++++++++++++++++ tests/unit/algorithms/test_distillation.py | 34 ++++ tests/unit/algorithms/test_grpo.py | 94 +++++++++ tests/unit/experience/test_rollouts.py | 159 ++++++++++++++++ 6 files changed, 501 insertions(+), 2 deletions(-) diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 28ecda1869..192970c909 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -44,6 +44,7 @@ ) from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.experience.rollouts import ( + run_multi_turn_rollout_async_generation, run_async_multi_turn_rollout, run_multi_turn_rollout, ) @@ -984,7 +985,7 @@ def validate( # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) # Use async rollouts if vLLM async engine is enabled if _should_use_async_rollouts(master_config): - val_batch, gen_metrics = run_async_multi_turn_rollout( + val_batch, gen_metrics = run_multi_turn_rollout_async_generation( policy_generation, val_batch, tokenizer, diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 02e43ae659..514cff9208 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -63,6 +63,7 @@ from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.experience.rollouts import ( + run_multi_turn_rollout_async_generation, run_async_multi_turn_rollout, run_async_nemo_gym_rollout, run_multi_turn_rollout, @@ -2263,7 +2264,7 @@ def validate( gen_metrics = nemo_gym_rollout_result.rollout_metrics additional_metrics_to_report = gen_metrics elif _should_use_async_rollouts(master_config): - val_batch, gen_metrics = run_async_multi_turn_rollout( + val_batch, gen_metrics = run_multi_turn_rollout_async_generation( policy_generation, val_batch, tokenizer, diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index d5186e868a..00b1b872a5 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -596,6 +596,213 @@ def run_multi_turn_rollout( return current_batch, rollout_metrics +def run_multi_turn_rollout_async_generation( + policy_generation: GenerationInterface, + input_batch: BatchedDataDict[DatumSpec], + tokenizer: TokenizerType, + task_to_env: dict[str, EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False, +) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]: + """Run a batched multi-turn rollout using async generation. + + This mirrors `run_multi_turn_rollout()`'s batched environment interaction, + but swaps synchronous generation for `generate_responses_async()`. It is + intended for cases like validation where preserving batched environment + evaluation matters more than per-sample pipelining across turns. + """ + + async def _async_rollout_implementation(): + current_batch = input_batch.copy() # Work on a copy + batch_size = len(current_batch["message_log"]) + active_indices = torch.arange(batch_size) + total_rewards = torch.zeros(batch_size, dtype=torch.float32) + + # Multi_rewards: number of components inferred from first env_output (1 for single-reward envs) + number_of_rewards: int | None = None + multi_rewards: torch.Tensor | None = None + + # Initialize stop_strings from the initial batch if present + current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) + + # Tracking metrics for each sample + sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_terminated = torch.zeros(batch_size, dtype=torch.bool) + sample_truncated = torch.zeros(batch_size, dtype=torch.bool) + sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool) + + # Tracking per-turn metrics + total_gen_tokens_per_turn = [] + active_samples_per_turn = [] + + for turn in range(max_rollout_turns): + if len(active_indices) == 0: + break + + active_samples_per_turn.append(len(active_indices)) + + active_batch = current_batch.select_indices(active_indices) + active_stop_strings = [ + current_stop_strings[i] for i in active_indices.tolist() + ] + + active_flat_messages: BatchedDataDict[FlatMessagesType] + active_flat_messages, active_input_lengths = ( + batched_message_log_to_flat_message( + active_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + ) + + active_input_ids = active_flat_messages["token_ids"] + generation_input_data = BatchedDataDict[GenerationDatumSpec]( + { + "input_ids": active_input_ids, + "input_lengths": active_input_lengths, + "stop_strings": active_stop_strings, + } + ) + multimodal_data = active_flat_messages.get_multimodal_dict(as_tensors=False) + generation_input_data.update(multimodal_data) + + if "vllm_content" in active_batch: + generation_input_data["vllm_content"] = active_batch["vllm_content"] + if "vllm_images" in active_batch: + generation_input_data["vllm_images"] = active_batch["vllm_images"] + if "vllm_videos" in active_batch: + generation_input_data["vllm_videos"] = active_batch["vllm_videos"] + + active_batch, generated_ids, gen_metrics = await generate_responses_async( + policy_generation, + generation_input_data, + active_batch, + tokenizer, + input_lengths=active_input_lengths, + greedy=greedy, + ) + + response_truncated = gen_metrics.pop("_response_truncated", None) + if response_truncated is not None: + for i, global_idx in enumerate(active_indices.tolist()): + if response_truncated[i]: + sample_truncated[global_idx] = True + + for i, global_idx in enumerate(active_indices.tolist()): + sample_assistant_token_counts[global_idx] += len(generated_ids[i]) + sample_token_counts[global_idx] += len(generated_ids[i]) + + total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids)) + + env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) + + if number_of_rewards is None: + if env_output.rewards.ndim >= 2: + number_of_rewards = int(env_output.rewards.shape[1]) + multi_rewards = torch.zeros( + batch_size, number_of_rewards, dtype=torch.float32 + ) + else: + number_of_rewards = 1 + + if number_of_rewards > 1: + assert multi_rewards is not None + multi_rewards[active_indices] += env_output.rewards + total_rewards[active_indices] += env_output.rewards.sum(dim=1) + else: + total_rewards[active_indices] += env_output.rewards + + truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) + for i, global_idx in enumerate(active_indices.tolist()): + env_obs_content = env_output.observations[i]["content"] + tokenized_obs = tokenizer( + env_obs_content, return_tensors="pt", add_special_tokens=False + ).input_ids[0] + tokenized_obs = tokenized_obs.to(dtype=torch.int64) + + if ( + len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i] + >= max_seq_len + ): + tokens_left_for_obs = max_seq_len - ( + len(generated_ids[i]) + active_input_lengths[i] + ) + assert tokens_left_for_obs >= 0, ( + f"tokens_left_for_obs={tokens_left_for_obs} should not be negative. This should not happen if the inference engine respects the max sequence length." + ) + tokenized_obs = tokenized_obs[:tokens_left_for_obs] + truncation_mask[i] = True + sample_truncated[active_indices[i]] = True + + tokenized_env_obs_message = { + "role": env_output.observations[i]["role"], + "content": env_obs_content, + "token_ids": tokenized_obs, + } + current_batch["message_log"][global_idx].append(tokenized_env_obs_message) + + sample_env_token_counts[global_idx] += len(tokenized_obs) + sample_token_counts[global_idx] += len(tokenized_obs) + sample_turn_counts[global_idx] += 1 + + terminateds = env_output.terminateds.bool() + done = truncation_mask | terminateds + sample_terminated[active_indices] |= done + + active_indices_local_next = torch.where(~done)[0] + active_indices = active_indices[active_indices_local_next] + continuing_indices_global = active_indices + continuing_next_stops = [ + env_output.next_stop_strings[i] for i in active_indices_local_next.tolist() + ] + continuing_metadata = [ + env_output.metadata[i] for i in active_indices_local_next.tolist() + ] + + for i, global_idx in enumerate(continuing_indices_global.tolist()): + current_stop_strings[global_idx] = continuing_next_stops[i] + if continuing_metadata[i] is not None: + current_batch["extra_env_info"][global_idx] = continuing_metadata[i] + + sample_max_turns_reached[active_indices] = True + + current_batch["total_reward"] = total_rewards + current_batch["truncated"] = sample_truncated + if multi_rewards is not None: + num_reward_components = multi_rewards.shape[1] + for i in range(num_reward_components): + current_batch[f"reward{i + 1}"] = multi_rewards[:, i].clone() + + rollout_metrics = { + "total_turns": int(sample_turn_counts.sum().item()), + "avg_turns_per_sample": float(sample_turn_counts.float().mean().item()), + "max_turns_per_sample": int(sample_turn_counts.max().item()), + "natural_termination_rate": float(sample_terminated.float().mean().item()), + "truncation_rate": float(sample_truncated.float().mean().item()), + "max_turns_reached_rate": float( + sample_max_turns_reached.float().mean().item() + ), + "mean_total_tokens_per_sample": float( + sample_token_counts.float().mean().item() + ), + "mean_gen_tokens_per_sample": float( + sample_assistant_token_counts.float().mean().item() + ), + "max_gen_tokens_per_sample": float( + sample_assistant_token_counts.float().max().item() + ), + "mean_env_tokens_per_sample": float( + sample_env_token_counts.float().mean().item() + ), + } + return current_batch, rollout_metrics + + return asyncio.run(_async_rollout_implementation()) + + async def async_generate_response_for_sample_turn( policy_generation: GenerationInterface, sample_message_log: list[dict], @@ -872,6 +1079,9 @@ def run_async_multi_turn_rollout( Each sample in the batch proceeds through its interaction independently. Async generation is used internally when available but the function is synchronous. + This keeps sample-level pipelining across turns, which is useful for some + training paths, but it also means environment evaluation happens from the + per-sample loop rather than from a batched rollout loop. Args: policy_generation: The generation interface (policy) diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index ae30ecfe24..1ce39aaf09 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -294,6 +294,40 @@ def test_validate_function(mock_components): # Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller +def test_validate_function_uses_batched_async_generation_helper(mock_components): + """Async distillation validation should use the batched async-generation helper.""" + mock_components["master_config"]["policy"]["generation"]["backend"] = "vllm" + mock_components["master_config"]["policy"]["generation"]["vllm_cfg"] = { + "async_engine": True + } + mock_components["student_generation"] = MagicMock() + + mock_rollout_metrics = {"mean_gen_tokens_per_sample": 1.0} + with patch( + "nemo_rl.algorithms.distillation.run_multi_turn_rollout_async_generation" + ) as mock_async_validation_rollout: + mock_async_validation_rollout.return_value = ( + next(iter(mock_components["val_dataloader"])), + mock_rollout_metrics, + ) + with patch( + "nemo_rl.algorithms.distillation.run_async_multi_turn_rollout", + side_effect=AssertionError( + "Validation should not use run_async_multi_turn_rollout" + ), + ): + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=mock_components["master_config"], + ) + + mock_async_validation_rollout.assert_called_once() + + def test_check_vocab_equality_pass(monkeypatch): student_tokenizer = MagicMock() student_tokenizer.get_vocab.return_value = {"a": 0, "b": 1} diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 2ddbf001c9..8c8f13d75a 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -2127,6 +2127,100 @@ def test_validate_works_without_logger(self): assert "accuracy" in val_metrics assert "avg_length" in val_metrics + def test_validate_uses_batched_async_generation_helper(self): + """Async validation should use the batched async-generation rollout helper.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test1", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "response1", + "token_ids": torch.tensor([4, 5, 6]), + }, + ] + ], + "task_name": ["math"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([6]), + "total_reward": torch.tensor([1.0]), + } + ) + + mock_dataloader = MagicMock(spec=StatefulDataLoader) + mock_dataloader.__iter__ = MagicMock(return_value=iter([mock_batch])) + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = (mock_batch, {}) + + mock_config = { + "grpo": { + "max_val_samples": 10, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": True}, + }, + }, + "logger": { + "num_val_samples_to_print": 1, + }, + } + + mock_rollout_metrics = {"mean_gen_tokens_per_sample": 10.0} + + with patch( + "nemo_rl.algorithms.grpo.run_multi_turn_rollout_async_generation" + ) as mock_async_validation_rollout: + mock_async_validation_rollout.return_value = ( + mock_batch, + mock_rollout_metrics, + ) + with patch( + "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", + side_effect=AssertionError( + "Validation should not use run_async_multi_turn_rollout" + ), + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=True, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=5, + master_config=mock_config, + logger=None, + ) + + mock_async_validation_rollout.assert_called_once() + def test_validate_returns_empty_when_no_dataloader(self): """Test that validate returns empty dicts when no dataloader is provided.""" mock_policy_gen = MagicMock() diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index f3486de21e..26facca844 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -17,6 +17,8 @@ import tempfile from copy import deepcopy from dataclasses import asdict +from types import SimpleNamespace +from unittest.mock import MagicMock import pytest import ray @@ -38,6 +40,7 @@ ) from nemo_rl.experience.rollouts import ( _calculate_single_metric, + run_multi_turn_rollout_async_generation, run_async_multi_turn_rollout, run_async_nemo_gym_rollout, run_multi_turn_rollout, @@ -101,6 +104,162 @@ def test_two_identical_values_returns_zero_stddev(self): assert result["test/stddev"] == 0.0 +class _FakeRemoteMethod: + def __init__(self, fn): + self._fn = fn + + def remote(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + +class _RecordingEnvHandle: + def __init__(self, reward_offset: float = 0.0): + self.calls = [] + self.reward_offset = reward_offset + self.step = _FakeRemoteMethod(self._step) + + def _step(self, messages, env_info): + self.calls.append(len(messages)) + batch_size = len(messages) + rewards = torch.tensor( + [self.reward_offset + i + 1 for i in range(batch_size)], + dtype=torch.float32, + ) + return ( + [{"role": "environment", "content": "done"} for _ in range(batch_size)], + [None] * batch_size, + [None] * batch_size, + rewards, + torch.ones(batch_size, dtype=torch.bool), + [None] * batch_size, + ) + + +class _FakeTokenizer: + pad_token_id = 0 + + def __call__(self, text, return_tensors="pt", add_special_tokens=False): + del text, return_tensors, add_special_tokens + return SimpleNamespace(input_ids=torch.tensor([[1]], dtype=torch.int64)) + + +def _make_test_rollout_batch(task_names: list[str]) -> BatchedDataDict[DatumSpec]: + message_logs = [] + for idx, task_name in enumerate(task_names): + del task_name + message_logs.append( + [ + { + "role": "user", + "content": f"prompt-{idx}", + "token_ids": torch.tensor([idx + 1, idx + 2], dtype=torch.int64), + } + ] + ) + + return BatchedDataDict[DatumSpec]( + { + "message_log": message_logs, + "task_name": task_names, + "extra_env_info": [{} for _ in task_names], + "idx": torch.tensor(list(range(len(task_names)))), + } + ) + + +def test_run_multi_turn_rollout_async_generation_batches_same_task(monkeypatch): + async def fake_generate_responses_async( + policy_generation, + generation_input_data, + batch, + tokenizer, + input_lengths, + include_logprobs=True, + greedy=False, + ): + del policy_generation, generation_input_data, tokenizer, input_lengths + del include_logprobs, greedy + generated_ids = [] + for i, message_log in enumerate(batch["message_log"]): + token_ids = torch.tensor([10 + i], dtype=torch.int64) + message_log.append( + { + "role": "assistant", + "content": f"answer-{i}", + "token_ids": token_ids, + } + ) + generated_ids.append(token_ids) + return batch, generated_ids, {"mean_generation_length": 1.0} + + monkeypatch.setattr(ray, "get", lambda refs: refs if isinstance(refs, list) else refs) + monkeypatch.setattr( + "nemo_rl.experience.rollouts.generate_responses_async", + fake_generate_responses_async, + ) + + env = _RecordingEnvHandle() + final_batch, rollout_metrics = run_multi_turn_rollout_async_generation( + policy_generation=MagicMock(), + input_batch=_make_test_rollout_batch(["math", "math"]), + tokenizer=_FakeTokenizer(), + task_to_env={"math": env}, + max_seq_len=32, + max_rollout_turns=1, + ) + + assert env.calls == [2] + assert torch.equal(final_batch["total_reward"], torch.tensor([1.0, 2.0])) + assert rollout_metrics["total_turns"] == 2 + + +def test_run_multi_turn_rollout_async_generation_keeps_task_grouping(monkeypatch): + async def fake_generate_responses_async( + policy_generation, + generation_input_data, + batch, + tokenizer, + input_lengths, + include_logprobs=True, + greedy=False, + ): + del policy_generation, generation_input_data, tokenizer, input_lengths + del include_logprobs, greedy + generated_ids = [] + for i, message_log in enumerate(batch["message_log"]): + token_ids = torch.tensor([20 + i], dtype=torch.int64) + message_log.append( + { + "role": "assistant", + "content": f"mixed-answer-{i}", + "token_ids": token_ids, + } + ) + generated_ids.append(token_ids) + return batch, generated_ids, {"mean_generation_length": 1.0} + + monkeypatch.setattr(ray, "get", lambda refs: refs if isinstance(refs, list) else refs) + monkeypatch.setattr( + "nemo_rl.experience.rollouts.generate_responses_async", + fake_generate_responses_async, + ) + + math_env = _RecordingEnvHandle(reward_offset=0.0) + code_env = _RecordingEnvHandle(reward_offset=10.0) + final_batch, _ = run_multi_turn_rollout_async_generation( + policy_generation=MagicMock(), + input_batch=_make_test_rollout_batch(["math", "code"]), + tokenizer=_FakeTokenizer(), + task_to_env={"math": math_env, "code": code_env}, + max_seq_len=32, + max_rollout_turns=1, + ) + + assert math_env.calls == [1] + assert code_env.calls == [1] + assert torch.equal(final_batch["total_reward"], torch.tensor([1.0, 11.0])) + + @pytest.fixture(scope="function") def rollout_tokenizer(): """Loads the tokenizer for the tests."""