Skip to content

Commit 06de915

Browse files
committed
[fix] Allow None values in _pack_field_values and fallback to NonTensorStack
- Modify _pack_field_values to tolerate None placeholders in the values list, falling back to NonTensorStack instead of raising ValueError. - Pure tensor lists (no None) still use torch.stack or nested tensor. - Update docstring to reflect the new None-tolerant behavior.
1 parent e04cc05 commit 06de915

1 file changed

Lines changed: 30 additions & 15 deletions

File tree

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,26 +387,41 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack:
387387
"""
388388
Pack a list of per-sample values into a batched container.
389389
390-
For tensor values, this performs a memory copy via stacking or nested tensor creation.
391-
Non-tensor values are grouped into a ``NonTensorStack`` without copying.
390+
For pure tensor lists (no None), this performs a memory copy via stacking
391+
or nested tensor creation. Mixed types, non-tensor values, or lists
392+
containing None placeholders are grouped into a ``NonTensorStack``.
393+
394+
Args:
395+
values: List of per-sample values to pack. May contain None for
396+
unfilled batch positions.
397+
398+
Returns:
399+
A stacked ``torch.Tensor`` (or nested tensor) when all values are
400+
tensors, otherwise a ``NonTensorStack``.
401+
402+
Raises:
403+
ValueError: If *values* is empty.
392404
"""
393405
if not values:
394406
raise ValueError("_pack_field_values received empty values list; caller should filter empty batches")
395-
if any(v is None for v in values):
396-
raise ValueError("_pack_field_values received None in values list; some batch positions were not filled")
397-
if all(isinstance(v, torch.Tensor) for v in values):
398-
if all(v.shape == values[0].shape for v in values):
399-
return torch.stack(values)
400-
try:
401-
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
402-
except (RuntimeError, TypeError) as e:
403-
logger.warning(
404-
f"Failed to pack nested tensor with jagged layout. "
405-
f"Falling back to strided layout. Detailed error: {e}"
406-
)
407-
return torch.nested.as_nested_tensor(values, layout=torch.strided)
407+
non_none = [v for v in values if v is not None]
408+
if non_none and all(isinstance(v, torch.Tensor) for v in non_none):
409+
if not any(v is None for v in values):
410+
# Pure tensor list — try stacking / nested tensor
411+
if all(v.shape == values[0].shape for v in values):
412+
return torch.stack(values)
413+
try:
414+
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
415+
except (RuntimeError, TypeError) as e:
416+
logger.warning(
417+
f"Failed to pack nested tensor with jagged layout. "
418+
f"Falling back to strided layout. Detailed error: {e}"
419+
)
420+
return torch.nested.as_nested_tensor(values, layout=torch.strided)
421+
# Mixed tensor + None — cannot stack, fall through to NonTensorStack
408422
return NonTensorStack(*values)
409423

424+
410425
async def get_data(self, metadata: BatchMeta) -> TensorDict:
411426
"""
412427
Retrieve data from remote StorageUnit based on metadata.

0 commit comments

Comments
 (0)