diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 99e97aaf9d0..4f980c8e6be 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -360,6 +360,7 @@ def create_factory(self) -> ModelFactory: model_kwargs=self.model_kwargs, tokenizer=None if self.tokenizer is None else str(self.tokenizer), tokenizer_kwargs=self.tokenizer_kwargs, + trust_remote_code=self.trust_remote_code, skip_loading_weights=self.skip_loading_weights, max_seq_len=self.max_seq_len, # Extra kwargs consumed by EagleOneModelFactory (ignored by others via **kwargs) diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index d04c35e5fe0..63ccc99bb3d 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -119,6 +119,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, tokenizer: Optional[str] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + trust_remote_code: bool = False, skip_loading_weights: bool = False, max_seq_len: int = 512, **kwargs, @@ -127,6 +128,7 @@ def __init__( self.model_kwargs = copy.deepcopy(model_kwargs or {}) self._tokenizer = tokenizer self.tokenizer_kwargs = copy.deepcopy(tokenizer_kwargs or {}) + self.trust_remote_code = trust_remote_code self.skip_loading_weights = skip_loading_weights self.max_seq_len = max_seq_len self._prefetched_model_path: Optional[str] = None diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index e01ec6fe2c0..645e07d2f0c 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -107,7 +107,6 @@ class AutoModelForCausalLMFactory(AutoModelFactory): "legacy": False, "padding_side": "left", "truncation_side": "left", - "trust_remote_code": True, "use_fast": True, } @@ -124,6 +123,7 @@ def __init__(self, *args, **kwargs): self._quant_config_reader: QuantConfigReader | None = None # Ingest defaults for tokenizer and model kwargs self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs) + self.tokenizer_kwargs["trust_remote_code"] = self.trust_remote_code self.model_kwargs = deep_merge_dicts( self._model_defaults, self.model_kwargs or {}, @@ -201,7 +201,9 @@ def _get_model_config(self) -> Tuple[PretrainedConfig, Dict[str, Any]]: # the entire subconfig will be overwritten. # we want to recursively update model_config from model_kwargs here. model_config, unused_kwargs = AutoConfig.from_pretrained( - self.model, return_unused_kwargs=True, trust_remote_code=True + self.model, + return_unused_kwargs=True, + trust_remote_code=self.trust_remote_code, ) model_config, nested_unused_kwargs = self._recursive_update_config( model_config, self.model_kwargs @@ -231,8 +233,8 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: model = self.automodel_cls.from_config( model_config, **{ - "trust_remote_code": True, **unused_kwargs, + "trust_remote_code": self.trust_remote_code, }, ) @@ -309,9 +311,9 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module: self.model, config=model_config, **{ - "trust_remote_code": True, "tp_plan": "auto", **unused_kwargs, + "trust_remote_code": self.trust_remote_code, "dtype": "auto", # takes precedence over unused_kwargs! }, ) diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_hf.py b/tests/unittest/auto_deploy/singlegpu/models/test_hf.py index d6e63b6433d..d28cedcf60a 100644 --- a/tests/unittest/auto_deploy/singlegpu/models/test_hf.py +++ b/tests/unittest/auto_deploy/singlegpu/models/test_hf.py @@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM from transformers.models.llama4.configuration_llama4 import Llama4Config +from tensorrt_llm._torch.auto_deploy import LlmArgs from tensorrt_llm._torch.auto_deploy.models.hf import ( AutoModelForCausalLMFactory, hf_load_state_dict_with_device, @@ -134,6 +135,78 @@ def test_recursive_update_config(mock_factory): assert config.text_config.rope_scaling["type"] == "linear" +def test_create_factory_threads_trust_remote_code(): + factory = LlmArgs( + model="dummy_model", + trust_remote_code=False, + tokenizer_kwargs={"trust_remote_code": True}, + ).create_factory() + + assert isinstance(factory, AutoModelForCausalLMFactory) + assert factory.trust_remote_code is False + assert factory.tokenizer_kwargs["trust_remote_code"] is False + + +def test_get_model_config_uses_factory_trust_remote_code(mock_factory): + with patch.object( + AutoModelForCausalLMFactory, + "prefetch_checkpoint", + ), patch( + "tensorrt_llm._torch.auto_deploy.models.hf.AutoConfig.from_pretrained", + return_value=(Llama4Config(), {}), + ) as mock_from_pretrained: + mock_factory._get_model_config() + + assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False + + +def test_init_tokenizer_uses_factory_trust_remote_code(mock_factory): + with patch( + "tensorrt_llm._torch.auto_deploy.models.hf.AutoTokenizer.from_pretrained", + return_value=MagicMock(), + ) as mock_from_pretrained: + mock_factory.init_tokenizer() + + assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False + + +def test_build_model_uses_factory_trust_remote_code(mock_factory): + dummy_model = SimpleModel() + dummy_model.config = MagicMock() + + with patch.object( + AutoModelForCausalLMFactory, + "_get_model_config", + return_value=(Llama4Config(), {"trust_remote_code": True, "foo": "bar"}), + ), patch.object(AutoModelForCausalLM, "from_config", return_value=dummy_model) as from_config: + mock_factory.build_model(device="meta") + + assert from_config.call_args.kwargs["foo"] == "bar" + assert from_config.call_args.kwargs["trust_remote_code"] is False + + +def test_build_and_load_model_uses_factory_trust_remote_code(mock_factory): + dummy_model = SimpleModel() + + with patch.object( + AutoModelForCausalLMFactory, + "_get_model_config", + return_value=(Llama4Config(), {"trust_remote_code": True}), + ), patch.object( + AutoModelForCausalLMFactory, + "prefetch_checkpoint", + ), patch.object( + AutoModelForCausalLM, + "from_pretrained", + return_value=dummy_model, + ) as from_pretrained: + mock_factory.build_and_load_model(device="cuda") + + assert from_pretrained.call_args.kwargs["trust_remote_code"] is False + assert from_pretrained.call_args.kwargs["tp_plan"] == "auto" + assert from_pretrained.call_args.kwargs["dtype"] == "auto" + + def test_register_custom_model_cls(): config_cls_name = "FooConfig" custom_model_cls = MagicMock(spec=AutoModelForCausalLM)