Skip to content

Commit f655c51

Browse files
authored
Speed up named data map + add uint32 support
Differential Revision: D92447071 Pull Request resolved: #17257
1 parent 7823792 commit f655c51

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

exir/_serialize/_named_data_store.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,32 @@
77
# pyre-strict
88

99
import hashlib
10-
1110
from dataclasses import dataclass
1211
from typing import Dict, List, Optional, Union
1312

1413
import torch
15-
1614
from executorch.exir._serialize.data_serializer import DataEntry
1715
from executorch.exir.tensor_layout import TensorLayout
1816

1917

18+
def _tensor_to_bytes(tensor: torch.Tensor) -> bytes:
19+
"""Convert tensor to bytes using the fastest method available.
20+
21+
Uses numpy().tobytes() which is faster than bytes(untyped_storage())
22+
for C-contiguous tensors. Falls back to untyped_storage() for
23+
non-contiguous tensors (e.g., channels_last) to preserve memory layout.
24+
"""
25+
if not tensor.is_contiguous():
26+
# For non-C-contiguous tensors (e.g., channels_last), use untyped_storage
27+
# to preserve the actual memory layout
28+
return bytes(tensor.untyped_storage())
29+
if tensor.dtype == torch.bfloat16:
30+
# BFloat16 is not supported by numpy, extract raw bytes via view
31+
return tensor.view(torch.uint16).numpy().tobytes()
32+
else:
33+
return tensor.numpy().tobytes()
34+
35+
2036
@dataclass
2137
class NamedDataStoreOutput:
2238
"""
@@ -169,7 +185,7 @@ def add_named_data(
169185
f"Tensor {key} is a torch.Tensor, with tensor_layout {real_tensor_layout}. The provided tensor layout {tensor_layout} does not match."
170186
)
171187
tensor_layout = real_tensor_layout
172-
byte_data = bytes(data.untyped_storage())
188+
byte_data = _tensor_to_bytes(data)
173189
else:
174190
byte_data = data
175191

exir/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
289289
torch.bfloat16: ScalarType.BFLOAT16,
290290
torch.quint4x2: ScalarType.QUINT4x2,
291291
torch.uint16: ScalarType.UINT16,
292+
torch.uint32: ScalarType.UINT32,
292293
}
293294

294295

0 commit comments

Comments
 (0)