Skip to content

Commit 20d3ba7

Browse files
committed
fix
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent c2cd296 commit 20d3ba7

2 files changed

Lines changed: 49 additions & 22 deletions

File tree

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
2626
from transfer_queue.storage.clients.factory import StorageClientFactory
27-
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_continues_memory
27+
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory
2828

2929
logger = logging.getLogger(__name__)
3030
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -145,7 +145,7 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
145145
"""Worker thread for putting batch of tensors to MooncakeStore."""
146146

147147
batch_ptrs, batch_sizes, contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
148-
batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes)
148+
batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes)
149149
self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced)
150150

151151
try:
@@ -159,14 +159,14 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
159159
finally:
160160
self._unregister_all_buffers(batch_ptr_reduced)
161161

162-
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[bytes]):
162+
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]):
163163
"""Worker thread for putting batch of non-tensors to MooncakeStore."""
164164

165165
batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
166166

167167
ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
168168
if ret != 0:
169-
raise RuntimeError(f"put_batch failed with error code: {ret}")
169+
raise RuntimeError(f"upsert_batch failed with error code: {ret}")
170170

171171
def get(
172172
self,
@@ -232,9 +232,11 @@ def _get_tensors_thread_worker(
232232
self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int]
233233
) -> tuple[list[Tensor], list[int]]:
234234
batch_nbytes = get_nbytes(batch_dtypes, batch_shapes)
235-
batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes)
235+
batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors(
236+
batch_dtypes, batch_shapes
237+
)
236238

237-
self._register_all_buffers(batch_buffer_ptrs, batch_nbytes)
239+
self._register_all_buffers(region_ptrs, region_sizes)
238240
try:
239241
ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes)
240242
if len(ret_codes) != len(batch_keys):
@@ -243,7 +245,7 @@ def _get_tensors_thread_worker(
243245
if ret < 0:
244246
raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}")
245247
finally:
246-
self._unregister_all_buffers(batch_buffer_ptrs)
248+
self._unregister_all_buffers(region_ptrs)
247249

248250
return batch_buffer_tensors, indexes
249251

@@ -283,6 +285,11 @@ def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[A
283285
size_list = []
284286
tensor_list = [] # hold reference for the contiguous tensor
285287
for t in values:
288+
# TODO: support gpu direct rdma and use different data paths.
289+
# For GPU, it's more reasonable to perform data copy since
290+
# The register overhead is much higher than CPU
291+
if t.device.type == "cuda":
292+
t = t.cpu()
286293
t = t.contiguous()
287294
tensor_list.append(t)
288295
ptr_list.append(t.data_ptr())

transfer_queue/utils/tensor_utils.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import operator
1818
import os
19+
import warnings
1920
from functools import reduce
2021

2122
import torch
@@ -25,7 +26,9 @@
2526
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
2627

2728

28-
def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tuple[list[Tensor], list[int]]:
29+
def allocate_empty_tensors(
30+
dtypes: list[torch.dtype], shapes: list[tuple]
31+
) -> tuple[list[Tensor], list[int], list[int], list[int]]:
2932
"""Allocate empty tensors, grouping same dtypes into shared memory blocks.
3033
3134
Instead of allocating each tensor separately, this function groups tensors
@@ -40,17 +43,19 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
4043
A tuple containing:
4144
- List of tensors sharing memory within their dtype groups.
4245
- List of memory pointers (data_ptr) for each tensor.
46+
- List of base pointers for each allocated memory region (one per dtype).
47+
- List of total bytes for each allocated memory region (one per dtype).
4348
4449
Example:
4550
>>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32]
4651
>>> shapes = [(10,), (20,), (5,), (15,)]
47-
>>> tensors, ptrs = allocate_empty_tensors(dtypes, shapes)
52+
>>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
4853
>>> # tensors[0], [1], [3] share the same dtype and memory block
4954
"""
5055
assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length"
5156

5257
if len(dtypes) == 0:
53-
return [], []
58+
return [], [], [], []
5459

5560
# Group indices by dtype
5661
dtype_groups: dict[torch.dtype, list[int]] = {}
@@ -61,6 +66,8 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6166

6267
tensor_list = [torch.empty(()) for _ in range(len(dtypes))]
6368
ptr_list = [0] * len(dtypes)
69+
region_ptrs: list[int] = []
70+
region_sizes: list[int] = []
6471

6572
# For each dtype group, allocate one big tensor and create views
6673
for dtype, indices in dtype_groups.items():
@@ -69,13 +76,15 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
6976
shape_info = [] # Store (index, shape, num_elements, offset)
7077

7178
for idx in indices:
72-
shape = shapes[idx]
73-
num_elements = reduce(operator.mul, shape)
79+
shape = tuple(shapes[idx])
80+
num_elements = reduce(operator.mul, shape, 1)
7481
shape_info.append((idx, shape, num_elements, total_elements))
7582
total_elements += num_elements
7683

7784
# Allocate one big contiguous memory block for this dtype
7885
big_tensor = torch.empty(total_elements, dtype=dtype)
86+
region_ptrs.append(big_tensor.data_ptr())
87+
region_sizes.append(big_tensor.nbytes)
7988

8089
# Create views into the big tensor for each small tensor
8190
for idx, shape, num_elements, offset in shape_info:
@@ -84,7 +93,7 @@ def allocate_empty_tensors(dtypes: list[torch.dtype], shapes: list[tuple]) -> tu
8493
tensor_list[idx] = small_tensor
8594
ptr_list[idx] = small_tensor.data_ptr()
8695

87-
return tensor_list, ptr_list
96+
return tensor_list, ptr_list, region_ptrs, region_sizes
8897

8998

9099
def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]:
@@ -115,36 +124,37 @@ def get_nbytes(dtypes, shapes) -> list[int]:
115124
nbytes = []
116125
for i in range(len(dtypes)):
117126
elem_size = torch.tensor([], dtype=dtypes[i]).element_size()
118-
numel = reduce(operator.mul, shapes[i])
127+
shape = tuple(shapes[i])
128+
numel = reduce(operator.mul, shape, 1)
119129
nbytes.append(elem_size * numel)
120130

121131
return nbytes
122132

123133

124-
def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
125-
"""Merge continuous memory regions to reduce register_buffer overhead
134+
def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
135+
"""Merge contiguous memory regions to reduce register_buffer overhead
126136
127137
Args:
128138
ptrs: List of memory pointers (starting addresses).
129139
sizes: List of memory region sizes corresponding to each pointer.
130140
131141
Returns:
132-
A tuple of (merged_ptrs, merged_sizes) where continuous regions
142+
A tuple of (merged_ptrs, merged_sizes) where contiguous regions
133143
have been merged into single regions.
134144
135145
Example:
136-
>>> merge_continues_memory([0, 10, 30], [10, 20, 10])
146+
>>> merge_contiguous_memory([0, 10, 30], [10, 20, 10])
137147
([0, 30], [30, 10])
138148
139-
>>> merge_continues_memory([0, 5, 20], [5, 5, 10])
149+
>>> merge_contiguous_memory([0, 5, 20], [5, 5, 10])
140150
([0, 20], [10, 10])
141151
"""
142-
if not ptrs or not sizes:
143-
return [], []
144-
145152
if len(ptrs) != len(sizes):
146153
raise ValueError("ptrs and sizes must have the same length")
147154

155+
if not ptrs:
156+
return [], []
157+
148158
# Create list of (ptr, size) pairs and sort by pointer address
149159
regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0])
150160

@@ -171,3 +181,13 @@ def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int]
171181
merged_sizes.append(current_size)
172182

173183
return merged_ptrs, merged_sizes
184+
185+
186+
def merge_continues_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
187+
"""Deprecated alias for :func:`merge_contiguous_memory`."""
188+
warnings.warn(
189+
"merge_continues_memory is deprecated, use merge_contiguous_memory instead",
190+
DeprecationWarning,
191+
stacklevel=2,
192+
)
193+
return merge_contiguous_memory(ptrs, sizes)

0 commit comments

Comments
 (0)