Skip to content

Commit eb59ad0

Browse files
author
ascend-robot
committed
[optim] Add layout=jagged to irregular nested tensors
Co-authored-by: tianyi-huawei<getianyi1@huawei.com> # message auto-generated for no-merge-commit merge: !16 merge feat/jagged into main [optim] Add layout=jagged to irregular nested tensors Created-by: tianyi-huawei Commit-by: tianyi-huawei Merged-by: ascend-robot Description: add layout=jagged to nested tensors because torch warning suggests this usage before: ![image.png](https://raw.gitcode.com/user-images/assets/8886051/b1de5079-8366-4f24-a9fb-e1df0320564f/image.png 'image.png') after: ![image.png](https://raw.gitcode.com/user-images/assets/8886051/9d10c041-245e-4e91-9386-83a66efe448f/image.png 'image.png') See merge request: Ascend/TransferQueue!16
1 parent 4424e2d commit eb59ad0

3 files changed

Lines changed: 5 additions & 3 deletions

File tree

transfer_queue/storage/managers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> T
396396
except RuntimeError:
397397
try:
398398
# Fallback to nested tensor if shapes are irregular
399-
merged_data[field] = torch.nested.as_nested_tensor(data_list)
399+
merged_data[field] = torch.nested.as_nested_tensor(data_list, layout=torch.jagged)
400400
except Exception:
401401
merged_data[field] = NonTensorStack(*data_list)
402402
else:

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
313313
and all(isinstance(item, torch.Tensor) for item in v)
314314
and all(item.shape == v[0].shape for item in v)
315315
else (
316-
torch.nested.as_nested_tensor(v)
316+
torch.nested.as_nested_tensor(v, layout=torch.jagged)
317317
if v and all(isinstance(item, torch.Tensor) for item in v)
318318
else NonTensorStack(*v)
319319
)

transfer_queue/utils/serial_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ def deserialization(data: list[bytestr] | bytestr) -> Any:
264264
tensors[i] = single_tensors[current_idx]
265265
current_idx += 1
266266
else:
267-
tensors[i] = torch.nested.as_nested_tensor(single_tensors[current_idx : current_idx + tensor_num])
267+
tensors[i] = torch.nested.as_nested_tensor(
268+
single_tensors[current_idx : current_idx + tensor_num], layout=torch.strided
269+
)
268270
current_idx += tensor_num
269271

270272
return _internal_rpc_pickler.deserialize(pickled_bytes, tensors)

0 commit comments

Comments
 (0)