Skip to content
This repository was archived by the owner on May 27, 2026. It is now read-only.

Commit 3e05485

Browse files
authored
feat: implement dynamic max chunks per batch from env variable (#898)
- max chunks per batch when sending using "gemini" - add ability to configure this even more arbitrarily using env vars Signed-off-by: Nick Crews <nicholas.b.crews@gmail.com>
1 parent 46cdf69 commit 3e05485

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

projects/pgai/pgai/vectorizer/embedders/litellm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from collections.abc import AsyncGenerator, Callable
23
from typing import Any, Literal
34

@@ -61,6 +62,14 @@ def _max_chunks_per_batch(self) -> int:
6162
# Note: deferred import to avoid import overhead
6263
import litellm
6364

65+
if result := os.getenv("PGAI_LITELLM_MAX_CHUNKS_PER_BATCH"):
66+
try:
67+
return int(result)
68+
except ValueError:
69+
logger.warn(
70+
"Value for PGAI_LITELLM_MAX_CHUNKS_PER_BATCH is not a valid int. Continuing with default provider value"
71+
)
72+
6473
_, custom_llm_provider, _, _ = litellm.get_llm_provider(self.model) # type: ignore
6574
match custom_llm_provider:
6675
case "cohere":
@@ -71,6 +80,8 @@ def _max_chunks_per_batch(self) -> int:
7180
return 2048 # https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#verify-inputs-dont-exceed-the-maximum-length
7281
case "bedrock":
7382
return 96 # NOTE: currently (Jan 2025) Bedrock only supports embeddings with Cohere or Titan models. The Titan API only processes one input per request, which LiteLLM already handles under the hood. We assume that the Cohere API has the same input limits as above.
83+
case "gemini":
84+
return 250 # https://docs.cloud.google.com/vertex-ai/docs/quotas#text-embedding-limits
7485
case "huggingface":
7586
return 2048 # NOTE: There is not documented limit. In testing we got a response for a request with 10k (short) inputs.
7687
case "mistral":
@@ -80,10 +91,11 @@ def _max_chunks_per_batch(self) -> int:
8091
case "voyage":
8192
return 128 # see https://docs.voyageai.com/reference/embeddings-api
8293
case _:
94+
fallback = 5
8395
logger.warn(
84-
f"unknown provider '{custom_llm_provider}', falling back to conservative max chunks per batch"
96+
f"unknown provider '{custom_llm_provider}', falling back to {fallback} max chunks per batch"
8597
)
86-
return 5
98+
return fallback
8799

88100
@override
89101
def _max_tokens_per_batch(self) -> int | None:

0 commit comments

Comments
 (0)