Skip to content

Commit 68b13ee

Browse files
Fix processing LLM response (#324)
* process dict response * fix * fix stream state
1 parent 432e9f8 commit 68b13ee

4 files changed

Lines changed: 10 additions & 5 deletions

File tree

lagent/actions/tmux_action.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _start(self) -> None:
112112
capture_output=True,
113113
text=True,
114114
)
115-
if result.returncode != 0 and result.stderr != 'duplicate session: terminus2\n':
115+
if result.returncode != 0 and result.stderr != f'duplicate session: {self._session_name}\n':
116116
raise RuntimeError(f"Failed to start tmux session: {result.stderr!r}")
117117
subprocess.run(
118118
f"tmux set-option -g history-limit {self._history_limit}",

lagent/actions/web_visitor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ async def _read_webpage(self, url: str, goal: str) -> str:
173173
llm_response = (
174174
llm_response.content
175175
if isinstance(llm_response, AgentMessage)
176-
else llm_response.choices[0].message.content
176+
else (
177+
llm_response['choices'][0]['message']['content']
178+
if isinstance(llm_response, dict)
179+
else llm_response.choices[0].message.content
180+
)
177181
)
178182
if not llm_response or len(llm_response) < 10:
179183
tool_response = tool_response[: int(len(tool_response) * 0.7)]

lagent/llms/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import json
33
import random
44
import traceback
5-
65
from typing import Dict, List, Optional, TypedDict, Union
76

87
import aiohttp
98

109
from lagent.llms.openai import AsyncGPTAPI
1110
from lagent.utils import get_logger
11+
1212
logger = get_logger()
1313

1414

@@ -54,7 +54,7 @@ def __init__(
5454
self.max_tool_response_length = max_tool_response_length
5555
self.max_tool_calls_per_turn = max_tool_calls_per_turn
5656

57-
async def chat(self, messages: List[dict], tools=None, **gen_params) -> str:
57+
async def chat(self, messages: List[dict], tools=None, **gen_params) -> dict:
5858
"""Generate completion from a list of templates.
5959
6060
Args:

lagent/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def from_model_response(cls, response: Union[ChatCompletion, dict], sender: str)
136136
raw_content_logprobs=msg.get('raw_content_logprobs'),
137137
extra_info=msg.get('extra_info') or {},
138138
tool_calls=msg.get('tool_calls'),
139-
stream_state=(ModelStatusCode.SESSION_OUT_OF_LIMIT if finish_reason == 'length' else ModelStatusCode.END),
139+
stream_state=choice.get('stream_state')
140+
or (ModelStatusCode.SESSION_OUT_OF_LIMIT if finish_reason == 'length' else ModelStatusCode.END),
140141
finish_reason=finish_reason,
141142
)
142143

0 commit comments

Comments
 (0)