Skip to content

Commit d11f808

Browse files
authored
[feat] Add retransmission mechanism for MooncakeStoreClient (#94)
## Background When using `MooncakeStore` as TQ's backend, we observe occasional transmission errors during verl e2e runs: ``` E0508 17:18:06.011560 731271 tcp_transport.cpp:708] TcpTransport::getConnection failed to create connection to 61.28.30.25:16181. Error: connect: Connection timed out E0508 17:18:06.011600 731277 tcp_transport.cpp:886] TcpTransport::startTransfer failed to get connection to 61.28.30.25:15816 E0508 17:18:06.011888 731271 transfer_task.cpp:281] Batch 281200032997056 completed with task failures: task_ids=[0] E0508 17:18:06.011895 731271 client_service.cpp:1100] Transfer failed for key: 68108@uid with error: -800 E0508 17:18:06.011996 731271 real_client.cpp:2253] BatchGet failed for key '68108@uid': TRANSFER_FAIL ``` These `Connection timed out` / `TRANSFER_FAIL` (`error: -800`) errors are **transient network issues** that typically resolve on a subsequent attempt. However, the previous client implementation had no retry logic whatsoever: - On the **tensor path**, any single key returning a negative status code would trigger an immediate `RuntimeError`, failing the entire batch and crashing the training job. - On the **bytes path**, the failure was far worse: `get_batch` returns `b""` for keys that encountered a transfer failure, and the client blindly passed these empty bytes through `pickle.loads(... if result != b"" else None)`, treating them as legitimate `None` values. **This leads to silent content corruption.** A training worker could proceed with corrupted or missing data without ever knowing that a transmission failure occurred, compromising model correctness. This PR addresses all failure modes on both the **read (`get`) and write (`put`) paths** by adding controlled retries that isolate failed keys and attempt retransmission before giving up. ## Summary This PR introduces a retry mechanism for transient failures in `MooncakeStoreClient`, covering both **read** (`get`) and **write** (`put`) operations, for both tensor and non-tensor data paths. Previously, the client had **zero tolerance for transient errors**: - **Tensor read** (`_get_tensors_thread_worker`): a single key failure (`ret < 0`) would immediately raise `RuntimeError`, causing the entire batch to fail. - **Non-tensor read** (`_get_bytes_thread_worker`): no failure detection at all. Empty byte strings (`b""`) — which MooncakeStore returns on transmission failures — were silently deserialized as `None`, making it impossible for callers to distinguish between "value is None" and "transfer failed". - **Tensor write** (`_put_tensors_thread_worker`): any single key returning a non-zero status would immediately abort the entire batch with `RuntimeError`. - **Non-tensor write** (`_put_bytes_thread_worker`): a single `upsert_batch` failure would immediately abort the batch with `RuntimeError`. This change adds **up to 3 retries with 1-second backoff** across all four paths. For paths that expose per-key status codes, only the failed subset of keys is retried on each attempt. ## Future Work - Replace the `b""` heuristic in `_get_bytes_thread_worker` with proper per-key error codes once MooncakeStore exposes them, then upgrade the exhausted-retry path from `logger.error` to `raise RuntimeError`. - When MooncakeStore supports **per-key status codes for `upsert_batch` and `get_batch`**, switch the bytes write/read paths from whole-batch retry to per-key selective retry, matching the tensor-path behaviour. --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 270ea73 commit d11f808

1 file changed

Lines changed: 188 additions & 21 deletions

File tree

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 188 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import pickle
17+
import time
1718
from concurrent.futures import ThreadPoolExecutor, as_completed
1819
from typing import Any
1920

@@ -33,8 +34,10 @@
3334
except ImportError:
3435
MOONCAKE_STORE_IMPORTED = False
3536

36-
BATCH_SIZE_LIMIT: int = 200
37+
BATCH_SIZE_LIMIT: int = 400
3738
MAX_WORKER_THREADS = 4
39+
MAX_RETRIES = 3
40+
RETRY_DELAY_SECONDS = 1.0
3841

3942

4043
@StorageClientFactory.register("MooncakeStoreClient")
@@ -147,23 +150,96 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
147150

148151
try:
149152
results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config)
150-
if not all(r == 0 for r in results):
151-
failed_indices = [j for j, r in enumerate(results) if r != 0]
152-
error_codes = [results[j] for j in failed_indices]
153+
if len(results) != len(batch_keys):
154+
raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}")
155+
156+
failed_indices = [j for j, r in enumerate(results) if r != 0]
157+
if not failed_indices:
158+
return
159+
160+
current_failed_keys = [batch_keys[i] for i in failed_indices]
161+
current_failed_codes = [results[i] for i in failed_indices]
162+
current_failed_indices = failed_indices
163+
164+
logger.error(
165+
f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. "
166+
f"Retrying up to {MAX_RETRIES} times..."
167+
)
168+
169+
for attempt in range(1, MAX_RETRIES + 1):
170+
retry_ptrs = [batch_ptrs[i] for i in current_failed_indices]
171+
retry_sizes = [batch_sizes[i] for i in current_failed_indices]
172+
173+
retry_results = self._store.batch_upsert_from(
174+
current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config
175+
)
176+
177+
next_failed_indices = []
178+
next_failed_keys = []
179+
next_failed_codes = []
180+
181+
for i, ret in enumerate(retry_results):
182+
if ret != 0:
183+
next_failed_indices.append(current_failed_indices[i])
184+
next_failed_keys.append(current_failed_keys[i])
185+
next_failed_codes.append(ret)
186+
187+
if not next_failed_indices:
188+
logger.info("batch_upsert_from succeeded after retransmission.")
189+
break # All retries in this attempt succeeded.
190+
191+
logger.error(
192+
f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys "
193+
f"with error codes {next_failed_codes}."
194+
)
195+
196+
current_failed_indices = next_failed_indices
197+
current_failed_keys = next_failed_keys
198+
current_failed_codes = next_failed_codes
199+
200+
if attempt < MAX_RETRIES:
201+
time.sleep(RETRY_DELAY_SECONDS)
202+
else:
153203
raise RuntimeError(
154-
f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}"
204+
f"batch_upsert_from failed for keys {current_failed_keys} with error codes "
205+
f"{current_failed_codes} after retrying {MAX_RETRIES} times."
155206
)
207+
156208
finally:
157209
self._unregister_all_buffers(batch_ptr_reduced)
158210

159211
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]):
160212
"""Worker thread for putting batch of non-tensors to MooncakeStore."""
161213

162-
batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
214+
serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
163215

164-
ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
165-
if ret != 0:
166-
raise RuntimeError(f"upsert_batch failed with error code: {ret}")
216+
# FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch,
217+
# switch the bytes write/read paths from whole-batch retry to per-key selective retry,
218+
# matching the tensor-path behaviour.
219+
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
220+
if ret == 0:
221+
return
222+
223+
logger.error(
224+
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. "
225+
f"Retrying up to {MAX_RETRIES} times..."
226+
)
227+
228+
for attempt in range(1, MAX_RETRIES + 1):
229+
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
230+
if ret == 0:
231+
logger.info("upsert_batch succeeded after retransmission.")
232+
return
233+
234+
logger.error(
235+
f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}."
236+
)
237+
if attempt < MAX_RETRIES:
238+
time.sleep(RETRY_DELAY_SECONDS)
239+
240+
raise RuntimeError(
241+
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times."
242+
)
167243

168244
def get(
169245
self,
@@ -238,25 +314,116 @@ def _get_tensors_thread_worker(
238314
ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes)
239315
if len(ret_codes) != len(batch_keys):
240316
raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}")
241-
for i, ret in enumerate(ret_codes):
242-
if ret < 0:
243-
raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}")
317+
318+
failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0]
319+
if not failed_indices:
320+
return batch_buffer_tensors, indexes
321+
322+
# error handling
323+
current_failed_keys = [batch_keys[i] for i in failed_indices]
324+
current_failed_codes = [ret_codes[i] for i in failed_indices]
325+
current_failed_indices = failed_indices
326+
327+
logger.error(
328+
f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. "
329+
f"Retrying up to {MAX_RETRIES} times..."
330+
)
331+
332+
for attempt in range(1, MAX_RETRIES + 1):
333+
# Reuse the originally allocated pointers; no need to allocate/register new buffers.
334+
retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices]
335+
retry_nbytes = [batch_nbytes[i] for i in current_failed_indices]
336+
337+
retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes)
338+
339+
next_failed_indices = []
340+
next_failed_keys = []
341+
next_failed_codes = []
342+
343+
for i, ret in enumerate(retry_codes):
344+
if ret < 0:
345+
next_failed_indices.append(current_failed_indices[i])
346+
next_failed_keys.append(current_failed_keys[i])
347+
next_failed_codes.append(ret)
348+
349+
if not next_failed_indices:
350+
logger.info("batch_get_into succeeded after retransmission.")
351+
break # All retries in this attempt succeeded.
352+
353+
logger.error(
354+
f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys "
355+
f"with error codes {next_failed_codes}."
356+
)
357+
358+
# Narrow down to still-failed items for the next retry attempt.
359+
current_failed_indices = next_failed_indices
360+
current_failed_keys = next_failed_keys
361+
current_failed_codes = next_failed_codes
362+
363+
if attempt < MAX_RETRIES:
364+
time.sleep(RETRY_DELAY_SECONDS)
365+
else:
366+
# All retries exhausted.
367+
raise RuntimeError(
368+
f"batch_get_into failed for keys {current_failed_keys} with error codes "
369+
f"{current_failed_codes} after retrying {MAX_RETRIES} times."
370+
)
371+
244372
finally:
245373
self._unregister_all_buffers(region_ptrs)
246374

247375
return batch_buffer_tensors, indexes
248376

249377
def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]:
250-
results = []
251-
252-
batch_results = self._store.get_batch(batch_keys)
253-
if len(batch_results) != len(batch_keys):
254-
raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}")
255-
256-
batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results]
257-
results.extend(batch_results)
378+
raw_results = self._store.get_batch(batch_keys)
379+
if len(raw_results) != len(batch_keys):
380+
raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}")
381+
382+
# FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported
383+
# Currently we rely on empty bytes (b'') to detect transmission failures because
384+
# MooncakeStore does not currently return a separate status code per key.
385+
failed_indices = [i for i, result in enumerate(raw_results) if result == b""]
386+
if failed_indices:
387+
current_failed_keys = [batch_keys[i] for i in failed_indices]
388+
current_failed_indices = failed_indices
389+
390+
logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...")
391+
392+
for attempt in range(1, MAX_RETRIES + 1):
393+
retry_results = self._store.get_batch(current_failed_keys)
394+
395+
next_failed_keys = []
396+
next_failed_indices = []
397+
398+
for i, result in enumerate(retry_results):
399+
original_idx = current_failed_indices[i]
400+
if result == b"":
401+
next_failed_keys.append(current_failed_keys[i])
402+
next_failed_indices.append(original_idx)
403+
else:
404+
# Write the successfully retried value back to its original slot immediately.
405+
raw_results[original_idx] = result
406+
407+
if not next_failed_indices:
408+
logger.info("get_batch succeeded after retransmission.")
409+
break # All retries in this attempt succeeded.
410+
411+
logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.")
412+
413+
# Narrow down to still-failed items for the next retry attempt.
414+
current_failed_keys = next_failed_keys
415+
current_failed_indices = next_failed_indices
416+
417+
if attempt < MAX_RETRIES:
418+
time.sleep(RETRY_DELAY_SECONDS)
419+
else:
420+
# All retries exhausted.
421+
raise RuntimeError(
422+
f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times."
423+
)
258424

259-
return results, indexes
425+
deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results]
426+
return deserialized_results, indexes
260427

261428
def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None:
262429
"""Deletes multiple keys from MooncakeStore.

0 commit comments

Comments
 (0)