Skip to content

Commit c4c4198

Browse files
committed
fix
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> fix Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> fix Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent c2cd296 commit c4c4198

3 files changed

Lines changed: 236 additions & 24 deletions

File tree

tests/test_tensor_utils.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for transfer_queue.utils.tensor_utils."""
17+
18+
import pytest
19+
import torch
20+
21+
from transfer_queue.utils.tensor_utils import (
22+
allocate_empty_tensors,
23+
compute_stride,
24+
get_nbytes,
25+
merge_contiguous_memory,
26+
)
27+
28+
29+
class TestComputeStride:
30+
"""Tests for compute_stride."""
31+
32+
def test_3d(self):
33+
assert compute_stride((2, 3, 4)) == (12, 4, 1)
34+
35+
def test_1d(self):
36+
assert compute_stride((5,)) == (1,)
37+
38+
def test_scalar(self):
39+
assert compute_stride(()) == ()
40+
41+
def test_2d(self):
42+
assert compute_stride((3, 5)) == (5, 1)
43+
44+
45+
class TestGetNbytes:
46+
"""Tests for get_nbytes."""
47+
48+
def test_basic(self):
49+
dtypes = [torch.float32, torch.int32]
50+
shapes = [(2, 3), (4,)]
51+
result = get_nbytes(dtypes, shapes)
52+
assert result == [2 * 3 * 4, 4 * 4] # float32=4, int32=4
53+
54+
def test_scalar(self):
55+
dtypes = [torch.float64]
56+
shapes = [()]
57+
result = get_nbytes(dtypes, shapes)
58+
assert result == [8] # scalar = 1 element
59+
60+
def test_list_shape(self):
61+
dtypes = [torch.float32]
62+
shapes = [[]] # list instead of tuple
63+
result = get_nbytes(dtypes, shapes)
64+
assert result == [4]
65+
66+
def test_mixed_dtypes(self):
67+
dtypes = [torch.float16, torch.float32, torch.int64]
68+
shapes = [(10,), (10,), (10,)]
69+
result = get_nbytes(dtypes, shapes)
70+
assert result == [10 * 2, 10 * 4, 10 * 8]
71+
72+
73+
class TestAllocateEmptyTensors:
74+
"""Tests for allocate_empty_tensors."""
75+
76+
def test_basic(self):
77+
dtypes = [torch.float32, torch.float32, torch.int32]
78+
shapes = [(2, 3), (4,), (5,)]
79+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
80+
81+
assert len(tensors) == 3
82+
assert len(ptrs) == 3
83+
assert len(region_ptrs) == 2 # float32 group + int32 group
84+
assert len(region_sizes) == 2
85+
86+
# Same dtype tensors share the same underlying storage
87+
assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0]
88+
assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0]
89+
assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1]
90+
91+
# Shapes are correct
92+
assert list(tensors[0].shape) == [2, 3]
93+
assert list(tensors[1].shape) == [4]
94+
assert list(tensors[2].shape) == [5]
95+
96+
def test_scalar(self):
97+
dtypes = [torch.float32, torch.int32]
98+
shapes = [(), ()]
99+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
100+
101+
assert len(tensors) == 2
102+
assert tensors[0].numel() == 1
103+
assert tensors[1].numel() == 1
104+
assert len(region_ptrs) == 2
105+
106+
def test_empty(self):
107+
result = allocate_empty_tensors([], [])
108+
assert result == ([], [], [], [])
109+
110+
def test_regions_complex(self):
111+
"""Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets."""
112+
dtypes = [
113+
torch.float32, # group 0: (2, 3) -> 6 elements
114+
torch.int32, # group 1: (4,) -> 4 elements
115+
torch.float32, # group 0: scalar -> 1 element
116+
torch.float64, # group 2: (2, 2) -> 4 elements
117+
torch.int32, # group 1: (3, 2) -> 6 elements
118+
]
119+
shapes = [(2, 3), (4,), (), (2, 2), (3, 2)]
120+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
121+
122+
# 3 dtype groups in insertion order: float32, int32, float64
123+
assert len(region_ptrs) == 3
124+
assert len(region_sizes) == 3
125+
assert len(set(region_ptrs)) == 3 # distinct allocations
126+
127+
# float32 region: 6 + 1 = 7 elements * 4 bytes = 28 bytes
128+
assert region_sizes[0] == 7 * 4
129+
# int32 region: 4 + 6 = 10 elements * 4 bytes = 40 bytes
130+
assert region_sizes[1] == 10 * 4
131+
# float64 region: 4 elements * 8 bytes = 32 bytes
132+
assert region_sizes[2] == 4 * 8
133+
134+
# Per-tensor ptrs must lie inside their respective regions
135+
# tensor 0 (float32, shape (2,3), offset 0)
136+
assert ptrs[0] == region_ptrs[0]
137+
# tensor 1 (int32, shape (4,), offset 0)
138+
assert ptrs[1] == region_ptrs[1]
139+
# tensor 2 (float32, scalar, offset 6)
140+
assert ptrs[2] == region_ptrs[0] + 6 * 4
141+
# tensor 3 (float64, shape (2,2), offset 0)
142+
assert ptrs[3] == region_ptrs[2]
143+
# tensor 4 (int32, shape (3,2), offset 4)
144+
assert ptrs[4] == region_ptrs[1] + 4 * 4
145+
146+
147+
class TestMergeContiguousMemory:
148+
"""Tests for merge_contiguous_memory."""
149+
150+
def test_basic_merge(self):
151+
ptrs = [0, 10, 30]
152+
sizes = [10, 20, 10]
153+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
154+
# 0+10=10 (contiguous with 10), 10+20=30 (contiguous with 30) -> all merge into [0]
155+
assert merged_ptrs == [0]
156+
assert merged_sizes == [40]
157+
158+
def test_no_contiguous(self):
159+
ptrs = [0, 100, 200]
160+
sizes = [50, 50, 50]
161+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
162+
assert merged_ptrs == [0, 100, 200]
163+
assert merged_sizes == [50, 50, 50]
164+
165+
def test_unsorted_input(self):
166+
ptrs = [100, 0, 50]
167+
sizes = [50, 50, 50]
168+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
169+
# After sorting: 0, 50, 100; all contiguous -> merge into [0]
170+
assert merged_ptrs == [0]
171+
assert merged_sizes == [150]
172+
173+
def test_single_region(self):
174+
ptrs = [10]
175+
sizes = [100]
176+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
177+
assert merged_ptrs == [10]
178+
assert merged_sizes == [100]
179+
180+
def test_empty(self):
181+
assert merge_contiguous_memory([], []) == ([], [])
182+
183+
def test_mismatched_lengths_both_empty_not_triggered(self):
184+
# If one is empty and other is not, should raise ValueError
185+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
186+
merge_contiguous_memory([], [10])
187+
188+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
189+
merge_contiguous_memory([0], [])
190+
191+
def test_three_continuous(self):
192+
ptrs = [0, 10, 20]
193+
sizes = [10, 10, 10]
194+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
195+
assert merged_ptrs == [0]
196+
assert merged_sizes == [30]

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
2626
from transfer_queue.storage.clients.factory import StorageClientFactory
27-
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_continues_memory
27+
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory
2828

2929
logger = logging.getLogger(__name__)
3030
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -144,8 +144,8 @@ def put(self, keys: list[str], values: list[Any]) -> None:
144144
def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]):
145145
"""Worker thread for putting batch of tensors to MooncakeStore."""
146146

147-
batch_ptrs, batch_sizes, contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
148-
batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes)
147+
batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
148+
batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes)
149149
self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced)
150150

151151
try:
@@ -154,19 +154,19 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
154154
failed_indices = [j for j, r in enumerate(results) if r != 0]
155155
error_codes = [results[j] for j in failed_indices]
156156
raise RuntimeError(
157-
f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}"
157+
f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}"
158158
)
159159
finally:
160160
self._unregister_all_buffers(batch_ptr_reduced)
161161

162-
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[bytes]):
162+
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]):
163163
"""Worker thread for putting batch of non-tensors to MooncakeStore."""
164164

165165
batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
166166

167167
ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
168168
if ret != 0:
169-
raise RuntimeError(f"put_batch failed with error code: {ret}")
169+
raise RuntimeError(f"upsert_batch failed with error code: {ret}")
170170

171171
def get(
172172
self,
@@ -232,9 +232,11 @@ def _get_tensors_thread_worker(
232232
self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int]
233233
) -> tuple[list[Tensor], list[int]]:
234234
batch_nbytes = get_nbytes(batch_dtypes, batch_shapes)
235-
batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes)
235+
batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors(
236+
batch_dtypes, batch_shapes
237+
)
236238

237-
self._register_all_buffers(batch_buffer_ptrs, batch_nbytes)
239+
self._register_all_buffers(region_ptrs, region_sizes)
238240
try:
239241
ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes)
240242
if len(ret_codes) != len(batch_keys):
@@ -243,7 +245,7 @@ def _get_tensors_thread_worker(
243245
if ret < 0:
244246
raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}")
245247
finally:
246-
self._unregister_all_buffers(batch_buffer_ptrs)
248+
self._unregister_all_buffers(region_ptrs)
247249

248250
return batch_buffer_tensors, indexes
249251

@@ -283,6 +285,11 @@ def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[A
283285
size_list = []
284286
tensor_list = [] # hold reference for the contiguous tensor
285287
for t in values:
288+
# TODO: support gpu direct rdma and use different data paths.
289+
# For GPU, it's more reasonable to perform data copy since
290+
# The register overhead is much higher than CPU
291+
if t.device.type == "cuda":
292+
t = t.cpu()
286293
t = t.contiguous()
287294
tensor_list.append(t)
288295
ptr_list.append(t.data_ptr())

transfer_queue/utils/tensor_utils.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
2626

2727

28-
def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tuple[list[Tensor], list[int]]:
28+
def allocate_empty_tensors(
29+
dtypes: list[torch.dtype], shapes: list[tuple]
30+
) -> tuple[list[Tensor], list[int], list[int], list[int]]:
2931
"""Allocate empty tensors, grouping same dtypes into shared memory blocks.
3032
3133
Instead of allocating each tensor separately, this function groups tensors
@@ -40,17 +42,19 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
4042
A tuple containing:
4143
- List of tensors sharing memory within their dtype groups.
4244
- List of memory pointers (data_ptr) for each tensor.
45+
- List of base pointers for each allocated memory region (one per dtype).
46+
- List of total bytes for each allocated memory region (one per dtype).
4347
4448
Example:
4549
>>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32]
4650
>>> shapes = [(10,), (20,), (5,), (15,)]
47-
>>> tensors, ptrs = allocate_empty_tensors(dtypes, shapes)
51+
>>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
4852
>>> # tensors[0], [1], [3] share the same dtype and memory block
4953
"""
5054
assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length"
5155

5256
if len(dtypes) == 0:
53-
return [], []
57+
return [], [], [], []
5458

5559
# Group indices by dtype
5660
dtype_groups: dict[torch.dtype, list[int]] = {}
@@ -61,6 +65,8 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6165

6266
tensor_list = [torch.empty(()) for _ in range(len(dtypes))]
6367
ptr_list = [0] * len(dtypes)
68+
region_ptrs: list[int] = []
69+
region_sizes: list[int] = []
6470

6571
# For each dtype group, allocate one big tensor and create views
6672
for dtype, indices in dtype_groups.items():
@@ -69,13 +75,15 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6975
shape_info = [] # Store (index, shape, num_elements, offset)
7076

7177
for idx in indices:
72-
shape = shapes[idx]
73-
num_elements = reduce(operator.mul, shape)
78+
shape = tuple(shapes[idx])
79+
num_elements = reduce(operator.mul, shape, 1)
7480
shape_info.append((idx, shape, num_elements, total_elements))
7581
total_elements += num_elements
7682

7783
# Allocate one big contiguous memory block for this dtype
7884
big_tensor = torch.empty(total_elements, dtype=dtype)
85+
region_ptrs.append(big_tensor.data_ptr())
86+
region_sizes.append(big_tensor.nbytes)
7987

8088
# Create views into the big tensor for each small tensor
8189
for idx, shape, num_elements, offset in shape_info:
@@ -84,7 +92,7 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
8492
tensor_list[idx] = small_tensor
8593
ptr_list[idx] = small_tensor.data_ptr()
8694

87-
return tensor_list, ptr_list
95+
return tensor_list, ptr_list, region_ptrs, region_sizes
8896

8997

9098
def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]:
@@ -115,36 +123,37 @@ def get_nbytes(dtypes, shapes) -> list[int]:
115123
nbytes = []
116124
for i in range(len(dtypes)):
117125
elem_size = torch.tensor([], dtype=dtypes[i]).element_size()
118-
numel = reduce(operator.mul, shapes[i])
126+
shape = tuple(shapes[i])
127+
numel = reduce(operator.mul, shape, 1)
119128
nbytes.append(elem_size * numel)
120129

121130
return nbytes
122131

123132

124-
def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
125-
"""Merge continuous memory regions to reduce register_buffer overhead
133+
def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
134+
"""Merge contiguous memory regions to reduce register_buffer overhead
126135
127136
Args:
128137
ptrs: List of memory pointers (starting addresses).
129138
sizes: List of memory region sizes corresponding to each pointer.
130139
131140
Returns:
132-
A tuple of (merged_ptrs, merged_sizes) where continuous regions
141+
A tuple of (merged_ptrs, merged_sizes) where contiguous regions
133142
have been merged into single regions.
134143
135144
Example:
136-
>>> merge_continues_memory([0, 10, 30], [10, 20, 10])
145+
>>> merge_contiguous_memory([0, 10, 30], [10, 20, 10])
137146
([0, 30], [30, 10])
138147
139-
>>> merge_continues_memory([0, 5, 20], [5, 5, 10])
148+
>>> merge_contiguous_memory([0, 5, 20], [5, 5, 10])
140149
([0, 20], [10, 10])
141150
"""
142-
if not ptrs or not sizes:
143-
return [], []
144-
145151
if len(ptrs) != len(sizes):
146152
raise ValueError("ptrs and sizes must have the same length")
147153

154+
if not ptrs:
155+
return [], []
156+
148157
# Create list of (ptr, size) pairs and sort by pointer address
149158
regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0])
150159

0 commit comments

Comments
 (0)