Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions swift/infer_engine/infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def _parse_stream_data(data: bytes) -> Optional[str]:
data = data.strip()
if len(data) == 0:
return
assert data.startswith('data:'), f'data: {data}'
return data[5:].strip()
if data.startswith('data:'):
return data[5:].strip()
return data

async def infer_async(
self,
Expand All @@ -138,6 +139,15 @@ async def infer_async(
async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp:
if resp.status >= 400 or resp.content_type != 'text/event-stream':
data = await resp.text()
try:
resp_obj = json.loads(data)
except json.JSONDecodeError:
raise HTTPError(data)
if resp_obj.get('object') == 'error':
raise HTTPError(resp_obj['message'])
raise HTTPError(data)
async for data in resp.content:
data = self._parse_stream_data(data)
if data == '[DONE]':
Expand Down
6 changes: 3 additions & 3 deletions swift/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ async def _run_async_iter():
async for item in await async_iter:
queue.put(item)
except Exception as e:
if getattr(self, 'strict', True):
raise
queue.put(e)
else:
finally:
queue.put(None)

try:
Expand All @@ -103,6 +101,8 @@ async def _run_async_iter():
if output is None or isinstance(output, Exception):
prog_bar.update()
self._update_metrics(pre_output, metrics)
if isinstance(output, Exception) and getattr(self, 'strict', True):
raise output
return
pre_output = output
yield output
Expand Down
62 changes: 58 additions & 4 deletions swift/infer_engine/transformers_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,41 @@ def _update_batched_logprobs(batched_logprobs: List[torch.Tensor], logits_stream
for logprobs, new_logprobs in zip(batched_logprobs, new_batched_logprobs):
logprobs += new_logprobs

@staticmethod
def _extract_reasoning_content(response: str) -> tuple[Optional[str], str]:
if '<think>' not in response:
return None, response
_, suffix = response.split('<think>', 1)
if '</think>' not in suffix:
return suffix.lstrip('\n'), ''
reasoning_content, content = suffix.split('</think>', 1)
return reasoning_content.lstrip('\n'), content.lstrip('\n')

@classmethod
def _extract_reasoning_delta(cls, previous_text: str, current_text: str) -> tuple[Optional[str], Optional[str]]:
previous_reasoning, previous_content = cls._extract_reasoning_content(previous_text)
current_reasoning, current_content = cls._extract_reasoning_content(current_text)

delta_reasoning_content = None
if current_reasoning is not None:
previous_reasoning = previous_reasoning or ''
if current_reasoning.startswith(previous_reasoning):
delta_reasoning_content = current_reasoning[len(previous_reasoning):]
else:
delta_reasoning_content = current_reasoning
if not delta_reasoning_content:
delta_reasoning_content = None

delta_content = None
if current_content:
if current_content.startswith(previous_content):
delta_content = current_content[len(previous_content):]
else:
delta_content = current_content
if not delta_content:
delta_content = None
return delta_reasoning_content, delta_content

def _infer_stream(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig,
adapter_request: Optional[AdapterRequest], request_config: RequestConfig,
**kwargs) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]:
Expand Down Expand Up @@ -251,6 +286,7 @@ def _model_generate(**kwargs):
infer_streamers = [InferStreamer(self.template) for _ in range(batch_size)]
request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)]
token_idxs = [0] * batch_size
response_texts = [''] * batch_size

raw_batched_generate_ids = None # or torch.Tensor: [batch_size, seq_len]
batched_logprobs = [[] for _ in range(batch_size)]
Expand Down Expand Up @@ -295,17 +331,30 @@ def _model_generate(**kwargs):
logprobs = self._get_logprobs(logprobs_list, generate_ids[token_idxs[i]:], request_config.top_logprobs)
token_idxs[i] = len(generate_ids)

previous_text = response_texts[i]
response_texts[i] = previous_text + (delta_text or '')
delta_reasoning_content, delta_content = self._extract_reasoning_delta(previous_text, response_texts[i])
if not delta_content and not delta_reasoning_content and not is_finished[i]:
res.append(None)
continue

usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
toolcall = None
if is_finished[i]:
toolcall = self._get_toolcall(self.template.decode(generate_ids))
response = self.template.decode(generate_ids)
_, content = self._extract_reasoning_content(response)
toolcall = self._get_toolcall(content or response)
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, usage_info.completion_tokens,
is_finished[i])

choices = [
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
delta=DeltaMessage(
role='assistant',
content=delta_content,
reasoning_content=delta_reasoning_content,
tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
Expand Down Expand Up @@ -423,13 +472,18 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo
logprobs = self._get_logprobs(logprobs_list, generate_ids, request_config.top_logprobs)
usage_info = self._update_usage_info(usage_info, len(generate_ids))
response = self.template.decode(generate_ids, template_inputs=template_inputs[i])
reasoning_content, content = self._extract_reasoning_content(response)
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True)
toolcall = self._get_toolcall(response)
toolcall = self._get_toolcall(content or response)
token_ids = generate_ids if request_config.return_details else None
choices.append(
ChatCompletionResponseChoice(
index=j,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
message=ChatMessage(
role='assistant',
content=content,
reasoning_content=reasoning_content,
tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs,
token_ids=token_ids))
Expand Down
9 changes: 8 additions & 1 deletion swift/pipelines/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,16 @@ def _post_process(self, request_info, response, return_cmpl_response: bool = Fal

is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices)))
if 'stream' in response.__class__.__name__.lower():
request_info['response'] += response.choices[0].delta.content
delta = response.choices[0].delta
if delta.content:
request_info['response'] += delta.content
if getattr(delta, 'reasoning_content', None):
request_info.setdefault('reasoning_content', '')
request_info['reasoning_content'] += delta.reasoning_content
else:
request_info['response'] = response.choices[0].message.content
if getattr(response.choices[0].message, 'reasoning_content', None):
request_info['reasoning_content'] = response.choices[0].message.reasoning_content
if return_cmpl_response:
response = response.to_cmpl_response()
if is_finished:
Expand Down
34 changes: 32 additions & 2 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class LLMInfer(BaseUI):
'en': 'Port'
},
},
'api_key': {
'label': {
'zh': '接口token',
'en': 'API key'
},
'info': {
'zh': '部署服务使用的API key,聊天时会自动复用',
'en': 'API key used by the deployed service and reused for chat requests'
}
},
'llm_infer': {
'label': {
'zh': 'LLM推理',
Expand Down Expand Up @@ -140,6 +150,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
scale=8)
infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
gr.Textbox(elem_id='port', lines=1, value='8000', scale=4)
gr.Textbox(elem_id='api_key', lines=1, scale=6)
chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
with gr.Row(equal_height=True):
prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
Expand Down Expand Up @@ -388,12 +399,16 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt

_, args = Runtime.parse_info_from_cmdline(running_task)
if 'port' not in args:
raise gr.Error('Please select a valid running deployment first.')
request_config = RequestConfig(
temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
request_config.stream = True
request_config.stop = ['Observation:']
request_config.max_tokens = max_new_tokens
stream_resp_with_history = ''
stream_reasoning_content = ''
stream_response_content = ''
response = ''
i = len(infer_request.messages) - 1
for i in range(len(infer_request.messages) - 1, -1, -1):
Expand All @@ -412,14 +427,29 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
if infer_model_type:
model_kwargs = {'model': infer_model_type}
gen_list = InferClient(
port=args['port'], ).infer(
port=args['port'],
api_key=args.get('api_key', 'EMPTY'),
).infer(
infer_requests=[_infer_request], request_config=request_config, **model_kwargs)
if infer_request.messages[-1]['role'] != 'assistant':
infer_request.messages.append({'role': 'assistant', 'content': ''})
for chunk in gen_list[0]:
if chunk is None:
continue
stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text
if chat:
delta = chunk.choices[0].delta
if delta.reasoning_content:
stream_reasoning_content += delta.reasoning_content
if delta.content:
stream_response_content += delta.content
if stream_reasoning_content and stream_response_content:
stream_resp_with_history = f'<think>\n{stream_reasoning_content}</think>\n{stream_response_content}'
elif stream_reasoning_content:
stream_resp_with_history = f'<think>\n{stream_reasoning_content}'
else:
stream_resp_with_history = stream_response_content
else:
stream_resp_with_history += chunk.choices[0].text
infer_request.messages[-1]['content'] = stream_resp_with_history
chatbot_content = cls._replace_tag_with_media(infer_request)
chatbot_content = cls.parse_text(chatbot_content)
Expand Down
70 changes: 54 additions & 16 deletions swift/ui/llm_infer/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os.path
import psutil
import re
import shlex
import subprocess
import sys
import time
Expand Down Expand Up @@ -122,8 +123,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
def break_log_event(cls, task):
if not task:
return
pid, all_args = cls.parse_info_from_cmdline(task)
cls.log_event[all_args['log_file']] = True
_, all_args = cls.parse_info_from_cmdline(task)
log_file = all_args.get('log_file')
if not log_file:
return
cls.log_event[log_file] = True

@classmethod
def update_log(cls):
Expand All @@ -134,7 +138,9 @@ def wait(cls, task):
if not task:
return [None]
_, args = cls.parse_info_from_cmdline(task)
log_file = args['log_file']
log_file = args.get('log_file')
if not log_file:
return [None]
cls.log_event[log_file] = False
offset = 0
latest_data = ''
Expand Down Expand Up @@ -230,29 +236,61 @@ def construct_running_task(proc):
@classmethod
def parse_info_from_cmdline(cls, task):
pid = None
for i in range(3):
slash = task.find('/')
if i == 0:
pid = task[:slash].split(':')[1]
task = task[slash + 1:]
args = task.split(f'swift {cls.cmd}')[1]
args = [arg.strip() for arg in args.split('--') if arg.strip()]
if not isinstance(task, str) or not task:
return pid, {}

pid_match = re.search(r'(?:^|/)pid:(\d+)', task)
if pid_match:
pid = pid_match.group(1)

cmdline = task.split('/cmd:', 1)[1] if '/cmd:' in task else task
args = None
if f'swift {cls.cmd}' in cmdline:
args = cmdline.split(f'swift {cls.cmd}', 1)[1]
else:
deploy_match = re.search(rf'\S*{re.escape(cls.cmd)}\.py(?=\s|$)', cmdline)
if deploy_match:
args = cmdline[deploy_match.end():]
if args is None:
return pid, {}

try:
tokens = shlex.split(args)
except ValueError:
return pid, {}

all_args = {}
for i in range(len(args)):
space = args[i].find(' ')
splits = args[i][:space], args[i][space + 1:]
all_args[splits[0]] = splits[1]
i = 0
while i < len(tokens):
token = tokens[i]
if not token.startswith('--'):
i += 1
continue
key = token[2:]
i += 1
values = []
while i < len(tokens) and not tokens[i].startswith('--'):
values.append(tokens[i])
i += 1
all_args[key] = ' '.join(values) if values else 'true'
return pid, all_args

@classmethod
def kill_task(cls, task):
if task:
pid, all_args = cls.parse_info_from_cmdline(task)
log_file = all_args['log_file']
log_file = all_args.get('log_file')
if sys.platform == 'win32':
if not pid:
return [cls.refresh_tasks()] + [gr.update(value=None)]
command = ['taskkill', '/f', '/t', '/pid', pid]
else:
command = ['pkill', '-9', '-f', log_file]
if log_file:
command = ['pkill', '-9', '-f', log_file]
elif pid:
command = ['kill', '-9', pid]
else:
return [cls.refresh_tasks()] + [gr.update(value=None)]
try:
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0
Expand Down