Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sortedcontainers>=1.5.9
tensorboard
tiktoken
tqdm
transformers>=4.33,<5.8.0
transformers>=4.33,<5.9.0
transformers_stream_generator
trl>=0.15,<1.0
uvicorn
Expand Down
30 changes: 27 additions & 3 deletions swift/infer_engine/transformers_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import torch
import torch.nn.functional as F
from contextlib import contextmanager
from copy import deepcopy
from PIL import Image
from queue import Queue
Expand All @@ -21,7 +22,7 @@
from swift.model import get_model_processor
from swift.template import Template
from swift.tuners import Swift
from swift.utils import get_last_valid_indices, safe_snapshot_download, to_device
from swift.utils import get_last_valid_indices, safe_snapshot_download, to_device, use_hf_hub
from .infer_engine import InferEngine
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
Expand Down Expand Up @@ -240,7 +241,8 @@ def _infer_stream(self, inputs: Dict[str, Any], *, generation_config: Generation
def _model_generate(**kwargs):
if is_torch_npu_available():
torch.npu.set_device(self.model.device)
self.template.generate(self.model, **kwargs)
with self._patch_kernels():
self.template.generate(self.model, **kwargs)

generate_kwargs = self.template.prepare_generate_kwargs(generate_kwargs, model=self.model)
thread = Thread(target=_model_generate, kwargs=generate_kwargs)
Expand Down Expand Up @@ -384,6 +386,27 @@ def _infer_forward(self, inputs: Dict[str, Any], adapter_request: Optional[Adapt
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
return res

@contextmanager
def _patch_kernels(self):
use_hf = self.use_hf if self.use_hf is not None else use_hf_hub()
if use_hf:
yield
return

try:
from modelscope import patch_hub, unpatch_hub
except ImportError:
yield
return
try:
patch_hub()
except AttributeError:
pass
try:
yield
finally:
unpatch_hub()
Comment on lines +390 to +408
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _patch_kernels context manager has two significant issues:

  1. Exception Safety: It lacks a try...finally block around the yield. If an exception occurs during generation (e.g., in self.template.generate), unpatch_hub() will never be called, leaving the global environment in a patched state.
  2. Thread Safety: patch_hub and unpatch_hub modify global state in modelscope/transformers. Since _infer_stream executes generation in a separate thread, concurrent requests can lead to race conditions where one thread's cleanup unpatches the hub while another thread is still using it.

Consider using a try...finally block for safety. For thread safety, if this engine is used in a concurrent environment, a global lock and reference counter might be necessary, or simply patching once during initialization if the mode is consistent for the process.

    def _patch_kernels(self):
        use_hf = self.use_hf if self.use_hf is not None else use_hf_hub()
        if use_hf:
            yield
            return

        try:
            from modelscope import patch_hub, unpatch_hub
        except ImportError:
            yield
            return

        patch_hub()
        try:
            yield
        finally:
            unpatch_hub()


def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig,
adapter_request: Optional[AdapterRequest], request_config: RequestConfig,
template_inputs) -> List[ChatCompletionResponse]:
Expand All @@ -394,7 +417,8 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo
generate_kwargs['adapter_names'] = adapter_names
num_prompt_tokens = self._get_num_tokens(inputs)
generate_kwargs = self.template.prepare_generate_kwargs(generate_kwargs, model=self.model)
output = dict(self.template.generate(self.model, **generate_kwargs))
with self._patch_kernels():
output = dict(self.template.generate(self.model, **generate_kwargs))
output.pop('past_key_values', None)
batched_generate_ids = output['sequences']
batched_generate_ids = self.template.get_generate_ids(batched_generate_ids, num_prompt_tokens)
Expand Down
Loading