Skip to content

Commit 2354908

Browse files
authored
Add bounded generation batch scheduling (#2)
* feat: extract generation helpers; add bounded batch execution * docs: update README and guide to include max_parallel_requests for LMClient * chore: update infermesh package details and dependencies to version 0.2.0 * fix: validate max_parallel_requests at client construction
1 parent f773ea9 commit 2354908

8 files changed

Lines changed: 582 additions & 157 deletions

File tree

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ prompts = [
5959

6060
with LMClient(
6161
model="openai/gpt-4.1-mini",
62+
max_parallel_requests=32,
6263
rpm=500,
6364
tpm=100_000,
6465
) as client:
@@ -84,6 +85,10 @@ One failing request does not abort the whole batch. Failed items are `None` in
8485
`batch.results`; the exception is in `batch.errors[i]`. This is deliberate: a single
8586
provider error should not wipe out a long experiment.
8687

88+
For large Python batches, set `max_parallel_requests` explicitly. That enables
89+
bounded in-flight scheduling for `generate_batch`; when it is unset, the method
90+
may start work for the full batch up front.
91+
8792
This code works in Jupyter notebooks without any `asyncio` setup. The sync API runs a
8893
background event loop so you do not have to.
8994

@@ -197,7 +202,7 @@ import json
197202
from infermesh import LMClient
198203

199204
with open("results.jsonl", "w") as out, \
200-
LMClient(model="openai/gpt-4.1-mini") as client:
205+
LMClient(model="openai/gpt-4.1-mini", max_parallel_requests=32) as client:
201206

202207
def save(index: int, result, error) -> None:
203208
row = {"index": index}

docs/guide.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ prompts = [
2222

2323
with LMClient(
2424
model="openai/gpt-4.1-mini",
25+
max_parallel_requests=32,
2526
rpm=500,
2627
tpm=100_000,
2728
) as client:
@@ -48,6 +49,10 @@ By default, one failing request does not abort the whole batch. Failed items are
4849
stored as `None` in `batch.results`, and the corresponding exception is stored
4950
in `batch.errors[i]`.
5051

52+
For large Python batches, set `max_parallel_requests` explicitly. That enables
53+
bounded in-flight scheduling for `generate_batch`; when it is unset, the method
54+
may start work for the full batch up front.
55+
5156
### Crash-Resilient Batches with `on_result`
5257

5358
For large batches, you may want to write results to disk as each request
@@ -65,7 +70,7 @@ from infermesh import LMClient
6570
prompts = [...] # large list
6671

6772
with open("results.jsonl", "w") as out, \
68-
LMClient(model="openai/gpt-4.1-mini") as client:
73+
LMClient(model="openai/gpt-4.1-mini", max_parallel_requests=32) as client:
6974

7075
def save(index: int, result, error) -> None:
7176
row = {"index": index}
@@ -102,7 +107,7 @@ if output_path.exists():
102107
pending = [(i, p) for i, p in enumerate(prompts) if i not in done]
103108

104109
with open(output_path, "a") as out, \
105-
LMClient(model="openai/gpt-4.1-mini") as client:
110+
LMClient(model="openai/gpt-4.1-mini", max_parallel_requests=32) as client:
106111

107112
def save(batch_idx: int, result, error) -> None:
108113
orig_idx = pending[batch_idx][0]

src/infermesh/_client_runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,17 @@ def _validate_init_args(
9393
api_key: str | None,
9494
deployments: dict[str, DeploymentConfig | dict[str, Any]] | None,
9595
endpoint: EndpointType,
96+
max_parallel_requests: int | None,
9697
) -> None:
9798
"""Validate top-level constructor arguments."""
9899

99100
validate_endpoint(endpoint)
100101
if model is None:
101102
raise ValueError("``model`` is required.")
103+
if max_parallel_requests is not None and max_parallel_requests < 1:
104+
raise ValueError(
105+
"``max_parallel_requests`` must be ``None`` or a positive integer."
106+
)
102107
if deployments is not None and (api_base is not None or api_key is not None):
103108
raise ValueError(
104109
"``api_base`` and ``api_key`` cannot be set when ``deployments`` "
@@ -199,11 +204,14 @@ async def _dispatch_with_controls(
199204
request_callable: Any,
200205
request_args: tuple[Any, ...],
201206
request_kwargs: dict[str, Any],
207+
queue_started_at: float | None = None,
202208
) -> tuple[Any, RequestMetrics]:
203209
"""Run a request with concurrency and rate-limiting controls."""
204210

205211
state = self._get_loop_state()
206-
queue_started = time.perf_counter()
212+
queue_started = (
213+
queue_started_at if queue_started_at is not None else time.perf_counter()
214+
)
207215
handle: RateLimiterAcquisitionHandle | None = None
208216
semaphore_context = (
209217
state.semaphore if state.semaphore is not None else _null_async_context()

0 commit comments

Comments
 (0)