Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']:
optional_params[key] = value
if key == 'extra_body' and isinstance(value, dict):
optional_params = {**optional_params, **value}
else:
optional_params[key] = value
return optional_params


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
return chat_tong_yi
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)
238 changes: 134 additions & 104 deletions apps/setting/models_provider/impl/base_chat_open_ai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# coding=utf-8
import warnings
from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union
from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping

import openai
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \
SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata

from common.config.tokenizer_manage_config import TokenizerManage

Expand All @@ -19,14 +20,78 @@ def custom_get_token_ids(text: str):
return tokenizer.encode(text)


def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
reasoning_content = cast(str, _dict.get("reasoning_content") or "")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict = {'reasoning_content': reasoning_content}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
elif role in ("system", "developer") or default_class == SystemMessageChunk:
if role == "developer":
additional_kwargs = {"__openai_role__": "developer"}
else:
additional_kwargs = {}
return SystemMessageChunk(
content=content, id=id_, additional_kwargs=additional_kwargs
)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore


class BaseChatOpenAI(ChatOpenAI):
usage_metadata: dict = {}
custom_get_token_ids = custom_get_token_ids

def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
try:
return super().get_num_tokens_from_messages(messages)
Expand All @@ -44,114 +109,77 @@ def get_num_tokens(self, text: str) -> int:
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)

def _stream(
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
kwargs['stream_usage'] = True
for chunk in super()._stream(*args, **kwargs):
if chunk.message.usage_metadata is not None:
self.usage_metadata = chunk.message.usage_metadata
yield chunk

def _convert_chunk_to_generation_chunk(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
"""Set default stream_options."""
stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs)
# Note: stream_options is not a valid parameter for Azure OpenAI.
# To support users proxying Azure through ChatOpenAI, here we only specify
# stream_options if include_usage is set to True.
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}

payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}

if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
chunk: dict,
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
token_usage = chunk.get("usage")
choices = (
chunk.get("choices", [])
# from beta.chat.completions.stream
or chunk.get("chunk", {}).get("choices", [])
)

usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
)
if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()

generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue

# custom code
if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']:
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][
'reasoning_content']

default_chunk_class = generation_chunk.message.__class__
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
# custom code
if generation_chunk.message.usage_metadata is not None:
self.usage_metadata = generation_chunk.message.usage_metadata
yield generation_chunk

def _create_chat_result(self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None):
result = super()._create_chat_result(response, generation_info)
try:
reasoning_content = ''
reasoning_content_enable = False
for res in response.choices:
if 'reasoning_content' in res.message.model_extra:
reasoning_content_enable = True
_reasoning_content = res.message.model_extra.get('reasoning_content')
if _reasoning_content is not None:
reasoning_content += _reasoning_content
if reasoning_content_enable:
result.llm_output['reasoning_content'] = reasoning_content
except Exception as e:
pass
return result
return generation_chunk

choice = choices[0]
if choice["delta"] is None:
return None

message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if base_generation_info else {}

if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint

logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs

if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata

generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk

def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
"ChatGeneration",
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
Expand All @@ -162,7 +190,9 @@ def invoke(
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],

).message

self.usage_metadata = chat_result.response_metadata[
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
return chat_result
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)
return deepseek_chat_open_ai
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ def is_cache_model():
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],
openai_api_key=model_credential['api_key'],
model_name=model_name,
**optional_params
extra_body=optional_params,
)
return kimi_chat_open_ai
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import List, Dict

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -35,9 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming = False
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
base_url=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
extra_body=optional_params,
streaming=streaming,
custom_get_token_ids=custom_get_token_ids
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no apparent issues with the existing code, so here is a summary of potential improvements:

  1. The azure_chat_open_ai object can be assigned to the variable name that best represents its purpose in your context (e.g., chat_model).
  2. If possible, add additional exception handling to manage errors better during API requests.
  3. Consider using more descriptive parameter names than optional_params, such as extra_settings.
  4. It's recommended to use type hints consistently across the file and update them whenever you refactor the function parameters or return types.

Here is how the refactored code could be structured:

@@ -35,8 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
     if stream:
         streaming = False

-    # Using 'chat_model' instead of 'azure_chat_open_ai'
+    chat_model = OpenAIChatModel(
     model=model_name,
-    openai_api_base=model_credential.get('api_base'), 
-    openai_api_key=model_credential.get('api_key'),
+    base_url=model_credential.get('api_base'),  
+    api_key=model_credential.get('api_key'),          
# Update this parameter name depending on what it represents in your app    
       extra_body=optional_settings_dict,

        streaming=streaming,     
        custom_get_token_ids=custom_get_token_ids        
   ) 

This change provides clarity about the object being created and enhances readability throughout the codebase. Remember always to test these changes thoroughly after making modifications!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
return chat_tong_yi
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
return chat_tong_yi
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
Loading