Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ One failing request does not abort the whole batch. Failed items are `None` in
`batch.results`; the exception is in `batch.errors[i]`. This is deliberate: a single
provider error should not wipe out a long experiment.

For large Python batches, set `max_parallel_requests` explicitly. That enables
bounded in-flight scheduling for `generate_batch`; when it is unset, the method
may start work for the full batch up front.
For large Python batches, set `max_parallel_requests` explicitly. `generate_batch`
and `transcribe_batch` both use a bounded in-flight window when it is set; when it
is unset, they start one coroutine per item up front, which can cause memory pressure
for very large inputs. `embed_batch` is always micro-batched regardless of
`max_parallel_requests` — pass `micro_batch_size` to tune chunk size instead.

This code works in Jupyter notebooks without any `asyncio` setup. The sync API runs a
background event loop so you do not have to.
Expand Down
46 changes: 39 additions & 7 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ By default, one failing request does not abort the whole batch. Failed items are
stored as `None` in `batch.results`, and the corresponding exception is stored
in `batch.errors[i]`.

For large Python batches, set `max_parallel_requests` explicitly. That enables
bounded in-flight scheduling for `generate_batch`; when it is unset, the method
may start work for the full batch up front.
For large Python batches, set `max_parallel_requests` explicitly. `generate_batch`
and `transcribe_batch` both use a bounded in-flight window when it is set; when it
is unset, they start one coroutine per item up front, which can cause memory pressure
for very large inputs. `embed_batch` is always micro-batched regardless of
`max_parallel_requests` — pass `micro_batch_size` to tune chunk size instead.

### Crash-Resilient Batches with `on_result`

Expand Down Expand Up @@ -91,12 +93,17 @@ The callback receives:

| Argument | Type | Notes |
|---|---|---|
| `index` | `int` | Position in `input_batch` |
| `result` | `GenerationResult \| None` | `None` on failure |
| `index` | `int` | Position in `input_batch` (global item index, not micro-batch index) |
| `result` | `GenerationResult \| EmbeddingResult \| TranscriptionResult \| None` | `None` on failure |
| `error` | `BaseException \| None` | `None` on success |

To resume from a partial output file, read the completed indices before the
batch starts and filter the input:
The same contract applies to `embed_batch` and `transcribe_batch`.
For `embed_batch`, the callback uses the same `index`, `result`, and `error`
arguments when `on_result` is invoked, and `index` is always the position in the
original input list even when the provider call was part of a micro-batch.
Per-item error callbacks are guaranteed when `return_exceptions=True`. With
`return_exceptions=False`, a failed embedding micro-batch may raise before
`on_result` is called for the affected indices.

```python
done = set()
Expand Down Expand Up @@ -303,6 +310,31 @@ batch = client.transcribe_batch(["recording-a.wav", "recording-b.wav"])
texts = [r.text if r is not None else None for r in batch]
```

`transcribe_batch` supports the same `on_result` and `on_progress` callbacks as
`generate_batch`. Use `on_result` to stream results to disk as each file completes
rather than waiting for the whole batch:

```python
import json

with open("transcripts.jsonl", "w") as out, \
LMClient(model="whisper-1", max_parallel_requests=4) as client:

Comment thread
fcogidi marked this conversation as resolved.
def save(index: int, result, error) -> None:
row = {"index": index}
if error is not None:
row["error"] = str(error)
else:
row["text"] = result.text
out.write(json.dumps(row) + "\n")
out.flush()

client.transcribe_batch(audio_paths, on_result=save)
```

Set `max_parallel_requests` to bound how many audio files are in-flight at once.
When it is unset, `transcribe_batch` starts all requests up front.

Audio inputs larger than 25 MB are rejected by default. Pass
`max_transcription_bytes=None` only in trusted environments where the server is
expected to accept larger uploads. Disabling the guard means the client may
Expand Down
17 changes: 16 additions & 1 deletion src/infermesh/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def _settle_bounded_generation_task(
result = task.result()
except BaseException as exc: # noqa: BLE001
if not return_exceptions:
if on_result is not None:
on_result(index, None, exc)
completed += 1
if on_progress is not None:
on_progress(completed, total)
return completed, exc
assert errors is not None
errors[index] = exc
Expand Down Expand Up @@ -335,7 +340,17 @@ async def progress_wrapper(
) -> GenerationResult:
"""Wrap a coroutine with progress and result callbacks."""

result = await coro
try:
result = await coro
except asyncio.CancelledError:
raise
except BaseException as exc: # noqa: BLE001
completed[0] += 1
if on_result is not None:
on_result(index, None, exc)
if on_progress is not None:
on_progress(completed[0], len(coros))
raise
completed[0] += 1
if on_result is not None:
on_result(index, result, None)
Expand Down
78 changes: 57 additions & 21 deletions src/infermesh/_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,51 @@ async def _atranscribe_one(
return build_transcription_result(response, metrics=metrics)


async def _consume_transcribe_task(
task: asyncio.Task[TranscriptionResult],
index: int,
*,
return_exceptions: bool,
results: list[TranscriptionResult | None],
errors: list[BaseException | None] | None,
on_result: OnTranscriptionResult,
on_progress: Callable[[int, int], None] | None,
completed: int,
total: int,
active_tasks: dict[asyncio.Task[TranscriptionResult], int],
) -> int:
"""Apply one finished transcription task to batch state; return new completed count."""

try:
result = task.result()
except asyncio.CancelledError:
raise
except Exception as exc:
if not return_exceptions:
if on_result is not None:
on_result(index, None, exc)
completed += 1
if on_progress is not None:
on_progress(completed, total)
pending = list(active_tasks)
active_tasks.clear()
await cancel_tasks(pending)
raise
assert errors is not None
errors[index] = exc
if on_result is not None:
on_result(index, None, exc)
else:
results[index] = result
if on_result is not None:
on_result(index, result, None)

completed += 1
if on_progress is not None:
on_progress(completed, total)
return completed


async def _atranscribe_batch(
self: LMClient,
input_batch: Sequence[TranscriptionInput],
Expand Down Expand Up @@ -101,27 +146,18 @@ def admit_inputs() -> int:
)
for task in done:
index = active_tasks.pop(task)
try:
result = task.result()
except asyncio.CancelledError:
raise
except Exception as exc:
if not return_exceptions:
pending = list(active_tasks)
active_tasks.clear()
await cancel_tasks(pending)
raise
assert errors is not None
errors[index] = exc
if on_result is not None:
on_result(index, None, exc)
else:
results[index] = result
if on_result is not None:
on_result(index, result, None)
completed += 1
if on_progress is not None:
on_progress(completed, total)
completed = await _consume_transcribe_task(
task,
index,
return_exceptions=return_exceptions,
results=results,
errors=errors,
on_result=on_result,
on_progress=on_progress,
completed=completed,
total=total,
active_tasks=active_tasks,
)
admit_inputs()
except BaseException:
pending = list(active_tasks)
Expand Down
115 changes: 115 additions & 0 deletions tests/test_client_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,118 @@ def test_on_result_captures_error_on_failure(
assert result is None
assert isinstance(error, RuntimeError)
failing_fake_client.close()


@pytest.mark.asyncio
async def test_on_result_fires_for_failing_item_bounded_strict_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""on_result must fire for the failing item even when return_exceptions=False (bounded path)."""
monkeypatch.setattr(
LMClient,
"_create_litellm_module",
lambda self: FailingFakeLiteLLM(fail_on={"bad"}),
)
client = LMClient(
model="openai/test",
api_base="http://localhost",
max_parallel_requests=2,
)
calls: list[tuple[int, Any, Any]] = []
with pytest.raises(RuntimeError, match="Simulated failure"):
await client.agenerate_batch(
["good", "bad"],
return_exceptions=False,
on_result=lambda idx, result, error: calls.append((idx, result, error)),
)
error_calls = [
(idx, result, error) for idx, result, error in calls if error is not None
]
assert len(error_calls) == 1
fail_idx, fail_result, fail_error = error_calls[0]
assert fail_result is None
assert isinstance(fail_error, RuntimeError)
# The failing item must be the one whose input was "bad"
assert fail_idx == 1
client.close()


@pytest.mark.asyncio
async def test_on_result_fires_for_failing_item_unbounded_strict_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""on_result must fire for the failing item even when return_exceptions=False (unbounded/TaskGroup path)."""
gate = asyncio.Event()
cancelled: list[str] = []

class GatedFailing(FakeLiteLLM):
async def acompletion(self, **kwargs: Any) -> Any:
content = kwargs["messages"][0]["content"]
if content == "bad":
await asyncio.sleep(0)
raise RuntimeError("boom")
try:
await gate.wait()
return await super().acompletion(**kwargs)
except asyncio.CancelledError:
cancelled.append(content)
raise

monkeypatch.setattr(LMClient, "_create_litellm_module", lambda self: GatedFailing())
client = LMClient(model="openai/test", api_base="http://localhost")
calls: list[tuple[int, Any, Any]] = []
with pytest.raises(RuntimeError, match="boom"):
await client.agenerate_batch(
["slow", "bad"],
return_exceptions=False,
on_result=lambda idx, result, error: calls.append((idx, result, error)),
)
# "bad" (index 1) must have fired on_result with an error
error_calls = [
(idx, result, error) for idx, result, error in calls if error is not None
]
assert len(error_calls) == 1
fail_idx, fail_result, fail_error = error_calls[0]
assert fail_idx == 1
assert fail_result is None
assert isinstance(fail_error, RuntimeError)
# "slow" was cancelled — on_result must NOT have been called for it
assert not any(idx == 0 for idx, _, _ in calls)
assert cancelled == ["slow"]
client.close()


@pytest.mark.asyncio
async def test_on_result_fires_for_failing_item_transcribe_strict_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""on_result must fire for the failing item even when return_exceptions=False (transcription path)."""

class FailingTranscriptionFakeLiteLLM(FakeLiteLLM):
async def atranscription(self, **kwargs: Any) -> dict[str, Any]:
if kwargs["file"] == b"bad":
raise RuntimeError("transcription boom")
return await super().atranscription(**kwargs)

monkeypatch.setattr(
LMClient,
"_create_litellm_module",
lambda self: FailingTranscriptionFakeLiteLLM(),
)
client = LMClient(model="openai/test", api_base="http://localhost")
calls: list[tuple[int, Any, Any]] = []
with pytest.raises(RuntimeError, match="transcription boom"):
await client.atranscribe_batch(
[b"good", b"bad"],
return_exceptions=False,
on_result=lambda idx, result, error: calls.append((idx, result, error)),
)
error_calls = [
(idx, result, error) for idx, result, error in calls if error is not None
]
assert len(error_calls) == 1
fail_idx, fail_result, fail_error = error_calls[0]
assert fail_idx == 1
assert fail_result is None
assert isinstance(fail_error, RuntimeError)
client.close()
Loading