diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 8d873c5081..f4ec875588 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -86,8 +86,67 @@ 'MistralForCausalLM': 'lm_head', } +STR_TO_TORCH_DTYPE = { + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + 'float32': torch.float32 +} + +TORCH_DTYPE_TO_STR = { + torch.float16: 'float16', + torch.bfloat16: 'bfloat16', + torch.float32: 'float32' +} + +def _set_use_cache(model): + model.config.use_cache = False + if hasattr(model.config, 'text_config'): + model.config.text_config.use_cache = False + elif hasattr(model.config, 'llm_config'): + model.config.llm_config.use_cache = False + + +def _get_torch_dtype(config): + def _resolve_dtype(config): + dtype = getattr(config, 'torch_dtype', None) + if dtype is None: + dtype = getattr(config, 'dtype', None) + return dtype + + dtype = _resolve_dtype(config) + + if hasattr(config, 'text_config'): + sub_dtype = _resolve_dtype(config.text_config) + if sub_dtype is not None: + dtype = sub_dtype + elif hasattr(config, 'llm_config'): + sub_dtype = _resolve_dtype(config.llm_config) + if sub_dtype is not None: + dtype = sub_dtype + + if dtype is None: + dtype = 'bfloat16' + + if isinstance(dtype, torch.dtype): + return dtype + return STR_TO_TORCH_DTYPE[dtype] + + +def _set_config_dtype(model, torch_dtype): + dtype = TORCH_DTYPE_TO_STR[torch_dtype] + configs = [model.config] + + for name in ['text_config', 'llm_config', 'vision_config', 'ts_config']: + sub_config = getattr(model.config, name, None) + if sub_config is not None: + configs.append(sub_config) + + for config in configs: + if hasattr(config, 'dtype'): + config.dtype = dtype + -def check_vl_llm(backend: str, config: dict) -> bool: +def check_vl_llm(config: dict) -> bool: """Check if the model is a vl model from model config.""" if 'auto_map' in config: for _, v in config['auto_map'].items(): @@ -121,11 +180,11 @@ def check_vl_llm(backend: str, config: dict) -> bool: return False -def get_task(backend: str, model_path: str): +def get_task(model_path: str, trust_remote_code: bool = False) -> str: """Get pipeline type and pipeline class from model config.""" - _, config = get_model_arch(model_path) - if check_vl_llm(backend, config.to_dict()): + _, config = get_model_arch(model_path, trust_remote_code) + if check_vl_llm(config.to_dict()): return 'vlm' # default task @@ -203,11 +262,11 @@ class name or the class type itself. # TODO to be removed -def make_compatible_internvl_config(model_path): +def make_compatible_internvl_config(model_path, trust_remote_code: bool = False): """Patch model.config since after transformers v4.45.0, InternVL models can't use `save_pretrained`""" from lmdeploy.archs import get_model_arch - arch, _ = get_model_arch(model_path) + arch, _ = get_model_arch(model_path, trust_remote_code) if arch == 'InternVLChatModel': import transformers from packaging import version @@ -257,8 +316,8 @@ def load_model_and_tokenizer(model: str, work_dir: str = './work_dir', trust_remote_code: bool = False): """Load model and tokenizer.""" - model_type = get_task(backend='turbomind', model_path=model) - make_compatible_internvl_config(model) + model_type = get_task(model, trust_remote_code) + make_compatible_internvl_config(model, trust_remote_code) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code) @@ -276,14 +335,12 @@ def load_model_and_tokenizer(model: str, model = vl_model.language_model if hasattr(vl_model, 'llm'): # MiniCPMV, ... model = vl_model.llm - model.config.use_cache = False - if hasattr(model.config, 'text_config'): - model.config.text_config.use_cache = False - elif hasattr(model.config, 'llm_config'): - model.config.llm_config.use_cache = False - if dtype == 'float16' or (dtype == 'auto' and original_config.torch_dtype == torch.float16): + _set_use_cache(model) + torch_dtype = _get_torch_dtype(original_config) + _set_config_dtype(model, torch_dtype) + if dtype == 'float16' or (dtype == 'auto' and torch_dtype == torch.float16): model.half() - elif dtype == 'bfloat16' or (dtype == 'auto' and original_config.torch_dtype == torch.bfloat16): + elif dtype == 'bfloat16' or (dtype == 'auto' and torch_dtype == torch.bfloat16): assert torch.cuda.is_bf16_supported( ), 'your device does not support bfloat16 please set --dtype float16' # noqa model.to(torch.bfloat16) diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index a7d4365044..33db25ec16 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -100,7 +100,6 @@ def smooth_quant(model: str, patterns = [] skipped_modules = [] - arch = model.config.architectures[0] rebuilder = MODELS.get(arch) if rebuilder: patterns = rebuilder.skipped_modules()