Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ print(result.request_id) # provider-assigned ID for debugging
result = client.embed("The quick brown fox")
print(result.embedding) # list[float]

# Multiple strings → sent in one API call
batch = client.embed_batch(["sentence one", "sentence two", "sentence three"])
# Multiple strings → processed in resilient micro-batches by default
batch = client.embed_batch(
["sentence one", "sentence two", "sentence three"],
micro_batch_size=32,
)
vectors = [r.embedding for r in batch if r is not None]
```

Expand All @@ -145,8 +148,16 @@ result = client.transcribe("recording.wav") # path, bytes, or file-like object
print(result.text)
print(result.language) # detected language code, e.g. "en"
print(result.duration_s) # audio length in seconds

batch = client.transcribe_batch(["recording-a.wav", "recording-b.wav"])
texts = [r.text if r is not None else None for r in batch]
```

Audio inputs larger than 25 MB are rejected by default. Pass
`max_transcription_bytes=None` only in trusted environments where the server is
expected to accept larger uploads. Disabling the guard means the client may
read and send very large audio files in full.

## CLI

```bash
Expand Down Expand Up @@ -197,6 +208,9 @@ For long runs, pass `on_result` to write each result to disk as it arrives.
A crash or interruption only loses the requests that were in-flight at that
moment — everything already completed is safe on disk.

`generate_batch`, `embed_batch`, and `transcribe_batch` all support the same
per-item callback contract.

```python
import json
from infermesh import LMClient
Expand Down Expand Up @@ -340,6 +354,7 @@ async def main():
batch = await client.agenerate_batch(["prompt A", "prompt B", "prompt C"])
emb = await client.aembed("The quick brown fox")
embs = await client.aembed_batch(["text a", "text b"])
txs = await client.atranscribe_batch(["a.wav", "b.wav"])

asyncio.run(main())
```
Expand Down
28 changes: 23 additions & 5 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ completes rather than waiting for the whole batch to finish. This way a
process crash or interruption only loses the in-flight requests, not
everything already completed.

`generate_batch`, `embed_batch`, and `transcribe_batch` all support the same
`on_result(index, result, error)` contract.

Pass an `on_result` callback to `generate_batch` (or `agenerate_batch`):

```python
Expand Down Expand Up @@ -280,8 +283,11 @@ client = LMClient(
result = client.embed("The quick brown fox")
print(result.embedding)

# Multiple strings -> sent in one API call
batch = client.embed_batch(["sentence one", "sentence two", "sentence three"])
# Multiple strings -> processed in resilient micro-batches by default
batch = client.embed_batch(
["sentence one", "sentence two", "sentence three"],
micro_batch_size=32,
)
vectors = [r.embedding for r in batch if r is not None]
```

Expand All @@ -292,11 +298,16 @@ result = client.transcribe("recording.wav")
print(result.text)
print(result.language)
print(result.duration_s)

batch = client.transcribe_batch(["recording-a.wav", "recording-b.wav"])
texts = [r.text if r is not None else None for r in batch]
```

Audio inputs larger than 25 MB are rejected by default. Pass
`max_transcription_bytes=None` to disable the limit, or a smaller integer to
tighten it.
`max_transcription_bytes=None` only in trusted environments where the server is
expected to accept larger uploads. Disabling the guard means the client may
read and send very large audio files in full. Pass a smaller integer to
tighten the limit.

## Multimodal / VLM

Expand Down Expand Up @@ -514,7 +525,14 @@ async def main():
batch = await client.agenerate_batch(["prompt A", "prompt B", "prompt C"])
embedding = await client.aembed("The quick brown fox")
embedding_batch = await client.aembed_batch(["text a", "text b"])
print(result.output_text, len(batch), len(embedding.embedding), len(embedding_batch))
transcription_batch = await client.atranscribe_batch(["a.wav", "b.wav"])
print(
result.output_text,
len(batch),
len(embedding.embedding),
len(embedding_batch),
len(transcription_batch),
)

asyncio.run(main())
```
Expand Down
17 changes: 17 additions & 0 deletions src/infermesh/_batch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Shared helpers for internal batch runners."""

from __future__ import annotations

import asyncio
from collections.abc import Sequence
from typing import Any


async def cancel_tasks(tasks: Sequence[asyncio.Task[Any]]) -> None:
"""Cancel unfinished tasks and await their cleanup."""

for task in tasks:
if not task.done():
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
87 changes: 71 additions & 16 deletions src/infermesh/_cli_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,54 @@
DEFAULT_EMBED_BATCH_SIZES = [1, 8, 32, 128, 512]


def _embedding_request_key(
result: EmbeddingResult, index: int
) -> tuple[Any, Any, Any, Any]:
"""Return a best-effort key for one underlying embedding request."""

key = (
result.request_id,
id(result.metrics) if result.metrics is not None else None,
id(result.raw_response) if result.raw_response is not None else None,
)
if key == (None, None, None):
return (*key, index)
return (*key, None)


def _accumulate_embed_batch_call(
result: BatchResult[EmbeddingResult],
*,
latencies: list[float],
service_times: list[float],
) -> tuple[int, int, int, int]:
"""Return per-call embedding benchmark totals and update latency stats."""

total_submitted = 0
succeeded = 0
failures = 0
total_tokens = 0
seen_requests: set[tuple[Any, Any, Any, Any]] = set()

for index, item in enumerate(result.results):
total_submitted += 1
if item is None:
failures += 1
continue
succeeded += 1
key = _embedding_request_key(item, index)
if key in seen_requests:
continue
seen_requests.add(key)
if item.metrics is not None:
latencies.append(item.metrics.end_to_end_s)
service_times.append(item.metrics.service_time_s)
if item.token_usage is not None:
total_tokens += item.token_usage.total_tokens

return total_submitted, succeeded, failures, total_tokens


def _run_benchmark(
*,
task_name: str,
Expand Down Expand Up @@ -321,7 +369,10 @@ def _run_embed_batch_benchmark(
try:
warmup_batch = batched_cycle(texts, batch_size_sweep[0])
for _ in range(warmup):
warmup_client.embed_batch(warmup_batch)
warmup_client.embed_batch(
warmup_batch,
micro_batch_size=batch_size_sweep[0],
)
finally:
warmup_client.close()

Expand All @@ -333,6 +384,8 @@ def _run_embed_batch_benchmark(
latencies: list[float] = []
service_times: list[float] = []
total_tokens = 0
total_submitted = 0
succeeded = 0
failures = 0

started_at = time.perf_counter()
Expand All @@ -343,31 +396,33 @@ def _run_embed_batch_benchmark(
file=sys.stderr,
) as pbar:
for _ in range(requests):
result = client.embed_batch(sample)
first = result.results[0] if result.results else None
if first is None:
failures += 1
else:
if first.metrics is not None:
latencies.append(first.metrics.end_to_end_s)
service_times.append(first.metrics.service_time_s)
if first.token_usage is not None:
total_tokens += first.token_usage.total_tokens
result = client.embed_batch(
sample,
micro_batch_size=batch_size,
)
submitted_delta, succeeded_delta, failures_delta, tokens_delta = (
_accumulate_embed_batch_call(
result,
latencies=latencies,
service_times=service_times,
)
)
total_submitted += submitted_delta
succeeded += succeeded_delta
failures += failures_delta
total_tokens += tokens_delta
pbar.update(1)
elapsed_s = time.perf_counter() - started_at
finally:
client.close()

succeeded = requests - failures
level: dict[str, Any] = {
"batch_size": batch_size,
"total_submitted": requests,
"total_submitted": total_submitted,
"succeeded": succeeded,
"failures": failures,
"elapsed_s": elapsed_s,
"vectors_per_second": succeeded * len(sample) / elapsed_s
if elapsed_s > 0
else 0.0,
"vectors_per_second": succeeded / elapsed_s if elapsed_s > 0 else 0.0,
**_percentile_stats(latencies, service_times),
}
if total_tokens:
Expand Down
Loading
Loading