Skip to content

Commit eecbcc4

Browse files
refactor(sync-rollout-actor): remove unused wrappers; document full lifecycle
Remove three actor-level wrappers (finish_generation, get_logger_metrics, clear_logger_metrics) that had zero external callers. The actor's internal code already calls self.policy_generation.{...} directly at the right points inside rollout_to_tq; the wrappers added indirection without value. Rewrite the rollout_to_tq docstring to list all six steps bundled into the single Ray RPC (reset metrics -> rollout -> flatten -> TQ put -> release GPU -> capture metrics), making the lifecycle visible without having to read the method body. Per yuki-97 PR review (#7, #8). Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent 416df7c commit eecbcc4

1 file changed

Lines changed: 25 additions & 23 deletions

File tree

nemo_rl/experience/sync_rollout_actor.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,31 @@ def rollout_to_tq(
9797
dict[str, Any],
9898
Optional[dict[str, Any]],
9999
]:
100-
"""Rollout → flatten + mask + prompt extraction → flat ``kv_batch_put``.
101-
102-
``slice`` carries only the small per-sample tensors the driver
103-
needs for its own per-sample compute (scale_rewards,
104-
reward_shaping, overlong filtering, baseline/std,
105-
dynamic_sampling, advantage). The actor handles the bulk-touching
106-
ops (flatten / mask / prompt extraction) that require
107-
``message_log`` and would otherwise force bulk onto the driver.
100+
"""Run the full per-step generation cycle and write bulk data to TQ.
101+
102+
Bundles six steps into one Ray round-trip so the driver only sees
103+
a single RPC instead of separate calls for each:
104+
105+
1. **Reset metrics** — ``policy_generation.clear_logger_metrics()``
106+
clears per-step generation accumulators before the rollout.
107+
2. **Rollout** — runs ``run_multi_turn_rollout`` (or the async /
108+
nemo-gym variants) to produce ``final_batch``.
109+
3. **Flatten + mask + prompt extraction** — converts
110+
``message_log`` layout to flat tensors; builds token mask,
111+
sample mask, prompt-only ids, baseline/std.
112+
4. **Write bulk to TQ** — ``kv_first_write`` puts every tensor
113+
field in one flat ``kv_batch_put``; the driver never touches
114+
bulk bytes.
115+
5. **Release GPU** — ``policy_generation.finish_generation()``
116+
frees KV cache and inference state so the trainer can use the
117+
GPU immediately.
118+
6. **Capture metrics** — ``policy_generation.get_logger_metrics()``
119+
collects generation stats (throughput, etc.) and returns them
120+
to the driver in the result tuple.
121+
122+
The driver receives ``(meta, slice, rollout_metrics,
123+
generation_logger_metrics)`` and uses only the small per-sample
124+
slice for its own compute (rewards, advantages, dynamic sampling).
108125
109126
Args:
110127
input_batch: Per-step prompt batch (already repeat-interleaved).
@@ -290,21 +307,6 @@ def rollout_to_tq(
290307
gen_metrics = None
291308
return meta, slice_extras, rollout_metrics, gen_metrics
292309

293-
def finish_generation(self) -> None:
294-
"""Forward to ``policy_generation.finish_generation``."""
295-
if self.policy_generation is not None:
296-
self.policy_generation.finish_generation()
297-
298-
def get_logger_metrics(self) -> Optional[dict[str, Any]]:
299-
if self.policy_generation is None:
300-
return None
301-
return self.policy_generation.get_logger_metrics()
302-
303-
def clear_logger_metrics(self) -> None:
304-
if self.policy_generation is None:
305-
return
306-
self.policy_generation.clear_logger_metrics()
307-
308310
def shutdown(self) -> None:
309311
try:
310312
self._dp_client.close()

0 commit comments

Comments
 (0)