Skip to content

Commit 3a9fa87

Browse files
Merge pull request #565 from MervinPraison/claude/issue-97-20250531_205330
fix: implement litellm handling in code.py same as llm.py
2 parents 6fc50bd + 1f7440a commit 3a9fa87

1 file changed

Lines changed: 32 additions & 9 deletions

File tree

src/praisonai/praisonai/ui/code.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from chainlit.types import ThreadDict
2121
import chainlit.data as cl_data
2222
from litellm import acompletion
23+
import litellm
2324
from db import DatabaseManager
2425

2526
# Load environment variables
@@ -40,6 +41,15 @@
4041
# Set the logging level for the logger
4142
logger.setLevel(log_level)
4243

44+
# Configure litellm same as in llm.py
45+
litellm.set_verbose = False
46+
litellm.success_callback = []
47+
litellm._async_success_callback = []
48+
litellm.callbacks = []
49+
litellm.drop_params = True
50+
litellm.modify_params = True
51+
litellm.suppress_debug_messages = True
52+
4353
CHAINLIT_AUTH_SECRET = os.getenv("CHAINLIT_AUTH_SECRET")
4454

4555
if not CHAINLIT_AUTH_SECRET:
@@ -55,6 +65,17 @@
5565

5666
deleted_thread_ids = [] # type: List[str]
5767

68+
def _build_completion_params(model_name, **override_params):
69+
"""Build parameters for litellm completion calls with proper model handling"""
70+
params = {
71+
"model": model_name,
72+
}
73+
74+
# Override with any provided parameters
75+
params.update(override_params)
76+
77+
return params
78+
5879
def save_setting(key: str, value: str):
5980
"""Saves a setting to the database.
6081
@@ -237,12 +258,12 @@ async def main(message: cl.Message):
237258
msg = cl.Message(content="")
238259
await msg.send()
239260

240-
# Prepare the completion parameters
241-
completion_params = {
242-
"model": model_name,
243-
"messages": message_history,
244-
"stream": True,
245-
}
261+
# Prepare the completion parameters using the helper function
262+
completion_params = _build_completion_params(
263+
model_name,
264+
messages=message_history,
265+
stream=True,
266+
)
246267

247268
# If an image is uploaded, include it in the message
248269
if image:
@@ -344,9 +365,11 @@ async def main(message: cl.Message):
344365
logger.error(f"Failed to parse function arguments: {function_args}")
345366

346367
second_response = await acompletion(
347-
model=model_name,
348-
stream=True,
349-
messages=messages,
368+
**_build_completion_params(
369+
model_name,
370+
stream=True,
371+
messages=messages,
372+
)
350373
)
351374
logger.debug(f"Second LLM response: {second_response}")
352375

0 commit comments

Comments
 (0)