Skip to content

Commit c1b9150

Browse files
committed
fix: refactor _convert_delta_to_message_chunk and update _convert_chunk_to_generation_chunk for improved functionality
1 parent c34d7a3 commit c1b9150

File tree

1 file changed

+119
-113
lines changed

1 file changed

+119
-113
lines changed

apps/models_provider/impl/base_chat_open_ai.py

Lines changed: 119 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -13,74 +13,71 @@
1313
from langchain_core.runnables import RunnableConfig, ensure_config
1414
from langchain_core.tools import BaseTool
1515
from langchain_openai import ChatOpenAI
16-
from langchain_openai.chat_models.base import _create_usage_metadata, _convert_delta_to_message_chunk
16+
from langchain_openai.chat_models.base import _create_usage_metadata
1717

1818
from common.config.tokenizer_manage_config import TokenizerManage
1919
from common.utils.logger import maxkb_logger
2020

2121
def custom_get_token_ids(text: str):
2222
tokenizer = TokenizerManage.get_tokenizer()
2323
return tokenizer.encode(text)
24-
#
25-
#
26-
# def _convert_delta_to_message_chunk(
27-
# _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
28-
# ) -> BaseMessageChunk:
29-
# id_ = _dict.get("id")
30-
# role = cast(str, _dict.get("role"))
31-
# content = cast(str, _dict.get("content") or "")
32-
# additional_kwargs: dict = {}
33-
# if 'reasoning_content' in _dict:
34-
# additional_kwargs['reasoning_content'] = _dict.get('reasoning_content')
35-
# if _dict.get("function_call"):
36-
# function_call = dict(_dict["function_call"])
37-
# if "name" in function_call and function_call["name"] is None:
38-
# function_call["name"] = ""
39-
# additional_kwargs["function_call"] = function_call
40-
# tool_call_chunks = []
41-
# if raw_tool_calls := _dict.get("tool_calls"):
42-
# additional_kwargs["tool_calls"] = raw_tool_calls
43-
# try:
44-
# tool_call_chunks = [
45-
# tool_call_chunk(
46-
# name=rtc["function"].get("name"),
47-
# args=rtc["function"].get("arguments"),
48-
# id=rtc.get("id"),
49-
# index=rtc["index"],
50-
# )
51-
# for rtc in raw_tool_calls
52-
# ]
53-
# except KeyError:
54-
# pass
55-
#
56-
# if role == "user" or default_class == HumanMessageChunk:
57-
# return HumanMessageChunk(content=content, id=id_)
58-
# elif role == "assistant" or default_class == AIMessageChunk:
59-
# return AIMessageChunk(
60-
# content=content,
61-
# additional_kwargs=additional_kwargs,
62-
# id=id_,
63-
# tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
64-
# )
65-
# elif role in ("system", "developer") or default_class == SystemMessageChunk:
66-
# if role == "developer":
67-
# additional_kwargs = {"__openai_role__": "developer"}
68-
# else:
69-
# additional_kwargs = {}
70-
# return SystemMessageChunk(
71-
# content=content, id=id_, additional_kwargs=additional_kwargs
72-
# )
73-
# elif role == "function" or default_class == FunctionMessageChunk:
74-
# return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
75-
# elif role == "tool" or default_class == ToolMessageChunk:
76-
# return ToolMessageChunk(
77-
# content=content, tool_call_id=_dict["tool_call_id"], id=id_
78-
# )
79-
# elif role or default_class == ChatMessageChunk:
80-
# return ChatMessageChunk(content=content, role=role, id=id_)
81-
# else:
82-
# return default_class(content=content, id=id_) # type: ignore
83-
#
24+
25+
def _convert_delta_to_message_chunk(
26+
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
27+
) -> BaseMessageChunk:
28+
"""Convert to a LangChain message chunk."""
29+
id_ = _dict.get("id")
30+
role = cast(str, _dict.get("role"))
31+
content = cast(str, _dict.get("content") or "")
32+
additional_kwargs: dict = {}
33+
if 'reasoning_content' in _dict:
34+
additional_kwargs['reasoning_content'] = _dict.get('reasoning_content')
35+
if _dict.get("function_call"):
36+
function_call = dict(_dict["function_call"])
37+
if "name" in function_call and function_call["name"] is None:
38+
function_call["name"] = ""
39+
additional_kwargs["function_call"] = function_call
40+
tool_call_chunks = []
41+
if raw_tool_calls := _dict.get("tool_calls"):
42+
try:
43+
tool_call_chunks = [
44+
tool_call_chunk(
45+
name=rtc["function"].get("name"),
46+
args=rtc["function"].get("arguments"),
47+
id=rtc.get("id"),
48+
index=rtc["index"],
49+
)
50+
for rtc in raw_tool_calls
51+
]
52+
except KeyError:
53+
pass
54+
55+
if role == "user" or default_class == HumanMessageChunk:
56+
return HumanMessageChunk(content=content, id=id_)
57+
if role == "assistant" or default_class == AIMessageChunk:
58+
return AIMessageChunk(
59+
content=content,
60+
additional_kwargs=additional_kwargs,
61+
id=id_,
62+
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
63+
)
64+
if role in ("system", "developer") or default_class == SystemMessageChunk:
65+
if role == "developer":
66+
additional_kwargs = {"__openai_role__": "developer"}
67+
else:
68+
additional_kwargs = {}
69+
return SystemMessageChunk(
70+
content=content, id=id_, additional_kwargs=additional_kwargs
71+
)
72+
if role == "function" or default_class == FunctionMessageChunk:
73+
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
74+
if role == "tool" or default_class == ToolMessageChunk:
75+
return ToolMessageChunk(
76+
content=content, tool_call_id=_dict["tool_call_id"], id=id_
77+
)
78+
if role or default_class == ChatMessageChunk:
79+
return ChatMessageChunk(content=content, role=role, id=id_)
80+
return default_class(content=content, id=id_) # type: ignore[call-arg]#
8481

8582
class BaseChatOpenAI(ChatOpenAI):
8683
usage_metadata: dict = {}
@@ -131,58 +128,67 @@ def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
131128
self.usage_metadata = chunk.message.usage_metadata
132129
yield chunk
133130

134-
# def _convert_chunk_to_generation_chunk(
135-
# self,
136-
# chunk: dict,
137-
# default_chunk_class: type,
138-
# base_generation_info: Optional[dict],
139-
# ) -> Optional[ChatGenerationChunk]:
140-
# if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
141-
# return None
142-
# token_usage = chunk.get("usage")
143-
# choices = (
144-
# chunk.get("choices", [])
145-
# # from beta.chat.completions.stream
146-
# or chunk.get("chunk", {}).get("choices", [])
147-
# )
148-
#
149-
# usage_metadata: Optional[UsageMetadata] = (
150-
# _create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None
151-
# )
152-
# if len(choices) == 0:
153-
# # logprobs is implicitly None
154-
# generation_chunk = ChatGenerationChunk(
155-
# message=default_chunk_class(content="", usage_metadata=usage_metadata)
156-
# )
157-
# return generation_chunk
158-
#
159-
# choice = choices[0]
160-
# if choice["delta"] is None:
161-
# return None
162-
#
163-
# message_chunk = _convert_delta_to_message_chunk(
164-
# choice["delta"], default_chunk_class
165-
# )
166-
# generation_info = {**base_generation_info} if base_generation_info else {}
167-
#
168-
# if finish_reason := choice.get("finish_reason"):
169-
# generation_info["finish_reason"] = finish_reason
170-
# if model_name := chunk.get("model"):
171-
# generation_info["model_name"] = model_name
172-
# if system_fingerprint := chunk.get("system_fingerprint"):
173-
# generation_info["system_fingerprint"] = system_fingerprint
174-
#
175-
# logprobs = choice.get("logprobs")
176-
# if logprobs:
177-
# generation_info["logprobs"] = logprobs
178-
#
179-
# if usage_metadata and isinstance(message_chunk, AIMessageChunk):
180-
# message_chunk.usage_metadata = usage_metadata
181-
#
182-
# generation_chunk = ChatGenerationChunk(
183-
# message=message_chunk, generation_info=generation_info or None
184-
# )
185-
# return generation_chunk
131+
def _convert_chunk_to_generation_chunk(
132+
self,
133+
chunk: dict,
134+
default_chunk_class: type,
135+
base_generation_info: dict | None,
136+
) -> ChatGenerationChunk | None:
137+
if chunk.get("type") == "content.delta": # From beta.chat.completions.stream
138+
return None
139+
token_usage = chunk.get("usage")
140+
choices = (
141+
chunk.get("choices", [])
142+
# From beta.chat.completions.stream
143+
or chunk.get("chunk", {}).get("choices", [])
144+
)
145+
146+
usage_metadata: UsageMetadata | None = (
147+
_create_usage_metadata(token_usage, chunk.get("service_tier"))
148+
if token_usage
149+
else None
150+
)
151+
if len(choices) == 0:
152+
# logprobs is implicitly None
153+
generation_chunk = ChatGenerationChunk(
154+
message=default_chunk_class(content="", usage_metadata=usage_metadata),
155+
generation_info=base_generation_info,
156+
)
157+
if self.output_version == "v1":
158+
generation_chunk.message.content = []
159+
generation_chunk.message.response_metadata["output_version"] = "v1"
160+
161+
return generation_chunk
162+
163+
choice = choices[0]
164+
if choice["delta"] is None:
165+
return None
166+
167+
message_chunk = _convert_delta_to_message_chunk(
168+
choice["delta"], default_chunk_class
169+
)
170+
generation_info = {**base_generation_info} if base_generation_info else {}
171+
172+
if finish_reason := choice.get("finish_reason"):
173+
generation_info["finish_reason"] = finish_reason
174+
if model_name := chunk.get("model"):
175+
generation_info["model_name"] = model_name
176+
if system_fingerprint := chunk.get("system_fingerprint"):
177+
generation_info["system_fingerprint"] = system_fingerprint
178+
if service_tier := chunk.get("service_tier"):
179+
generation_info["service_tier"] = service_tier
180+
181+
logprobs = choice.get("logprobs")
182+
if logprobs:
183+
generation_info["logprobs"] = logprobs
184+
185+
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
186+
message_chunk.usage_metadata = usage_metadata
187+
188+
message_chunk.response_metadata["model_provider"] = "openai"
189+
return ChatGenerationChunk(
190+
message=message_chunk, generation_info=generation_info or None
191+
)
186192

187193
def invoke(
188194
self,

0 commit comments

Comments
 (0)