@@ -86,6 +86,8 @@ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[st
8686 if len (local_indexes ) == 1 :
8787 # The unsqueeze op make the shape from n to (1, n)
8888 gathered_item = self .field_data [field ][local_indexes [0 ]]
89+ if gathered_item is None :
90+ raise ValueError (f"Missing data for field '{ field } ' at index { local_indexes [0 ]} " )
8991 if not isinstance (gathered_item , torch .Tensor ):
9092 result [field ] = NonTensorStack (gathered_item )
9193 else :
@@ -94,13 +96,18 @@ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[st
9496 gathered_items = list (itemgetter (* local_indexes )(self .field_data [field ]))
9597
9698 if gathered_items :
99+ if any (x is None for x in gathered_items ):
100+ missing = [i for i , x in zip (local_indexes , gathered_items ) if x is None ]
101+ raise ValueError (f"Missing data for field '{ field } ' at indexes { missing } " )
97102 all_tensors = all (isinstance (x , torch .Tensor ) for x in gathered_items )
98103 if all_tensors :
99104 result [field ] = torch .nested .as_nested_tensor (gathered_items )
100105 else :
101106 result [field ] = NonTensorStack (* gathered_items )
102107
103- return TensorDict (result )
108+ # Explicit batch size for stability
109+ bs = 0 if not fields or not local_indexes else len (local_indexes )
110+ return TensorDict (result , batch_size = bs )
104111
105112 def put_data (self , field_data : TensorDict [str , Any ], local_indexes : list [int ]) -> None :
106113 """
@@ -110,7 +117,15 @@ def put_data(self, field_data: TensorDict[str, Any], local_indexes: list[int]) -
110117 field_data: Dict with field names as keys, corresponding data in the field as values.
111118 local_indexes: Local indexes used for putting data.
112119 """
113- extracted_data = dict (field_data )
120+ # Accept TensorDict or plain dict[str, list-like]
121+ if isinstance (field_data , TensorDict ):
122+ extracted_data = field_data .to_dict ()
123+ elif isinstance (field_data , dict ):
124+ extracted_data = field_data
125+ else :
126+ raise TypeError (
127+ f"field_data must be a TensorDict or dict[str, list-like], got { type (field_data )} "
128+ )
114129
115130 for f , values in extracted_data .items ():
116131 if f not in self .field_data :
0 commit comments