diff --git a/requirements/framework.txt b/requirements/framework.txt index cc6f80636f..66ad03f7f4 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -32,7 +32,7 @@ sortedcontainers>=1.5.9 tensorboard tiktoken tqdm -transformers>=4.33,<5.8.0 +transformers>=4.33,<5.9.0 transformers_stream_generator trl>=0.15,<1.0 uvicorn diff --git a/swift/infer_engine/transformers_engine.py b/swift/infer_engine/transformers_engine.py index ac7e963717..6dbbe686e9 100644 --- a/swift/infer_engine/transformers_engine.py +++ b/swift/infer_engine/transformers_engine.py @@ -7,6 +7,7 @@ import time import torch import torch.nn.functional as F +from contextlib import contextmanager from copy import deepcopy from PIL import Image from queue import Queue @@ -21,7 +22,7 @@ from swift.model import get_model_processor from swift.template import Template from swift.tuners import Swift -from swift.utils import get_last_valid_indices, safe_snapshot_download, to_device +from swift.utils import get_last_valid_indices, safe_snapshot_download, to_device, use_hf_hub from .infer_engine import InferEngine from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, @@ -240,7 +241,8 @@ def _infer_stream(self, inputs: Dict[str, Any], *, generation_config: Generation def _model_generate(**kwargs): if is_torch_npu_available(): torch.npu.set_device(self.model.device) - self.template.generate(self.model, **kwargs) + with self._patch_kernels(): + self.template.generate(self.model, **kwargs) generate_kwargs = self.template.prepare_generate_kwargs(generate_kwargs, model=self.model) thread = Thread(target=_model_generate, kwargs=generate_kwargs) @@ -384,6 +386,27 @@ def _infer_forward(self, inputs: Dict[str, Any], adapter_request: Optional[Adapt res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) return res + @contextmanager + def _patch_kernels(self): + use_hf = self.use_hf if self.use_hf is not None else use_hf_hub() + if use_hf: + yield + return + + try: + from modelscope import patch_hub, unpatch_hub + except ImportError: + yield + return + try: + patch_hub() + except AttributeError: + pass + try: + yield + finally: + unpatch_hub() + def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig, adapter_request: Optional[AdapterRequest], request_config: RequestConfig, template_inputs) -> List[ChatCompletionResponse]: @@ -394,7 +417,8 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo generate_kwargs['adapter_names'] = adapter_names num_prompt_tokens = self._get_num_tokens(inputs) generate_kwargs = self.template.prepare_generate_kwargs(generate_kwargs, model=self.model) - output = dict(self.template.generate(self.model, **generate_kwargs)) + with self._patch_kernels(): + output = dict(self.template.generate(self.model, **generate_kwargs)) output.pop('past_key_values', None) batched_generate_ids = output['sequences'] batched_generate_ids = self.template.get_generate_ids(batched_generate_ids, num_prompt_tokens)