Skip to content

Commit fe04f31

Browse files
kouroshHakhaclaude
authored andcommitted
[serve][LLM] Fix HF config loading for models with custom rope_scaling (ray-project#62464)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 22e5efb commit fe04f31

2 files changed

Lines changed: 67 additions & 30 deletions

File tree

python/ray/llm/_internal/serve/core/configs/llm_config.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -256,50 +256,59 @@ def validate_server_cls(cls, value):
256256
_engine_config: EngineConfigType = PrivateAttr(None)
257257
_callback_instance: Optional[CallbackBase] = PrivateAttr(None)
258258

259-
def _infer_supports_vision(self, model_id_or_path: str) -> None:
260-
"""Called in llm node initializer together with other transformers calls. It
261-
loads the model config from huggingface and sets the supports_vision
262-
attribute based on whether the config has `vision_config`. All LVM models has
263-
`vision_config` setup.
259+
def _load_hf_config(self, model_id_or_path: str, trust_remote_code: bool = False):
260+
"""Load the HuggingFace config for a model.
261+
262+
Uses AutoConfig which loads the model-specific config class (e.g.
263+
DeepseekV3Config) instead of the generic PretrainedConfig. The generic
264+
base class can fail for models whose config.json contains fields (like
265+
``rope_scaling``) that require model-specific post-init logic.
264266
"""
265267
try:
266-
hf_config = transformers.PretrainedConfig.from_pretrained(model_id_or_path)
267-
self._supports_vision = hasattr(hf_config, "vision_config")
268+
return transformers.AutoConfig.from_pretrained(
269+
model_id_or_path, trust_remote_code=trust_remote_code
270+
)
268271
except Exception as e:
269272
raise ValueError(
270-
f"Failed to load Hugging Face config for model_id='{model_id_or_path}'.\
271-
Ensure `model_id` is a valid Hugging Face repo or a local path that \
272-
contains a valid `config.json` file. "
273-
f"Original error: {repr(e)}"
273+
f"Failed to load Hugging Face config for "
274+
f"model_id='{model_id_or_path}'. Ensure `model_id` is a valid "
275+
f"Hugging Face repo or a local path that contains a valid "
276+
f"`config.json` file. Original error: {repr(e)}"
274277
) from e
275278

279+
def _infer_supports_vision(
280+
self, model_id_or_path: str, trust_remote_code: bool = False
281+
) -> None:
282+
"""Called in llm node initializer together with other transformers calls. It
283+
loads the model config from huggingface and sets the supports_vision
284+
attribute based on whether the config has `vision_config`. All LVM models has
285+
`vision_config` setup.
286+
"""
287+
hf_config = self._load_hf_config(
288+
model_id_or_path, trust_remote_code=trust_remote_code
289+
)
290+
self._supports_vision = hasattr(hf_config, "vision_config")
291+
276292
def _set_model_architecture(
277293
self,
278294
model_id_or_path: Optional[str] = None,
279295
model_architecture: Optional[str] = None,
296+
trust_remote_code: bool = False,
280297
) -> None:
281298
"""Called in llm node initializer together with other transformers calls. It
282299
loads the model config from huggingface and sets the model_architecture
283300
attribute based on whether the config has `architectures`.
284301
"""
285302
if model_id_or_path:
286-
try:
287-
hf_config = transformers.PretrainedConfig.from_pretrained(
288-
model_id_or_path
289-
)
290-
if (
291-
hf_config
292-
and hasattr(hf_config, "architectures")
293-
and hf_config.architectures
294-
):
295-
self._model_architecture = hf_config.architectures[0]
296-
except Exception as e:
297-
raise ValueError(
298-
f"Failed to load Hugging Face config for model_id='{model_id_or_path}'.\
299-
Ensure `model_id` is a valid Hugging Face repo or a local path that \
300-
contains a valid `config.json` file. "
301-
f"Original error: {repr(e)}"
302-
) from e
303+
hf_config = self._load_hf_config(
304+
model_id_or_path, trust_remote_code=trust_remote_code
305+
)
306+
if (
307+
hf_config
308+
and hasattr(hf_config, "architectures")
309+
and hf_config.architectures
310+
):
311+
self._model_architecture = hf_config.architectures[0]
303312

304313
if model_architecture:
305314
self._model_architecture = model_architecture
@@ -308,8 +317,12 @@ def apply_checkpoint_info(
308317
self, model_id_or_path: str, trust_remote_code: bool = False
309318
) -> None:
310319
"""Apply the checkpoint info to the model config."""
311-
self._infer_supports_vision(model_id_or_path)
312-
self._set_model_architecture(model_id_or_path)
320+
self._infer_supports_vision(
321+
model_id_or_path, trust_remote_code=trust_remote_code
322+
)
323+
self._set_model_architecture(
324+
model_id_or_path, trust_remote_code=trust_remote_code
325+
)
313326

314327
def get_or_create_callback(self) -> Optional[CallbackBase]:
315328
"""Get or create the callback instance for this process.

python/ray/llm/tests/serve/cpu/configs/test_models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
from pathlib import Path
3+
from unittest.mock import MagicMock, patch
34

45
import pydantic
56
import pytest
@@ -418,5 +419,28 @@ def test_requires_deferred_placement_group(self):
418419
assert tpu_accel_with_topo.requires_deferred_placement_group is True
419420

420421

422+
class TestCheckpointInfo:
423+
def test_apply_checkpoint_info_uses_autoconfig_and_threads_trust_remote_code(self):
424+
"""apply_checkpoint_info uses AutoConfig (not PretrainedConfig) and forwards
425+
trust_remote_code to every HF config load call."""
426+
llm_config = LLMConfig(
427+
model_loading_config=ModelLoadingConfig(model_id="test_model")
428+
)
429+
mock_hf_config = MagicMock(spec=["architectures", "vision_config"])
430+
mock_hf_config.architectures = ["LlavaForCausalLM"]
431+
432+
with patch(
433+
"transformers.AutoConfig.from_pretrained", return_value=mock_hf_config
434+
) as mock_auto:
435+
llm_config.apply_checkpoint_info("vision/model", trust_remote_code=True)
436+
437+
assert all(
438+
call.kwargs["trust_remote_code"] is True
439+
for call in mock_auto.call_args_list
440+
)
441+
assert llm_config._supports_vision is True
442+
assert llm_config._model_architecture == "LlavaForCausalLM"
443+
444+
421445
if __name__ == "__main__":
422446
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)