Skip to content

Commit 73ed4c9

Browse files
authored
[perf] Improve performance for putting jagged tensor (Ascend#36)
## Background When users input a TensorDict containing jagged tensors (nested tensors), the `put_data` process becomes extremely slow. Specifically, the `_filter_storage_data` function uses `itemgetter(*batch_indexes)(data[fname])` to extract individual items from each tensor in the TensorDict. This indexing approach works efficiently for strided tensors but is extremely inefficient for jagged tensors. ## Root Cause For jagged tensors, itemgetter with multiple batch indexes requires repeated indexing operations, which is $\mathcal{O}(n)$ for each access. When extracting multiple samples, this becomes $\mathcal{O}(n²)$ complexity. ## Solution We unbind nested tensor before accessing each sample from it. ```python3 # unbind nested tensor results: dict = {} for field in sorted(data.keys()): field_data = data[field] if isinstance(field_data, Tensor) and field_data.is_nested: results[field] = field_data.unbind() else: results[field] = field_data ``` ## Simple Reproduction Script ```python3 import torch import time from operator import itemgetter # Create a jagged tensor with 1000 samples offsets = torch.tensor([0] + list(torch.randint(10, 50, (1001,)).cumsum(0))) values = torch.randn(offsets[-1].item(), 128) jagged = torch.nested.as_nested_tensor( [values[offsets[i]:offsets[i+1]] for i in range(1000)], layout=torch.jagged ) batch_indexes = list(range(0, 1000, 10)) # 100 indexes # Method 1: Direct itemgetter on jagged tensor (SLOW) start = time.perf_counter() result = itemgetter(*batch_indexes)(jagged) print(f"Direct itemgetter: {(time.perf_counter() - start)*1000:.2f} ms") # Method 2: Unbind first, then itemgetter (FAST) start = time.perf_counter() field_list = jagged.unbind() result = itemgetter(*batch_indexes)(field_list) print(f"Unbind + itemgetter: {(time.perf_counter() - start)*1000:.2f} ms") ``` Output: ```bash Direct itemgetter: 150.94 ms Unbind + itemgetter: 1.80 ms ``` --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent b7a3b01 commit 73ed4c9

4 files changed

Lines changed: 53 additions & 23 deletions

File tree

transfer_queue/metadata.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
from tensordict import TensorDict
2828
from tensordict.tensorclass import NonTensorData, NonTensorStack
29+
from torch import Tensor
2930

3031
from transfer_queue.utils.enum_utils import ProductionStatus
3132

@@ -815,18 +816,26 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->
815816

816817
production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED
817818

818-
all_fields = [
819-
{
820-
name: FieldMeta(
821-
name=name,
822-
dtype=getattr(value, "dtype", None),
823-
shape=getattr(value, "shape", None),
819+
# unbind nested tensor
820+
results: dict = {}
821+
for field in tensor_dict.keys():
822+
field_data = tensor_dict[field]
823+
if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested:
824+
results[field] = field_data.unbind()
825+
else:
826+
results[field] = field_data
827+
828+
all_fields = []
829+
for idx in range(batch_size):
830+
dict_of_field_meta = {}
831+
for field_name in results.keys():
832+
dict_of_field_meta[field_name] = FieldMeta(
833+
name=field_name,
834+
dtype=getattr(results[field_name][idx], "dtype", None),
835+
shape=getattr(results[field_name][idx], "shape", None),
824836
production_status=production_status,
825837
)
826-
for name, value in tensor_dict[idx].items()
827-
}
828-
for idx in range(batch_size)
829-
]
838+
all_fields.append(dict_of_field_meta)
830839

831840
return all_fields
832841

transfer_queue/storage/managers/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,14 @@ def _generate_values(data: TensorDict) -> list[Tensor]:
394394
list[Tensor]: Flattened list of tensors, e.g.,
395395
[data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
396396
"""
397-
return [row_data for field in sorted(data.keys()) for row_data in data[field]]
397+
results: list[Tensor] = []
398+
for field in sorted(data.keys()):
399+
field_data = data[field]
400+
if isinstance(field_data, Tensor) and field_data.is_nested:
401+
results.extend(field_data.unbind())
402+
else:
403+
results.extend(field_data)
404+
return results
398405

399406
@staticmethod
400407
def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None:

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import zmq
2828
from omegaconf import DictConfig
2929
from tensordict import NonTensorStack, TensorDict
30+
from torch import Tensor
3031

3132
from transfer_queue.metadata import BatchMeta
3233
from transfer_queue.storage.managers.base import TransferQueueStorageManager
@@ -201,10 +202,21 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
201202
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
202203
)
203204

205+
# unbind nested tensor
206+
results: dict = {}
207+
for field in data.keys():
208+
field_data = data[field]
209+
if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested:
210+
results[field] = field_data.unbind()
211+
else:
212+
results[field] = field_data
213+
204214
# send data to each storage unit
205215
tasks = [
206216
self._put_to_single_storage_unit(
207-
meta_group.get_local_indexes(), _filter_storage_data(meta_group, data), target_storage_unit=storage_id
217+
meta_group.get_local_indexes(),
218+
_filter_storage_data(meta_group, results),
219+
target_storage_unit=storage_id,
208220
)
209221
for storage_id, meta_group in storage_meta_groups.items()
210222
]
@@ -221,8 +233,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
221233
per_field_shapes[global_idx] = {}
222234

223235
# For each field, extract dtype and shape for each sample
224-
for field in data.keys():
225-
for i, data_item in enumerate(data[field]):
236+
for field in results.keys():
237+
for i, data_item in enumerate(results[field]):
226238
global_idx = metadata.global_indexes[i]
227239
per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
228240
per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
@@ -234,7 +246,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
234246

235247
# notify controller that new data is ready
236248
await self.notify_data_update(
237-
partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
249+
partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
238250
)
239251

240252
@dynamic_storage_manager_socket(socket_name="put_get_socket")
@@ -432,20 +444,20 @@ def close(self) -> None:
432444
super().close()
433445

434446

435-
def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) -> dict[str, Any]:
436-
"""Filter batch-aligned data from a TensorDict using batch indexes from a StorageMetaGroup.
447+
def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]:
448+
"""Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup.
437449
This helper extracts a subset of items from each field in ``data`` according to the
438450
batch indexes stored in ``storage_meta_group``. The same indexes are applied to every
439-
field in the input ``TensorDict`` so that the returned samples remain aligned across
451+
field in the input dict so that the returned samples remain aligned across
440452
fields.
441453
442454
Args:
443455
storage_meta_group: A :class:`StorageMetaGroup` instance that provides
444456
a sequence of batch indexes via :meth:`get_batch_indexes`. Each index
445457
refers to a position along the batch dimension of the tensors stored
446458
in ``data``.
447-
data: A :class:`tensordict.TensorDict` containing batched data fields. All
448-
fields are expected to be indexable by the batch indexes returned by
459+
data: A dict containing batched data fields. All fields are expected to
460+
be indexable by the batch indexes returned by
449461
``storage_meta_group.get_batch_indexes()``.
450462
Returns:
451463
dict[str, Any]: A dictionary mapping each field name in ``data`` to a list
@@ -461,7 +473,9 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict)
461473
return results
462474

463475
for fname in data.keys():
464-
result = itemgetter(*batch_indexes)(data[fname])
476+
field_data = data[fname]
477+
result = itemgetter(*batch_indexes)(field_data)
478+
465479
if not isinstance(result, tuple):
466480
result = (result,)
467481
results[fname] = list(result)

tutorial/04_understanding_controller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def demonstrate_partition_isolation():
6969
train_data = TensorDict(
7070
{
7171
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
72-
"labels": torch.tensor([0, 1]),
72+
"labels": torch.tensor([[0], [1]]),
7373
},
7474
batch_size=2,
7575
)
@@ -81,7 +81,7 @@ def demonstrate_partition_isolation():
8181
val_data = TensorDict(
8282
{
8383
"input_ids": torch.tensor([[7, 8, 9], [10, 11, 12]]),
84-
"labels": torch.tensor([2, 3]),
84+
"labels": torch.tensor([[2], [3]]),
8585
},
8686
batch_size=2,
8787
)

0 commit comments

Comments
 (0)