2020from litellm import BaseModel
2121from pydantic import Field
2222
23+ from aperag .db .models import APIType
2324from aperag .db .ops import async_db_ops
2425from aperag .flow .base .models import BaseNodeRunner , SystemInput , register_node_runner
2526from aperag .llm .completion .completion_service import CompletionService
2930from aperag .utils .history import BaseChatMessageHistory
3031from aperag .utils .utils import now_unix_milliseconds
3132
32- MAX_CONTEXT_LENGTH = 100000
33+ # Character to token estimation ratio for Chinese/mixed content
34+ # Conservative estimate: 2 characters = 1 token
35+ CHAR_TO_TOKEN_RATIO = 2.0
36+
37+ # Reserve tokens for output generation (default 1000 tokens)
38+ DEFAULT_OUTPUT_TOKENS = 1000
39+
40+ # Fallback max context length if model max_tokens is not available
41+ FALLBACK_MAX_CONTEXT_LENGTH = 50000
3342
3443
3544class Message (BaseModel ):
@@ -81,14 +90,39 @@ class LLMInput(BaseModel):
8190 custom_llm_provider : str = Field (..., description = "Custom LLM provider" )
8291 prompt_template : str = Field (..., description = "Prompt template" )
8392 temperature : float = Field (..., description = "Sampling temperature" )
84- max_tokens : int = Field (..., description = "Max tokens for generation" )
8593 docs : Optional [List [DocumentWithScore ]] = Field (None , description = "Documents" )
8694
8795
8896class LLMOutput (BaseModel ):
8997 text : str
9098
9199
100+ def estimate_token_count (text : str ) -> int :
101+ """
102+ Estimate token count from character count for Chinese/mixed content.
103+ Using conservative ratio: 2 characters = 1 token
104+ """
105+ return int (len (text ) / CHAR_TO_TOKEN_RATIO )
106+
107+
108+ def calculate_max_context_length (model_max_tokens : Optional [int ], output_tokens : int = DEFAULT_OUTPUT_TOKENS ) -> int :
109+ """
110+ Calculate maximum context length based on model's max_tokens limit.
111+ Reserve tokens for output generation.
112+ """
113+ if not model_max_tokens :
114+ return FALLBACK_MAX_CONTEXT_LENGTH
115+
116+ # Reserve tokens for output, convert to character count
117+ max_context_tokens = model_max_tokens - output_tokens
118+ if max_context_tokens <= 0 :
119+ # If model max_tokens is too small, use a minimal context
120+ max_context_tokens = max (model_max_tokens // 2 , 100 )
121+
122+ # Convert tokens to character count
123+ return int (max_context_tokens * CHAR_TO_TOKEN_RATIO )
124+
125+
92126# Database operations interface
93127class LLMRepository :
94128 """Repository interface for LLM database operations"""
@@ -114,7 +148,6 @@ async def generate_response(
114148 custom_llm_provider : str ,
115149 prompt_template : str ,
116150 temperature : float ,
117- max_tokens : int ,
118151 docs : Optional [List [DocumentWithScore ]] = None ,
119152 ) -> Tuple [str , Dict ]:
120153 """Generate LLM response with given parameters"""
@@ -130,23 +163,43 @@ async def generate_response(
130163 except Exception :
131164 raise Exception (f"LLMProvider { model_service_provider } not found" )
132165
166+ # Get model configuration to determine max_tokens
167+ try :
168+ model_config = await async_db_ops .query_llm_provider_model (
169+ provider_name = model_service_provider ,
170+ api = APIType .COMPLETION .value ,
171+ model = model_name
172+ )
173+ model_max_tokens = model_config .max_tokens if model_config else None
174+ except Exception :
175+ model_max_tokens = None
176+
177+ # Calculate dynamic context length based on model's max_tokens
178+ max_context_length = calculate_max_context_length (model_max_tokens )
179+
133180 # Build context and references from documents
134181 context = ""
135182 references = []
136183 if docs :
137184 for doc in docs :
138- if len (context ) + len (doc .text ) > MAX_CONTEXT_LENGTH :
185+ if len (context ) + len (doc .text ) > max_context_length :
139186 break
140187 context += doc .text
141188 references .append ({"text" : doc .text , "metadata" : doc .metadata , "score" : doc .score })
142189
143190 prompt = prompt_template .format (query = query , context = context )
144- output_max_tokens = max_tokens - len (prompt )
145-
146- if output_max_tokens < 0 :
147- raise Exception (
148- "max_tokens %d is too small to hold the prompt which size is %d" % (max_tokens , len (prompt ))
149- )
191+
192+ # Estimate prompt tokens and calculate output tokens
193+ prompt_tokens = estimate_token_count (prompt )
194+ if model_max_tokens :
195+ output_max_tokens = model_max_tokens - prompt_tokens
196+ if output_max_tokens < 100 : # Ensure minimum output tokens
197+ raise Exception (
198+ f"Model max_tokens { model_max_tokens } is too small to hold the prompt which requires approximately { prompt_tokens } tokens"
199+ )
200+ else :
201+ # Use default output tokens if model max_tokens is unknown
202+ output_max_tokens = DEFAULT_OUTPUT_TOKENS
150203
151204 cs = CompletionService (custom_llm_provider , model_name , base_url , api_key , temperature , output_max_tokens )
152205
@@ -193,7 +246,6 @@ async def run(self, ui: LLMInput, si: SystemInput) -> Tuple[LLMOutput, dict]:
193246 custom_llm_provider = ui .custom_llm_provider ,
194247 prompt_template = ui .prompt_template ,
195248 temperature = ui .temperature ,
196- max_tokens = ui .max_tokens ,
197249 docs = ui .docs ,
198250 )
199251
0 commit comments