|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | import hashlib |
10 | | - |
11 | 10 | from dataclasses import dataclass |
12 | 11 | from typing import Dict, List, Optional, Union |
13 | 12 |
|
14 | 13 | import torch |
15 | | - |
16 | 14 | from executorch.exir._serialize.data_serializer import DataEntry |
17 | 15 | from executorch.exir.tensor_layout import TensorLayout |
18 | 16 |
|
19 | 17 |
|
| 18 | +def _tensor_to_bytes(tensor: torch.Tensor) -> bytes: |
| 19 | + """Convert tensor to bytes using the fastest method available. |
| 20 | +
|
| 21 | + Uses numpy().tobytes() which is faster than bytes(untyped_storage()) |
| 22 | + for C-contiguous tensors. Falls back to untyped_storage() for |
| 23 | + non-contiguous tensors (e.g., channels_last) to preserve memory layout. |
| 24 | + """ |
| 25 | + if not tensor.is_contiguous(): |
| 26 | + # For non-C-contiguous tensors (e.g., channels_last), use untyped_storage |
| 27 | + # to preserve the actual memory layout |
| 28 | + return bytes(tensor.untyped_storage()) |
| 29 | + if tensor.dtype == torch.bfloat16: |
| 30 | + # BFloat16 is not supported by numpy, extract raw bytes via view |
| 31 | + return tensor.view(torch.uint16).numpy().tobytes() |
| 32 | + else: |
| 33 | + return tensor.numpy().tobytes() |
| 34 | + |
| 35 | + |
20 | 36 | @dataclass |
21 | 37 | class NamedDataStoreOutput: |
22 | 38 | """ |
@@ -169,7 +185,7 @@ def add_named_data( |
169 | 185 | f"Tensor {key} is a torch.Tensor, with tensor_layout {real_tensor_layout}. The provided tensor layout {tensor_layout} does not match." |
170 | 186 | ) |
171 | 187 | tensor_layout = real_tensor_layout |
172 | | - byte_data = bytes(data.untyped_storage()) |
| 188 | + byte_data = _tensor_to_bytes(data) |
173 | 189 | else: |
174 | 190 | byte_data = data |
175 | 191 |
|
|
0 commit comments