|
1 | 1 | # coding=utf-8 |
2 | 2 | import base64 |
3 | 3 | from concurrent.futures import ThreadPoolExecutor |
4 | | -from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping |
| 4 | +from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping, AsyncIterator |
5 | 5 |
|
6 | 6 | from langchain_core.language_models import LanguageModelInput |
7 | 7 | from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \ |
@@ -102,7 +102,7 @@ def get_num_tokens_from_messages( |
102 | 102 | with ThreadPoolExecutor(max_workers=1) as executor: |
103 | 103 | future = executor.submit(super().get_num_tokens_from_messages, messages, tools) |
104 | 104 | try: |
105 | | - response = future.result() |
| 105 | + response = future.result(timeout=timeout) |
106 | 106 | maxkb_logger.info("请求成功(未超时)") |
107 | 107 | return response |
108 | 108 | except Exception as e: |
@@ -131,6 +131,13 @@ def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: |
131 | 131 | self.usage_metadata = chunk.message.usage_metadata |
132 | 132 | yield chunk |
133 | 133 |
|
| 134 | + async def _astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[ChatGenerationChunk]: |
| 135 | + kwargs['stream_usage'] = True |
| 136 | + async for chunk in super()._astream(*args, **kwargs): |
| 137 | + if chunk.message.usage_metadata is not None: |
| 138 | + self.usage_metadata = chunk.message.usage_metadata |
| 139 | + yield chunk |
| 140 | + |
134 | 141 | def _convert_chunk_to_generation_chunk( |
135 | 142 | self, |
136 | 143 | chunk: dict, |
|
0 commit comments