Skip to content

Commit b266d39

Browse files
[perf] Refactor MooncakeStore backend with zero-copy upsert API (#77)
1. Switch to `batch_upsert_from` & `batch_get_into` for tensor 2. Switch to `upsert_batch` & `get_batch` for non-tensor 3. Use `batch_remove` API for data clearning 4. Set hard-pin flag during data writting 5. Use multi-thread to optimize data preparation & transfer workflow --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Signed-off-by: ji-huazhong <hzji210@gmail.com> Co-authored-by: Huazhong <hzji210@gmail.com>
1 parent 0a5176b commit b266d39

7 files changed

Lines changed: 543 additions & 120 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ yuanrong = [
118118
"openyuanrong-datasystem"
119119
]
120120
mooncake = [
121-
"mooncake-transfer-engine==0.3.10.post1"
121+
"mooncake-transfer-engine==0.3.10.post2"
122122
]
123123

124124
# If you need to mimic `package_dir={'': '.'}`:

tests/e2e/test_kv_interface_e2e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def test_kv_put_with_dict_fields(self, controller, tq_api):
212212
expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
213213
assert_tensor_equal(retrieved["data"], expected)
214214

215-
# delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test)
216215
tq_api.kv_clear(keys=key, partition_id=partition_id)
217216

218217
def test_kv_put_with_tensordict_fields(self, controller, tq_api):

tests/test_tensor_utils.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for transfer_queue.utils.tensor_utils."""
17+
18+
import pytest
19+
import torch
20+
21+
from transfer_queue.utils.tensor_utils import (
22+
allocate_empty_tensors,
23+
compute_stride,
24+
get_nbytes,
25+
merge_contiguous_memory,
26+
)
27+
28+
29+
class TestComputeStride:
30+
"""Tests for compute_stride."""
31+
32+
def test_3d(self):
33+
assert compute_stride((2, 3, 4)) == (12, 4, 1)
34+
35+
def test_1d(self):
36+
assert compute_stride((5,)) == (1,)
37+
38+
def test_scalar(self):
39+
assert compute_stride(()) == ()
40+
41+
def test_2d(self):
42+
assert compute_stride((3, 5)) == (5, 1)
43+
44+
45+
class TestGetNbytes:
46+
"""Tests for get_nbytes."""
47+
48+
def test_basic(self):
49+
dtypes = [torch.float32, torch.int32]
50+
shapes = [(2, 3), (4,)]
51+
result = get_nbytes(dtypes, shapes)
52+
assert result == [2 * 3 * 4, 4 * 4] # float32=4, int32=4
53+
54+
def test_scalar(self):
55+
dtypes = [torch.float64]
56+
shapes = [()]
57+
result = get_nbytes(dtypes, shapes)
58+
assert result == [8] # scalar = 1 element
59+
60+
def test_list_shape(self):
61+
dtypes = [torch.float32]
62+
shapes = [[]] # list instead of tuple
63+
result = get_nbytes(dtypes, shapes)
64+
assert result == [4]
65+
66+
def test_mixed_dtypes(self):
67+
dtypes = [torch.float16, torch.float32, torch.int64]
68+
shapes = [(10,), (10,), (10,)]
69+
result = get_nbytes(dtypes, shapes)
70+
assert result == [10 * 2, 10 * 4, 10 * 8]
71+
72+
73+
class TestAllocateEmptyTensors:
74+
"""Tests for allocate_empty_tensors."""
75+
76+
def test_basic(self):
77+
dtypes = [torch.float32, torch.float32, torch.int32]
78+
shapes = [(2, 3), (4,), (5,)]
79+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
80+
81+
assert len(tensors) == 3
82+
assert len(ptrs) == 3
83+
assert len(region_ptrs) == 2 # float32 group + int32 group
84+
assert len(region_sizes) == 2
85+
86+
# Same dtype tensors share the same underlying storage
87+
assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0]
88+
assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0]
89+
assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1]
90+
91+
# Shapes are correct
92+
assert list(tensors[0].shape) == [2, 3]
93+
assert list(tensors[1].shape) == [4]
94+
assert list(tensors[2].shape) == [5]
95+
96+
def test_scalar(self):
97+
dtypes = [torch.float32, torch.int32]
98+
shapes = [(), ()]
99+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
100+
101+
assert len(tensors) == 2
102+
assert tensors[0].numel() == 1
103+
assert tensors[1].numel() == 1
104+
assert len(region_ptrs) == 2
105+
106+
def test_empty(self):
107+
result = allocate_empty_tensors([], [])
108+
assert result == ([], [], [], [])
109+
110+
def test_regions_complex(self):
111+
"""Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets."""
112+
dtypes = [
113+
torch.float32, # group 0: (2, 3) -> 6 elements
114+
torch.int32, # group 1: (4,) -> 4 elements
115+
torch.float32, # group 0: scalar -> 1 element
116+
torch.float64, # group 2: (2, 2) -> 4 elements
117+
torch.int32, # group 1: (3, 2) -> 6 elements
118+
]
119+
shapes = [(2, 3), (4,), (), (2, 2), (3, 2)]
120+
tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
121+
122+
# 3 dtype groups in insertion order: float32, int32, float64
123+
assert len(region_ptrs) == 3
124+
assert len(region_sizes) == 3
125+
assert len(set(region_ptrs)) == 3 # distinct allocations
126+
127+
# float32 region: 6 + 1 = 7 elements * 4 bytes = 28 bytes
128+
assert region_sizes[0] == 7 * 4
129+
# int32 region: 4 + 6 = 10 elements * 4 bytes = 40 bytes
130+
assert region_sizes[1] == 10 * 4
131+
# float64 region: 4 elements * 8 bytes = 32 bytes
132+
assert region_sizes[2] == 4 * 8
133+
134+
# Per-tensor ptrs must lie inside their respective regions
135+
# tensor 0 (float32, shape (2,3), offset 0)
136+
assert ptrs[0] == region_ptrs[0]
137+
# tensor 1 (int32, shape (4,), offset 0)
138+
assert ptrs[1] == region_ptrs[1]
139+
# tensor 2 (float32, scalar, offset 6)
140+
assert ptrs[2] == region_ptrs[0] + 6 * 4
141+
# tensor 3 (float64, shape (2,2), offset 0)
142+
assert ptrs[3] == region_ptrs[2]
143+
# tensor 4 (int32, shape (3,2), offset 4)
144+
assert ptrs[4] == region_ptrs[1] + 4 * 4
145+
146+
147+
class TestMergeContiguousMemory:
148+
"""Tests for merge_contiguous_memory."""
149+
150+
def test_basic_merge(self):
151+
ptrs = [0, 10, 30]
152+
sizes = [10, 20, 10]
153+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
154+
# 0+10=10 (contiguous with 10), 10+20=30 (contiguous with 30) -> all merge into [0]
155+
assert merged_ptrs == [0]
156+
assert merged_sizes == [40]
157+
158+
def test_no_contiguous(self):
159+
ptrs = [0, 100, 200]
160+
sizes = [50, 50, 50]
161+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
162+
assert merged_ptrs == [0, 100, 200]
163+
assert merged_sizes == [50, 50, 50]
164+
165+
def test_unsorted_input(self):
166+
ptrs = [100, 0, 50]
167+
sizes = [50, 50, 50]
168+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
169+
# After sorting: 0, 50, 100; all contiguous -> merge into [0]
170+
assert merged_ptrs == [0]
171+
assert merged_sizes == [150]
172+
173+
def test_single_region(self):
174+
ptrs = [10]
175+
sizes = [100]
176+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
177+
assert merged_ptrs == [10]
178+
assert merged_sizes == [100]
179+
180+
def test_empty(self):
181+
assert merge_contiguous_memory([], []) == ([], [])
182+
183+
def test_mismatched_lengths_both_empty_not_triggered(self):
184+
# If one is empty and other is not, should raise ValueError
185+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
186+
merge_contiguous_memory([], [10])
187+
188+
with pytest.raises(ValueError, match="ptrs and sizes must have the same length"):
189+
merge_contiguous_memory([0], [])
190+
191+
def test_three_continuous(self):
192+
ptrs = [0, 10, 20]
193+
sizes = [10, 10, 10]
194+
merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes)
195+
assert merged_ptrs == [0]
196+
assert merged_sizes == [30]

0 commit comments

Comments
 (0)