Skip to content

Commit 416df7c

Browse files
refactor(data-plane): make kv_batch_get(select_fields) required
Silent over-fetch was possible when callers omitted select_fields: the noop adapter fetched every registered field via set intersection; the TQ adapter forwarded None to the backend. Bulk schemas are wide and fetching everything is the most expensive shape the wire can take. select_fields is now a required list[str] on DataPlaneClient.kv_batch_get and all concrete implementations. Callers must name what they read; fetch-all is still possible by passing list(meta.fields) explicitly. Also: worker_mixin internal call sites use list(meta.fields) directly (fail-loud TypeError if meta.fields is None, rather than silently producing an empty TensorDict). Per yuki-97 PR review (#6). Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent b38b473 commit 416df7c

5 files changed

Lines changed: 13 additions & 12 deletions

File tree

nemo_rl/data_plane/adapters/noop.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,12 @@ def kv_batch_get(
199199
self,
200200
keys: list[str],
201201
partition_id: str,
202-
select_fields: list[str] | None = None,
202+
select_fields: list[str],
203203
) -> TensorDict:
204204
rec = self._partitions[partition_id]
205205
if not keys:
206206
return TensorDict({}, batch_size=(0,))
207207

208-
if select_fields is None:
209-
available = set.intersection(*(set(rec.rows[k].keys()) for k in keys))
210-
select_fields = sorted(available)
211-
212208
out: dict[str, list[torch.Tensor]] = {f: [] for f in select_fields}
213209
for key in keys:
214210
row = rec.rows[key]

nemo_rl/data_plane/adapters/transfer_queue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,14 +587,14 @@ def kv_batch_get(
587587
self,
588588
keys: list[str],
589589
partition_id: str,
590-
select_fields: list[str] | None = None,
590+
select_fields: list[str],
591591
) -> TensorDict:
592592
if not keys:
593593
return TensorDict({}, batch_size=(0,))
594594
td = self._tq.kv_batch_get(
595595
keys=list(keys),
596596
partition_id=partition_id,
597-
select_fields=list(select_fields) if select_fields else None,
597+
select_fields=select_fields,
598598
)
599599
if self._promote_1d:
600600
td = _from_wire(td)

nemo_rl/data_plane/interfaces.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,17 +312,22 @@ def kv_batch_get(
312312
self,
313313
keys: list[str],
314314
partition_id: str,
315-
select_fields: list[str] | None = None,
315+
select_fields: list[str],
316316
) -> TensorDict:
317317
"""Direct fetch by uids.
318318
319319
Used by per-DP-rank slice fetches. Does NOT advance any per-task
320320
consumption cursor — that only happens via :meth:`claim_meta`.
321321
322+
``select_fields`` is required (no implicit "fetch every field"
323+
fallback): bulk schemas are wide and silent over-fetch is the
324+
most expensive shape the wire can take. Callers must name what
325+
they read.
326+
322327
Args:
323328
keys: Uids to fetch.
324329
partition_id: Partition the keys live in.
325-
select_fields: Subset of fields; ``None`` fetches every registered field.
330+
select_fields: Subset of fields to fetch.
326331
327332
Returns:
328333
``TensorDict`` keyed by field name, batched along ``keys``.

nemo_rl/data_plane/observability.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def kv_batch_put(self, keys, partition_id, fields=None, tags=None):
308308
self._record_put(partition_id, keys_list, n_bytes)
309309
return out
310310

311-
def kv_batch_get(self, keys, partition_id, select_fields=None):
311+
def kv_batch_get(self, keys, partition_id, select_fields):
312312
return self._run(
313313
"get",
314314
partition_id,

nemo_rl/data_plane/worker_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _fetch(
221221
td = self._require_dp_client().kv_batch_get(
222222
keys=meta.keys,
223223
partition_id=meta.partition_id,
224-
select_fields=list(meta.fields) if meta.fields else None,
224+
select_fields=list(meta.fields),
225225
)
226226
data = materialize(
227227
td,
@@ -246,7 +246,7 @@ def _fetch(
246246
td = self._require_dp_client().kv_batch_get(
247247
keys=meta.keys,
248248
partition_id=meta.partition_id,
249-
select_fields=list(meta.fields) if meta.fields else None,
249+
select_fields=list(meta.fields),
250250
)
251251
data = materialize(
252252
td,

0 commit comments

Comments
 (0)