Skip to content

Commit e6cf0c1

Browse files
committed
fix(client): batch embed retries, transcription reraise, and batch result typing
- Skip recursive micro-batch split on retryable errors in embed_batch - Use bare raise in transcription batch error handler - Define OnBatchResult as PEP 695 generic type alias Made-with: Cursor
1 parent 8cca367 commit e6cf0c1

4 files changed

Lines changed: 51 additions & 2 deletions

File tree

src/infermesh/_embedding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def _validate_micro_batch_size(micro_batch_size: int) -> None:
3434
raise ValueError("``micro_batch_size`` must be a positive integer.")
3535

3636

37+
def _should_isolate_embedding_failure(self: LMClient, exc: Exception) -> bool:
38+
"""Return whether a failed micro-batch should be recursively isolated."""
39+
40+
return not isinstance(exc, self._retryable_exceptions)
41+
42+
3743
async def _aembed_one(
3844
self: LMClient,
3945
input_data: str,
@@ -105,6 +111,10 @@ async def _resolve_embedding_chunk_capture(
105111
except Exception as exc:
106112
if len(input_data) == 1:
107113
return [(start_index, None, exc)]
114+
if not _should_isolate_embedding_failure(self, exc):
115+
return [
116+
(start_index + offset, None, exc) for offset in range(len(input_data))
117+
]
108118
midpoint = len(input_data) // 2
109119
left = await _resolve_embedding_chunk_capture(
110120
self,

src/infermesh/_transcription.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def admit_inputs() -> int:
110110
pending = list(active_tasks)
111111
active_tasks.clear()
112112
await cancel_tasks(pending)
113-
raise exc
113+
raise
114114
assert errors is not None
115115
errors[index] = exc
116116
if on_result is not None:

src/infermesh/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def __len__(self) -> int:
678678
TranscriptionBatchResult: TypeAlias = BatchResult[TranscriptionResult]
679679
"""Type alias for a batch of transcription results."""
680680

681-
OnBatchResult: TypeAlias = Callable[[int, T | None, BaseException | None], None] | None
681+
type OnBatchResult[T] = Callable[[int, T | None, BaseException | None], None] | None
682682
"""Generic callback type for per-result notifications in batch methods.
683683
684684
Called as ``on_result(index, result, error)`` each time a single request

tests/test_client_batch.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,45 @@ async def test_aembed_batch_recursively_isolates_bad_items(
242242
assert result.errors[2] is None
243243

244244

245+
@pytest.mark.asyncio
246+
async def test_aembed_batch_does_not_split_retryable_failures(
247+
monkeypatch: pytest.MonkeyPatch,
248+
) -> None:
249+
calls: list[list[str]] = []
250+
251+
class RetryableEmbeddingFakeLiteLLM(FakeLiteLLM):
252+
async def aembedding(self, **kwargs: Any) -> dict[str, Any]:
253+
payload = list(kwargs["input"])
254+
calls.append(payload)
255+
raise self.RateLimitError("rate limited")
256+
257+
monkeypatch.setattr(
258+
LMClient,
259+
"_create_litellm_module",
260+
lambda self: RetryableEmbeddingFakeLiteLLM(),
261+
)
262+
client = LMClient(
263+
model="openai/test",
264+
api_base="http://localhost",
265+
max_retries=0,
266+
)
267+
268+
batch = await client.aembed_batch(
269+
["a", "b", "c", "d"],
270+
micro_batch_size=4,
271+
return_exceptions=True,
272+
)
273+
274+
assert calls == [["a", "b", "c", "d"]]
275+
assert batch.errors is not None
276+
assert all(result is None for result in batch.results)
277+
assert all(
278+
isinstance(error, RetryableEmbeddingFakeLiteLLM.RateLimitError)
279+
for error in batch.errors
280+
)
281+
client.close()
282+
283+
245284
@pytest.mark.asyncio
246285
async def test_aembed_batch_return_exceptions_false_raises(
247286
failing_fake_client: LMClient,

0 commit comments

Comments
 (0)