Skip to content

Commit e7f6a91

Browse files
style: fix ruff lint errors and apply ruff format
Fix 39 ruff lint violations (F401 unused imports, D208/D209 docstring formatting, D205 missing blank line after summary) and reformat 28 files with ruff format. Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent b041dd2 commit e7f6a91

30 files changed

Lines changed: 900 additions & 344 deletions

examples/run_grpo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def main() -> None:
108108

109109
def _make_policy(**kwargs):
110110
return TQPolicy(**kwargs, dp_cfg=_dp_cfg)
111+
111112
_policy_factory = _make_policy
112113
else:
113114
_policy_factory = None # setup() defaults to plain Policy
@@ -124,7 +125,10 @@ def _make_policy(**kwargs):
124125
grpo_state,
125126
master_config,
126127
) = setup(
127-
config, tokenizer, dataset, val_dataset,
128+
config,
129+
tokenizer,
130+
dataset,
131+
val_dataset,
128132
policy_factory=_policy_factory,
129133
)
130134

nemo_rl/algorithms/grpo_sync.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,19 @@ def _apply_dynamic_sampling(
9797
max_gen_batches: int,
9898
dp_client: DataPlaneClient,
9999
) -> tuple[
100-
Optional[KVBatchMeta], Optional[_DSlice],
101-
list[torch.Tensor], bool, dict[str, Any], Optional[torch.Tensor],
100+
Optional[KVBatchMeta],
101+
Optional[_DSlice],
102+
list[torch.Tensor],
103+
bool,
104+
dict[str, Any],
105+
Optional[torch.Tensor],
102106
]:
103-
"""One iteration. Returns (pending_meta, pending_slice, pending_rewards,
107+
"""One iteration.
108+
109+
Returns (pending_meta, pending_slice, pending_rewards,
104110
is_complete, ds_metrics, unfiltered_for_log). When complete, the returned
105-
pending_* IS the training batch."""
111+
pending_* IS the training batch.
112+
"""
106113
# Cumulative unfiltered total_reward for legacy metrics["reward"]
107114
# parity. Reference-only append (no copy) — slice tensors are
108115
# produced fresh per iteration, not aliased to TQ-owned bulk.
@@ -145,7 +152,9 @@ def _apply_dynamic_sampling(
145152
)
146153
pending_meta = pending_meta.slice(0, train_prompts_size)
147154
pending_slice = pending_slice.slice(0, train_prompts_size)
148-
ds_metrics["dynamic_sampling_num_discarded_valid_samples"] = n - train_prompts_size
155+
ds_metrics["dynamic_sampling_num_discarded_valid_samples"] = (
156+
n - train_prompts_size
157+
)
149158

150159
unfiltered_for_log = torch.cat(pending_unfiltered_rewards)[:train_prompts_size]
151160
return pending_meta, pending_slice, [], True, ds_metrics, unfiltered_for_log
@@ -404,9 +413,7 @@ def grpo_train_sync(
404413
# partition exists with the expected schema.
405414
policy.prepare_step(
406415
num_samples=int(repeated_batch.size),
407-
group_size=master_config["grpo"][
408-
"num_generations_per_prompt"
409-
],
416+
group_size=master_config["grpo"]["num_generations_per_prompt"],
410417
)
411418

412419
# ── Rollout 1-hop put: actor runs rollout + flatten +
@@ -462,11 +469,13 @@ def grpo_train_sync(
462469
# touched by any of these ops).
463470
with timer.time("reward_calculation"):
464471
slice_data = scale_rewards(
465-
slice_data, master_config["grpo"]["reward_scaling"],
472+
slice_data,
473+
master_config["grpo"]["reward_scaling"],
466474
)
467475
if master_config["grpo"]["reward_shaping"]["enabled"]:
468476
slice_data = apply_reward_shaping(
469-
slice_data, master_config["grpo"]["reward_shaping"],
477+
slice_data,
478+
master_config["grpo"]["reward_shaping"],
470479
)
471480
if master_config["grpo"]["overlong_filtering"]:
472481
lm = slice_data["loss_multiplier"].clone()
@@ -495,9 +504,11 @@ def grpo_train_sync(
495504
* master_config["grpo"]["num_generations_per_prompt"]
496505
)
497506
(
498-
pending_meta, pending_slice,
507+
pending_meta,
508+
pending_slice,
499509
pending_unfiltered_rewards,
500-
is_complete, ds_metrics,
510+
is_complete,
511+
ds_metrics,
501512
unfiltered_rewards_for_logging,
502513
) = _apply_dynamic_sampling(
503514
meta=meta,
@@ -571,7 +582,8 @@ def grpo_train_sync(
571582
"skip_reference_policy_logprobs_calculation"
572583
):
573584
_ref_lp = policy.get_reference_policy_logprobs_from_meta(
574-
meta, timer=timer,
585+
meta,
586+
timer=timer,
575587
)
576588
reference_policy_logprobs = _ref_lp["reference_logprobs"]
577589
else:
@@ -582,7 +594,8 @@ def grpo_train_sync(
582594
# output_ids, attention_mask, position_ids) stays in
583595
# TQ — workers will fetch it via ``train_presharded``.
584596
extras_bdd = read_columns(
585-
policy._dp_client, meta,
597+
policy._dp_client,
598+
meta,
586599
select_fields=["generation_logprobs", "token_mask"],
587600
pad_value_dict=_pad_dict,
588601
)
@@ -658,7 +671,8 @@ def grpo_train_sync(
658671
# sample_mask under the same meta.keys so workers fetch
659672
# the union via train_presharded.
660673
write_columns(
661-
policy._dp_client, meta,
674+
policy._dp_client,
675+
meta,
662676
fields={
663677
"advantages": advantages,
664678
"sample_mask": sample_mask,
@@ -696,20 +710,27 @@ def grpo_train_sync(
696710
# mask / adv columns added later are irrelevant
697711
# here.
698712
_calib_fields = [
699-
f for f in (meta.fields or [])
700-
if f not in (
701-
"generation_logprobs", "token_mask",
702-
"sample_mask", "prev_logprobs",
703-
"reference_policy_logprobs", "advantages",
713+
f
714+
for f in (meta.fields or [])
715+
if f
716+
not in (
717+
"generation_logprobs",
718+
"token_mask",
719+
"sample_mask",
720+
"prev_logprobs",
721+
"reference_policy_logprobs",
722+
"advantages",
704723
)
705724
]
706725
calibration_data = read_columns(
707-
policy._dp_client, meta,
726+
policy._dp_client,
727+
meta,
708728
select_fields=_calib_fields,
709729
pad_value_dict=_pad_dict,
710730
)
711731
kv_scales_cache = policy.calibrate_qkv_fp8_scales(
712-
calibration_data, include_q=True,
732+
calibration_data,
733+
include_q=True,
713734
)["layers"]
714735
POLICY_GENERATION_STALE = True
715736

@@ -726,15 +747,18 @@ def grpo_train_sync(
726747
if "content" in (meta.fields or []):
727748
_log_select.append("content")
728749
_log_extras = read_columns(
729-
policy._dp_client, meta, select_fields=_log_select,
750+
policy._dp_client,
751+
meta,
752+
select_fields=_log_select,
730753
pad_value_dict=_pad_dict,
731754
)
732755
_log_input_ids = _log_extras["input_ids"]
733756
_log_content = _log_extras.get("content")
734757

735758
# ── Step-end TQ cleanup ────────────────────────────────
736759
policy._dp_client.kv_clear(
737-
keys=meta.keys, partition_id=meta.partition_id,
760+
keys=meta.keys,
761+
partition_id=meta.partition_id,
738762
)
739763

740764
is_last_step = total_steps + 1 >= max_num_steps
@@ -779,9 +803,7 @@ def grpo_train_sync(
779803

780804
# advantages and token_mask are in scope from the
781805
# advantage / masking blocks above. No need to re-fetch.
782-
response_advantages = torch.masked_select(
783-
advantages, token_mask.bool()
784-
)
806+
response_advantages = torch.masked_select(advantages, token_mask.bool())
785807

786808
memory_tracker.snapshot_start_of_stage("Metrics", dir())
787809
metrics = {
@@ -1033,7 +1055,9 @@ def grpo_train_sync(
10331055
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
10341056
if master_config["grpo"]["use_dynamic_sampling"]:
10351057
print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}")
1036-
print(f" • Avg Total Reward: {np.mean(unfiltered_rewards.numpy()):.4f}")
1058+
print(
1059+
f" • Avg Total Reward: {np.mean(unfiltered_rewards.numpy()):.4f}"
1060+
)
10371061
else:
10381062
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
10391063
print(

nemo_rl/data_plane/adapters/transfer_queue.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ def _mooncake_transport_config() -> dict:
110110

111111

112112
def _connect_existing() -> None:
113-
"""Worker-process path: connect this process's client to the
114-
already-running named controller actor in the Ray cluster. Mirrors
113+
"""Worker-process path: connect this process's client to the Ray cluster.
114+
115+
Connects to the already-running named controller actor. Mirrors
115116
rl-arena/arena/dataplane_client.py's `tq.init()` (no args) call.
116117
"""
117118
_tq().init()
@@ -121,9 +122,10 @@ def _connect_existing() -> None:
121122

122123

123124
def _patch_tq_actor_runtime_env() -> None:
124-
"""Inject Ray ``runtime_env={"pip": ["TransferQueue==0.1.6"]}`` into the
125-
``.options()`` calls on TQ's internal actor classes (``SimpleStorageUnit``,
126-
``TransferQueueController``).
125+
"""Inject Ray ``runtime_env`` into TQ's internal actor class ``.options()`` calls.
126+
127+
Injects ``{"pip": ["TransferQueue==0.1.6"]}`` into ``.options()`` for
128+
``SimpleStorageUnit`` and ``TransferQueueController``.
127129
128130
**Why**: TQ spawns these actors via ``Cls.options(...).remote(...)`` with
129131
no runtime_env. They inherit the *job-level* runtime_env that the driver
@@ -317,6 +319,7 @@ def _to_wire(td: TensorDict) -> TensorDict:
317319
# metadata-recorded shape. materialize squeezes the trailing 1
318320
# back on read so consumers see (N,).
319321
from nemo_rl.data_plane.codec import _KV_PROMOTE_1D as _promote_1d
322+
320323
if _promote_1d:
321324
new_dict: dict[str, torch.Tensor] = {}
322325
changed = False
@@ -391,6 +394,7 @@ def __init__(self, cfg: DataPlaneConfig, *, bootstrap: bool = True) -> None:
391394
os.environ["MC_TCP_BIND_ADDRESS"] = local_ip
392395
os.environ.setdefault("MC_STORE_MEMCPY", "0")
393396
from nemo_rl.data_plane.codec import set_kv_promote_1d
397+
394398
set_kv_promote_1d(True)
395399

396400
if bootstrap:

nemo_rl/data_plane/codec.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,28 @@
1313
# limitations under the License.
1414
"""Wire <-> trainer codec — jagged-on-the-wire bridge.
1515
16-
* Writer side: variable-length fields are encoded as
17-
``torch.nested.nested_tensor`` with ``layout=torch.jagged`` before
18-
``kv_batch_put``. Padding tax is paid only when a consumer needs a
19-
rectangular tensor.
20-
21-
* Reader side: :func:`materialize` accepts the wire TensorDict and,
22-
when ``layout='padded'``, calls
23-
:func:`torch.nested.to_padded_tensor` on any nested leaves using
24-
the per-field padding value supplied in ``pad_value_dict``. Trainer
25-
code consumes the padded BatchedDataDict unchanged.
26-
27-
* Worker write-backs that produce ``response``-shaped outputs use
28-
:func:`response_from_nested` to extract the response slice from a
29-
(prompt+response) nested tensor.
30-
31-
* Non-tensor object fields (verl-style ``np.ndarray(dtype=object)``)
32-
ride the same wire as variable-length tensors: each row is pickled
33-
to ``bytes`` and packed into a jagged uint8 nested tensor via
34-
:func:`pack_object_array`. Reader unpacks via
35-
:func:`unpack_object_array` and emits the field as an object array
36-
in the materialized BatchedDataDict. Backends see only tensors —
37-
no per-backend non-tensor support required.
16+
* Writer side: variable-length fields are encoded as
17+
``torch.nested.nested_tensor`` with ``layout=torch.jagged`` before
18+
``kv_batch_put``. Padding tax is paid only when a consumer needs a
19+
rectangular tensor.
20+
21+
* Reader side: :func:`materialize` accepts the wire TensorDict and,
22+
when ``layout='padded'``, calls
23+
:func:`torch.nested.to_padded_tensor` on any nested leaves using
24+
the per-field padding value supplied in ``pad_value_dict``. Trainer
25+
code consumes the padded BatchedDataDict unchanged.
26+
27+
* Worker write-backs that produce ``response``-shaped outputs use
28+
:func:`response_from_nested` to extract the response slice from a
29+
(prompt+response) nested tensor.
30+
31+
* Non-tensor object fields (verl-style ``np.ndarray(dtype=object)``)
32+
ride the same wire as variable-length tensors: each row is pickled
33+
to ``bytes`` and packed into a jagged uint8 nested tensor via
34+
:func:`pack_object_array`. Reader unpacks via
35+
:func:`unpack_object_array` and emits the field as an object array
36+
in the materialized BatchedDataDict. Backends see only tensors —
37+
no per-backend non-tensor support required.
3838
"""
3939

4040
from __future__ import annotations
@@ -101,8 +101,10 @@ def to_nested_by_length(
101101

102102

103103
def set_kv_promote_1d(enabled: bool) -> None:
104-
"""Adapter hook: when True, writer unsqueezes 1D bulk fields to
105-
(N, 1) and reader squeezes the trailing 1 in :func:`materialize`.
104+
"""Adapter hook: enable/disable 1D→(N,1) promotion for bulk fields.
105+
106+
When True, writer unsqueezes 1D bulk fields to (N, 1) and reader
107+
squeezes the trailing 1 in :func:`materialize`.
106108
107109
Required by backends that go through TQ's KVStorageManager path
108110
(mooncake_cpu) — see ``_KV_PROMOTE_1D`` above for the schema/data
@@ -157,9 +159,7 @@ def pack_object_array(arr: "np.ndarray | list[Any]") -> torch.Tensor:
157159
"""
158160
if isinstance(arr, np.ndarray):
159161
if arr.dtype != object:
160-
raise TypeError(
161-
f"pack_object_array expects dtype=object; got {arr.dtype}"
162-
)
162+
raise TypeError(f"pack_object_array expects dtype=object; got {arr.dtype}")
163163
items: list[Any] = list(arr)
164164
elif isinstance(arr, list):
165165
items = arr
@@ -173,9 +173,7 @@ def pack_object_array(arr: "np.ndarray | list[Any]") -> torch.Tensor:
173173
b = pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)
174174
# np.frombuffer + .copy() avoids the "non-writable buffer" warning
175175
# and severs the lifetime tie to the bytes object.
176-
rows.append(
177-
torch.from_numpy(np.frombuffer(b, dtype=np.uint8).copy())
178-
)
176+
rows.append(torch.from_numpy(np.frombuffer(b, dtype=np.uint8).copy()))
179177
return torch.nested.as_nested_tensor(rows, layout=torch.jagged)
180178

181179

@@ -263,9 +261,7 @@ def response_from_nested(
263261
response_list = []
264262
for resp_len, seq_offset in zip(response_lens, offsets[1:], strict=True):
265263
# left-shift output by one token for log_probs / values
266-
response_list.append(
267-
values[seq_offset - resp_len - 1 : seq_offset - 1]
268-
)
264+
response_list.append(values[seq_offset - resp_len - 1 : seq_offset - 1])
269265
return torch.nested.as_nested_tensor(response_list, layout=torch.jagged)
270266

271267

nemo_rl/data_plane/driver_io.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,8 @@ def write_columns(
9696
from nemo_rl.data_plane.codec import maybe_pack_jagged, pack_object_array
9797

9898
seq_lens = meta.sequence_lengths
99-
lengths = (
100-
torch.tensor(seq_lens, dtype=torch.long) if seq_lens is not None else None
101-
)
102-
registered_objects = set(
103-
(meta.extra_info or {}).get(META_OBJECT_FIELDS, ())
104-
)
99+
lengths = torch.tensor(seq_lens, dtype=torch.long) if seq_lens is not None else None
100+
registered_objects = set((meta.extra_info or {}).get(META_OBJECT_FIELDS, ()))
105101

106102
packed: dict[str, torch.Tensor] = {}
107103
for k, v in fields.items():
@@ -127,5 +123,7 @@ def write_columns(
127123

128124
td = TensorDict(packed, batch_size=[len(meta.keys)])
129125
dp_client.kv_batch_put(
130-
keys=meta.keys, partition_id=meta.partition_id, fields=td,
126+
keys=meta.keys,
127+
partition_id=meta.partition_id,
128+
fields=td,
131129
)

nemo_rl/data_plane/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def _replace(
112112
task_name=self.task_name,
113113
keys=list(keys),
114114
fields=self.fields,
115-
sequence_lengths=list(sequence_lengths) if sequence_lengths is not None else None,
115+
sequence_lengths=list(sequence_lengths)
116+
if sequence_lengths is not None
117+
else None,
116118
extra_info=dict(self.extra_info or {}),
117119
)
118120

0 commit comments

Comments
 (0)