|
20 | 20 | from chainlit.types import ThreadDict |
21 | 21 | import chainlit.data as cl_data |
22 | 22 | from litellm import acompletion |
| 23 | +import litellm |
23 | 24 | from db import DatabaseManager |
24 | 25 |
|
25 | 26 | # Load environment variables |
|
40 | 41 | # Set the logging level for the logger |
41 | 42 | logger.setLevel(log_level) |
42 | 43 |
|
| 44 | +# Configure litellm same as in llm.py |
| 45 | +litellm.set_verbose = False |
| 46 | +litellm.success_callback = [] |
| 47 | +litellm._async_success_callback = [] |
| 48 | +litellm.callbacks = [] |
| 49 | +litellm.drop_params = True |
| 50 | +litellm.modify_params = True |
| 51 | +litellm.suppress_debug_messages = True |
| 52 | + |
43 | 53 | CHAINLIT_AUTH_SECRET = os.getenv("CHAINLIT_AUTH_SECRET") |
44 | 54 |
|
45 | 55 | if not CHAINLIT_AUTH_SECRET: |
|
55 | 65 |
|
56 | 66 | deleted_thread_ids = [] # type: List[str] |
57 | 67 |
|
| 68 | +def _build_completion_params(model_name, **override_params): |
| 69 | + """Build parameters for litellm completion calls with proper model handling""" |
| 70 | + params = { |
| 71 | + "model": model_name, |
| 72 | + } |
| 73 | + |
| 74 | + # Override with any provided parameters |
| 75 | + params.update(override_params) |
| 76 | + |
| 77 | + return params |
| 78 | + |
58 | 79 | def save_setting(key: str, value: str): |
59 | 80 | """Saves a setting to the database. |
60 | 81 | |
@@ -237,12 +258,12 @@ async def main(message: cl.Message): |
237 | 258 | msg = cl.Message(content="") |
238 | 259 | await msg.send() |
239 | 260 |
|
240 | | - # Prepare the completion parameters |
241 | | - completion_params = { |
242 | | - "model": model_name, |
243 | | - "messages": message_history, |
244 | | - "stream": True, |
245 | | - } |
| 261 | + # Prepare the completion parameters using the helper function |
| 262 | + completion_params = _build_completion_params( |
| 263 | + model_name, |
| 264 | + messages=message_history, |
| 265 | + stream=True, |
| 266 | + ) |
246 | 267 |
|
247 | 268 | # If an image is uploaded, include it in the message |
248 | 269 | if image: |
@@ -344,9 +365,11 @@ async def main(message: cl.Message): |
344 | 365 | logger.error(f"Failed to parse function arguments: {function_args}") |
345 | 366 |
|
346 | 367 | second_response = await acompletion( |
347 | | - model=model_name, |
348 | | - stream=True, |
349 | | - messages=messages, |
| 368 | + **_build_completion_params( |
| 369 | + model_name, |
| 370 | + stream=True, |
| 371 | + messages=messages, |
| 372 | + ) |
350 | 373 | ) |
351 | 374 | logger.debug(f"Second LLM response: {second_response}") |
352 | 375 |
|
|
0 commit comments