Skip to content

Commit f31b179

Browse files
committed
Respect AutoDeploy trust_remote_code
Signed-off-by: Jordan Mecom <jm@squareup.com>
1 parent 7ee9e8b commit f31b179

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def create_factory(self) -> ModelFactory:
360360
model_kwargs=self.model_kwargs,
361361
tokenizer=None if self.tokenizer is None else str(self.tokenizer),
362362
tokenizer_kwargs=self.tokenizer_kwargs,
363+
trust_remote_code=self.trust_remote_code,
363364
skip_loading_weights=self.skip_loading_weights,
364365
max_seq_len=self.max_seq_len,
365366
# Extra kwargs consumed by EagleOneModelFactory (ignored by others via **kwargs)

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
model_kwargs: Optional[Dict[str, Any]] = None,
120120
tokenizer: Optional[str] = None,
121121
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
122+
trust_remote_code: bool = False,
122123
skip_loading_weights: bool = False,
123124
max_seq_len: int = 512,
124125
**kwargs,
@@ -127,6 +128,7 @@ def __init__(
127128
self.model_kwargs = copy.deepcopy(model_kwargs or {})
128129
self._tokenizer = tokenizer
129130
self.tokenizer_kwargs = copy.deepcopy(tokenizer_kwargs or {})
131+
self.trust_remote_code = trust_remote_code
130132
self.skip_loading_weights = skip_loading_weights
131133
self.max_seq_len = max_seq_len
132134
self._prefetched_model_path: Optional[str] = None

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
107107
"legacy": False,
108108
"padding_side": "left",
109109
"truncation_side": "left",
110-
"trust_remote_code": True,
111110
"use_fast": True,
112111
}
113112

@@ -124,6 +123,7 @@ def __init__(self, *args, **kwargs):
124123
self._quant_config_reader: QuantConfigReader | None = None
125124
# Ingest defaults for tokenizer and model kwargs
126125
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
126+
self.tokenizer_kwargs["trust_remote_code"] = self.trust_remote_code
127127
self.model_kwargs = deep_merge_dicts(
128128
self._model_defaults,
129129
self.model_kwargs or {},
@@ -201,7 +201,9 @@ def _get_model_config(self) -> Tuple[PretrainedConfig, Dict[str, Any]]:
201201
# the entire subconfig will be overwritten.
202202
# we want to recursively update model_config from model_kwargs here.
203203
model_config, unused_kwargs = AutoConfig.from_pretrained(
204-
self.model, return_unused_kwargs=True, trust_remote_code=True
204+
self.model,
205+
return_unused_kwargs=True,
206+
trust_remote_code=self.trust_remote_code,
205207
)
206208
model_config, nested_unused_kwargs = self._recursive_update_config(
207209
model_config, self.model_kwargs
@@ -231,8 +233,8 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
231233
model = self.automodel_cls.from_config(
232234
model_config,
233235
**{
234-
"trust_remote_code": True,
235236
**unused_kwargs,
237+
"trust_remote_code": self.trust_remote_code,
236238
},
237239
)
238240

@@ -309,9 +311,9 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
309311
self.model,
310312
config=model_config,
311313
**{
312-
"trust_remote_code": True,
313314
"tp_plan": "auto",
314315
**unused_kwargs,
316+
"trust_remote_code": self.trust_remote_code,
315317
"dtype": "auto", # takes precedence over unused_kwargs!
316318
},
317319
)

tests/unittest/auto_deploy/singlegpu/models/test_hf.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers import AutoModelForCausalLM
99
from transformers.models.llama4.configuration_llama4 import Llama4Config
1010

11+
from tensorrt_llm._torch.auto_deploy import LlmArgs
1112
from tensorrt_llm._torch.auto_deploy.models.hf import (
1213
AutoModelForCausalLMFactory,
1314
hf_load_state_dict_with_device,
@@ -134,6 +135,78 @@ def test_recursive_update_config(mock_factory):
134135
assert config.text_config.rope_scaling["type"] == "linear"
135136

136137

138+
def test_create_factory_threads_trust_remote_code():
139+
factory = LlmArgs(
140+
model="dummy_model",
141+
trust_remote_code=False,
142+
tokenizer_kwargs={"trust_remote_code": True},
143+
).create_factory()
144+
145+
assert isinstance(factory, AutoModelForCausalLMFactory)
146+
assert factory.trust_remote_code is False
147+
assert factory.tokenizer_kwargs["trust_remote_code"] is False
148+
149+
150+
def test_get_model_config_uses_factory_trust_remote_code(mock_factory):
151+
with patch.object(
152+
AutoModelForCausalLMFactory,
153+
"prefetch_checkpoint",
154+
), patch(
155+
"tensorrt_llm._torch.auto_deploy.models.hf.AutoConfig.from_pretrained",
156+
return_value=(Llama4Config(), {}),
157+
) as mock_from_pretrained:
158+
mock_factory._get_model_config()
159+
160+
assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False
161+
162+
163+
def test_init_tokenizer_uses_factory_trust_remote_code(mock_factory):
164+
with patch(
165+
"tensorrt_llm._torch.auto_deploy.models.hf.AutoTokenizer.from_pretrained",
166+
return_value=MagicMock(),
167+
) as mock_from_pretrained:
168+
mock_factory.init_tokenizer()
169+
170+
assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False
171+
172+
173+
def test_build_model_uses_factory_trust_remote_code(mock_factory):
174+
dummy_model = SimpleModel()
175+
dummy_model.config = MagicMock()
176+
177+
with patch.object(
178+
AutoModelForCausalLMFactory,
179+
"_get_model_config",
180+
return_value=(Llama4Config(), {"trust_remote_code": True, "foo": "bar"}),
181+
), patch.object(AutoModelForCausalLM, "from_config", return_value=dummy_model) as from_config:
182+
mock_factory.build_model(device="meta")
183+
184+
assert from_config.call_args.kwargs["foo"] == "bar"
185+
assert from_config.call_args.kwargs["trust_remote_code"] is False
186+
187+
188+
def test_build_and_load_model_uses_factory_trust_remote_code(mock_factory):
189+
dummy_model = SimpleModel()
190+
191+
with patch.object(
192+
AutoModelForCausalLMFactory,
193+
"_get_model_config",
194+
return_value=(Llama4Config(), {"trust_remote_code": True}),
195+
), patch.object(
196+
AutoModelForCausalLMFactory,
197+
"prefetch_checkpoint",
198+
), patch.object(
199+
AutoModelForCausalLM,
200+
"from_pretrained",
201+
return_value=dummy_model,
202+
) as from_pretrained:
203+
mock_factory.build_and_load_model(device="cuda")
204+
205+
assert from_pretrained.call_args.kwargs["trust_remote_code"] is False
206+
assert from_pretrained.call_args.kwargs["tp_plan"] == "auto"
207+
assert from_pretrained.call_args.kwargs["dtype"] == "auto"
208+
209+
137210
def test_register_custom_model_cls():
138211
config_cls_name = "FooConfig"
139212
custom_model_cls = MagicMock(spec=AutoModelForCausalLM)

0 commit comments

Comments
 (0)