|
8 | 8 | from lmdeploy.messages import PytorchEngineConfig |
9 | 9 | from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend |
10 | 10 | from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value |
11 | | -from lmdeploy.utils import get_logger |
| 11 | +from lmdeploy.utils import get_logger, is_bf16_supported |
12 | 12 |
|
13 | 13 | logger = get_logger('lmdeploy') |
14 | 14 |
|
15 | 15 |
|
16 | | -def _update_torch_dtype(config: 'ModelConfig', dtype: str): |
| 16 | +def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'): |
17 | 17 | """Update the torch dtype from the model config. |
18 | 18 |
|
19 | 19 | Args: |
20 | 20 | config (ModelConfig): The input model config. |
21 | 21 | dtype (str): user specified data type. Refer to |
22 | 22 | `PyTorchEngineConfig.dtype` for detailed info |
| 23 | + device_type (str): The device type. Refer to `PyTorchEngineConfig.device_type` for detailed info |
23 | 24 | """ |
24 | 25 | quantization_config = getattr(config.hf_config, 'quantization_config', dict()) |
25 | 26 | quant_method = quantization_config.get('quant_method', None) |
@@ -48,6 +49,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): |
48 | 49 | # update hf_config as well |
49 | 50 | setattr(config.hf_config, 'torch_dtype', torch_dtype) |
50 | 51 | else: |
| 52 | + if torch_dtype == 'bfloat16' and not is_bf16_supported(device_type): |
| 53 | + torch_dtype = 'float16' |
51 | 54 | # change to user specified data type if it is not 'auto' |
52 | 55 | if dtype == 'auto': |
53 | 56 | torch_dtype = torch_dtype if torch_dtype in ['float16', 'bfloat16'] else 'float16' |
@@ -356,6 +359,7 @@ def from_pretrained( |
356 | 359 | is_draft_model: bool = False, |
357 | 360 | spec_method: str = None, |
358 | 361 | model_format: str = None, |
| 362 | + device_type: str = 'auto', |
359 | 363 | ): |
360 | 364 | """Instantiate one of the configuration classes of the library from a |
361 | 365 | pretrained model configuration. |
@@ -386,6 +390,7 @@ def from_pretrained( |
386 | 390 | dist_config=dist_config, |
387 | 391 | is_draft_model=is_draft_model, |
388 | 392 | spec_method=spec_method, |
| 393 | + device_type=device_type, |
389 | 394 | ) |
390 | 395 | fp32_lm_head = False |
391 | 396 | if hf_overrides is not None: |
@@ -413,6 +418,7 @@ def from_hf_config( |
413 | 418 | dist_config: DistConfig = None, |
414 | 419 | is_draft_model: bool = False, |
415 | 420 | spec_method: str = None, |
| 421 | + device_type: str = 'auto', |
416 | 422 | ): |
417 | 423 | """From huggingface config.""" |
418 | 424 | from lmdeploy.pytorch.configurations import AutoModelConfigBuilder |
@@ -441,7 +447,7 @@ def from_hf_config( |
441 | 447 | assert tp % model_config.num_key_value_heads == 0 |
442 | 448 |
|
443 | 449 | # should after setting `hf_config` and `model_arch` attributes |
444 | | - model_config = _update_torch_dtype(model_config, dtype) |
| 450 | + model_config = _update_torch_dtype(model_config, dtype, device_type=device_type) |
445 | 451 |
|
446 | 452 | # update eos_token_id to list |
447 | 453 | if isinstance(model_config.eos_token_id, int): |
|
0 commit comments