@@ -387,26 +387,41 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack:
387387 """
388388 Pack a list of per-sample values into a batched container.
389389
390- For tensor values, this performs a memory copy via stacking or nested tensor creation.
391- Non-tensor values are grouped into a ``NonTensorStack`` without copying.
390+ For pure tensor lists (no None), this performs a memory copy via stacking
391+ or nested tensor creation. Mixed types, non-tensor values, or lists
392+ containing None placeholders are grouped into a ``NonTensorStack``.
393+
394+ Args:
395+ values: List of per-sample values to pack. May contain None for
396+ unfilled batch positions.
397+
398+ Returns:
399+ A stacked ``torch.Tensor`` (or nested tensor) when all values are
400+ tensors, otherwise a ``NonTensorStack``.
401+
402+ Raises:
403+ ValueError: If *values* is empty.
392404 """
393405 if not values :
394406 raise ValueError ("_pack_field_values received empty values list; caller should filter empty batches" )
395- if any (v is None for v in values ):
396- raise ValueError ("_pack_field_values received None in values list; some batch positions were not filled" )
397- if all (isinstance (v , torch .Tensor ) for v in values ):
398- if all (v .shape == values [0 ].shape for v in values ):
399- return torch .stack (values )
400- try :
401- return torch .nested .as_nested_tensor (values , layout = torch .jagged )
402- except (RuntimeError , TypeError ) as e :
403- logger .warning (
404- f"Failed to pack nested tensor with jagged layout. "
405- f"Falling back to strided layout. Detailed error: { e } "
406- )
407- return torch .nested .as_nested_tensor (values , layout = torch .strided )
407+ non_none = [v for v in values if v is not None ]
408+ if non_none and all (isinstance (v , torch .Tensor ) for v in non_none ):
409+ if not any (v is None for v in values ):
410+ # Pure tensor list — try stacking / nested tensor
411+ if all (v .shape == values [0 ].shape for v in values ):
412+ return torch .stack (values )
413+ try :
414+ return torch .nested .as_nested_tensor (values , layout = torch .jagged )
415+ except (RuntimeError , TypeError ) as e :
416+ logger .warning (
417+ f"Failed to pack nested tensor with jagged layout. "
418+ f"Falling back to strided layout. Detailed error: { e } "
419+ )
420+ return torch .nested .as_nested_tensor (values , layout = torch .strided )
421+ # Mixed tensor + None — cannot stack, fall through to NonTensorStack
408422 return NonTensorStack (* values )
409423
424+
410425 async def get_data (self , metadata : BatchMeta ) -> TensorDict :
411426 """
412427 Retrieve data from remote StorageUnit based on metadata.
0 commit comments