Skip to content

Commit a12180c

Browse files
committed
fix comment
1 parent dcf3dc3 commit a12180c

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

recipe/simple_use_case/async_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
6666
output = TensorDict(
6767
{
6868
"generate_sequences_ids": output,
69-
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
69+
"non_tensor_data": NonTensorData(["test_str" for _ in range(output.size(0))]),
7070
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
7171
},
7272
batch_size=output.size(0),
@@ -118,7 +118,7 @@ async def generate(self, data_meta):
118118
output = TensorDict(
119119
{
120120
"generate_sequences_ids": data,
121-
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
121+
"non_tensor_data": NonTensorData(["test_str" for _ in range(data.size(0))]),
122122
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
123123
},
124124
batch_size=data.size(0),

transfer_queue/storage/simple_backend.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[st
8686
if len(local_indexes) == 1:
8787
# The unsqueeze op make the shape from n to (1, n)
8888
gathered_item = self.field_data[field][local_indexes[0]]
89+
if gathered_item is None:
90+
raise ValueError(f"Missing data for field '{field}' at index {local_indexes[0]}")
8991
if not isinstance(gathered_item, torch.Tensor):
9092
result[field] = NonTensorStack(gathered_item)
9193
else:
@@ -94,13 +96,18 @@ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[st
9496
gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
9597

9698
if gathered_items:
99+
if any(x is None for x in gathered_items):
100+
missing = [i for i, x in zip(local_indexes, gathered_items) if x is None]
101+
raise ValueError(f"Missing data for field '{field}' at indexes {missing}")
97102
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
98103
if all_tensors:
99104
result[field] = torch.nested.as_nested_tensor(gathered_items)
100105
else:
101106
result[field] = NonTensorStack(*gathered_items)
102107

103-
return TensorDict(result)
108+
# Explicit batch size for stability
109+
bs = 0 if not fields or not local_indexes else len(local_indexes)
110+
return TensorDict(result, batch_size=bs)
104111

105112
def put_data(self, field_data: TensorDict[str, Any], local_indexes: list[int]) -> None:
106113
"""
@@ -110,7 +117,15 @@ def put_data(self, field_data: TensorDict[str, Any], local_indexes: list[int]) -
110117
field_data: Dict with field names as keys, corresponding data in the field as values.
111118
local_indexes: Local indexes used for putting data.
112119
"""
113-
extracted_data = dict(field_data)
120+
# Accept TensorDict or plain dict[str, list-like]
121+
if isinstance(field_data, TensorDict):
122+
extracted_data = field_data.to_dict()
123+
elif isinstance(field_data, dict):
124+
extracted_data = field_data
125+
else:
126+
raise TypeError(
127+
f"field_data must be a TensorDict or dict[str, list-like], got {type(field_data)}"
128+
)
114129

115130
for f, values in extracted_data.items():
116131
if f not in self.field_data:

0 commit comments

Comments
 (0)