Skip to content

Commit 54d919d

Browse files
committed
multi-thread support for concurrent data preprocess & transfer
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent d30829c commit 54d919d

1 file changed

Lines changed: 86 additions & 84 deletions

File tree

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 86 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import os
1818
import pickle
19+
from concurrent.futures import ThreadPoolExecutor, as_completed
1920
from typing import Any, Optional
2021

2122
import torch
@@ -35,7 +36,8 @@
3536
except ImportError:
3637
MOONCAKE_STORE_IMPORTED = False
3738

38-
BATCH_SIZE_LIMIT: int = 500
39+
BATCH_SIZE_LIMIT: int = 200
40+
MAX_WORKER_THREADS = 4
3941

4042

4143
@StorageClientFactory.register("MooncakeStoreClient")
@@ -81,7 +83,7 @@ def __init__(self, config: dict[str, Any]):
8183
self.metadata_server = self.metadata_server + "/metadata"
8284

8385
self.replica_config = ReplicateConfig()
84-
# FIXME: hard_pin is not supported yet
86+
# FIXME: hard_pin support
8587
# self.replica_config.with_hard_pin = True
8688

8789
self._store = MooncakeDistributedStore()
@@ -97,7 +99,7 @@ def __init__(self, config: dict[str, Any]):
9799
if ret != 0:
98100
raise RuntimeError(f"Mooncake store setup failed with error code: {ret}")
99101

100-
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
102+
def put(self, keys: list[str], values: list[Any]) -> None:
101103
"""Stores multiple key-value pairs to MooncakeStore.
102104
103105
Args:
@@ -121,43 +123,51 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
121123
tensor_values.append(value)
122124
else:
123125
non_tensor_keys.append(key)
124-
non_tensor_values.append(pickle.dumps(value))
126+
non_tensor_values.append(value)
125127

126-
if tensor_keys:
127-
self._batch_put_tensors(tensor_keys, tensor_values)
128+
futures = []
129+
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
130+
for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT):
131+
batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT]
132+
batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT]
133+
futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors))
128134

129-
if non_tensor_keys:
130-
self._batch_put_bytes(non_tensor_keys, non_tensor_values)
135+
for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT):
136+
batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT]
137+
batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT]
138+
futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values))
139+
140+
for future in as_completed(futures):
141+
future.result()
131142

132143
return None
133144

134-
def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]):
135-
for i in range(0, len(keys), BATCH_SIZE_LIMIT):
136-
batch_keys = keys[i : i + BATCH_SIZE_LIMIT]
137-
batch_tensors = tensors[i : i + BATCH_SIZE_LIMIT]
145+
def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]):
146+
"""Worker thread for putting batch of tensors to MooncakeStore."""
138147

139-
batch_ptrs, batch_sizes = self._preprocess_tensors_for_put(batch_tensors)
140-
batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes)
141-
self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced)
148+
batch_ptrs, batch_sizes, contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
149+
batch_ptr_reduced, batch_sizes_reduced = merge_continues_memory(batch_ptrs, batch_sizes)
150+
self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced)
142151

152+
try:
143153
results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config)
144154
if not all(r == 0 for r in results):
145155
failed_indices = [j for j, r in enumerate(results) if r != 0]
146156
error_codes = [results[j] for j in failed_indices]
147157
raise RuntimeError(
148158
f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}"
149159
)
150-
160+
finally:
151161
self._unregister_all_buffers(batch_ptr_reduced)
152162

153-
def _batch_put_bytes(self, keys: list[str], values: list[bytes]):
154-
for i in range(0, len(keys), BATCH_SIZE_LIMIT):
155-
batch_keys = keys[i : i + BATCH_SIZE_LIMIT]
156-
batch_values = values[i : i + BATCH_SIZE_LIMIT]
163+
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[bytes]):
164+
"""Worker thread for putting batch of non-tensors to MooncakeStore."""
157165

158-
ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
159-
if ret != 0:
160-
raise RuntimeError(f"put_batch failed with error code: {ret}")
166+
batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
167+
168+
ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
169+
if ret != 0:
170+
raise RuntimeError(f"put_batch failed with error code: {ret}")
161171

162172
def get(
163173
self,
@@ -194,71 +204,61 @@ def get(
194204

195205
results = [None] * len(keys)
196206

197-
if tensor_indices:
198-
tensor_keys = [keys[i] for i in tensor_indices]
199-
tensor_shapes = [shapes[i] for i in tensor_indices]
200-
tensor_dtypes = [dtypes[i] for i in tensor_indices]
201-
tensor_results = self._batch_get_tensors(tensor_keys, tensor_shapes, tensor_dtypes)
202-
# TODO: optimize these for loops
203-
for idx, tensor in zip(tensor_indices, tensor_results, strict=True):
204-
results[idx] = tensor
205-
206-
if non_tensor_indices:
207-
non_tensor_keys = [keys[i] for i in non_tensor_indices]
208-
non_tensor_results = self._batch_get_bytes(non_tensor_keys)
209-
for idx, data in zip(non_tensor_indices, non_tensor_results, strict=True):
210-
results[idx] = pickle.loads(data)
211-
212-
return results
213-
214-
def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> list[Tensor]:
215-
tensors = [None] * len(keys)
207+
futures = []
208+
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
209+
for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT):
210+
batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT]
211+
batch_keys = [keys[i] for i in batch_indexes]
212+
batch_shapes = [shapes[i] for i in batch_indexes]
213+
batch_dtypes = [dtypes[i] for i in batch_indexes]
214+
futures.append(
215+
executor.submit(
216+
self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes
217+
)
218+
)
216219

217-
for i in range(0, len(keys), BATCH_SIZE_LIMIT):
218-
batch_keys = keys[i : i + BATCH_SIZE_LIMIT]
219-
batch_shapes = shapes[i : i + BATCH_SIZE_LIMIT]
220-
batch_dtypes = dtypes[i : i + BATCH_SIZE_LIMIT]
220+
for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT):
221+
batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT]
222+
batch_keys = [keys[i] for i in batch_indexes]
223+
futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes))
221224

222-
batch_nbytes = get_nbytes(batch_dtypes, batch_shapes)
223-
batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes)
225+
for future in as_completed(futures):
226+
retrieved_values, batch_indexes = future.result()
227+
for idx, val in zip(batch_indexes, retrieved_values, strict=True):
228+
results[idx] = val
224229

225-
batch_ptrs = batch_buffer_ptrs
230+
return results
226231

227-
self._register_all_buffers(batch_ptrs, batch_nbytes)
228-
ret_codes = self._store.batch_get_into(batch_keys, batch_ptrs, batch_nbytes)
229-
self._unregister_all_buffers(batch_ptrs)
232+
def _get_tensors_thread_worker(
233+
self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int]
234+
) -> tuple[list[Tensor], list[int]]:
235+
batch_nbytes = get_nbytes(batch_dtypes, batch_shapes)
236+
batch_buffer_tensors, batch_buffer_ptrs = allocate_empty_tensors(batch_dtypes, batch_shapes)
230237

238+
self._register_all_buffers(batch_buffer_ptrs, batch_nbytes)
239+
try:
240+
ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes)
231241
if len(ret_codes) != len(batch_keys):
232242
raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}")
243+
for i, ret in enumerate(ret_codes):
244+
if ret < 0:
245+
raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}")
246+
finally:
247+
self._unregister_all_buffers(batch_buffer_ptrs)
233248

234-
# Check result codes and validate tensors
235-
# Note: Positive values indicate success (bytes read), negative values indicate error
236-
for j, (tensor, shape, dtype, ret_code) in enumerate(
237-
zip(batch_buffer_tensors, batch_shapes, batch_dtypes, ret_codes, strict=True)
238-
):
239-
if ret_code < 0:
240-
raise RuntimeError(f"batch_get_into failed for key '{batch_keys[j]}' with error code: {ret_code}")
241-
if tensor.shape != torch.Size(shape):
242-
raise RuntimeError(
243-
f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}"
244-
)
245-
if tensor.dtype != dtype:
246-
raise RuntimeError(
247-
f"Dtype mismatch for key '{batch_keys[j]}': expected {dtype}, got {tensor.dtype}"
248-
)
249-
tensors[i + j] = tensor
250-
251-
return tensors
249+
return batch_buffer_tensors, indexes
252250

253-
def _batch_get_bytes(self, keys: list[str]) -> list[bytes]:
251+
def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]:
254252
results = []
255-
for i in range(0, len(keys), BATCH_SIZE_LIMIT):
256-
batch_keys = keys[i : i + BATCH_SIZE_LIMIT]
257-
batch_results = self._store.get_batch(batch_keys)
258-
if len(batch_results) != len(batch_keys):
259-
raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}")
260-
results.extend(batch_results)
261-
return results
253+
254+
batch_results = self._store.get_batch(batch_keys)
255+
if len(batch_results) != len(batch_keys):
256+
raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}")
257+
258+
batch_results = [pickle.loads(result) for result in batch_results]
259+
results.extend(batch_results)
260+
261+
return results, indexes
262262

263263
def clear(self, keys: list[str], custom_backend_meta=None):
264264
"""Deletes multiple keys from MooncakeStore.
@@ -267,10 +267,10 @@ def clear(self, keys: list[str], custom_backend_meta=None):
267267
keys (List[str]): List of keys to remove.
268268
custom_backend_meta (List[Any], optional): ...
269269
"""
270-
rets = self._store.batch_remove(keys, force=True)
271-
for i, ret in enumerate(rets):
270+
ret_codes = self._store.batch_remove(keys, force=True)
271+
for i, ret in enumerate(ret_codes):
272272
if not (ret == 0 or ret == -704):
273-
logger.error(f"remove failed for key '{keys[i]}' with error code: {ret}")
273+
logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}")
274274

275275
def close(self):
276276
"""Closes MooncakeStore."""
@@ -279,17 +279,19 @@ def close(self):
279279
self._store = None
280280

281281
@staticmethod
282-
def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any]]:
282+
def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any], list[Tensor]]:
283283
ptr_list = []
284284
size_list = []
285+
tensor_list = [] # hold reference for the contiguous tensor
285286
for t in values:
286287
t = t.contiguous()
288+
tensor_list.append(t)
287289
ptr_list.append(t.data_ptr())
288290
size_list.append(t.nbytes)
289-
return ptr_list, size_list
291+
return ptr_list, size_list, tensor_list
290292

291293
def _register_all_buffers(self, ptrs, sizes):
292-
for ptr, size in zip(ptrs, sizes, strict=False):
294+
for ptr, size in zip(ptrs, sizes, strict=True):
293295
self._store.register_buffer(ptr, size)
294296

295297
def _unregister_all_buffers(self, ptrs):

0 commit comments

Comments
 (0)