|
8 | 8 | from transformers import AutoModelForCausalLM |
9 | 9 | from transformers.models.llama4.configuration_llama4 import Llama4Config |
10 | 10 |
|
| 11 | +from tensorrt_llm._torch.auto_deploy import LlmArgs |
11 | 12 | from tensorrt_llm._torch.auto_deploy.models.hf import ( |
12 | 13 | AutoModelForCausalLMFactory, |
13 | 14 | hf_load_state_dict_with_device, |
@@ -134,6 +135,78 @@ def test_recursive_update_config(mock_factory): |
134 | 135 | assert config.text_config.rope_scaling["type"] == "linear" |
135 | 136 |
|
136 | 137 |
|
| 138 | +def test_create_factory_threads_trust_remote_code(): |
| 139 | + factory = LlmArgs( |
| 140 | + model="dummy_model", |
| 141 | + trust_remote_code=False, |
| 142 | + tokenizer_kwargs={"trust_remote_code": True}, |
| 143 | + ).create_factory() |
| 144 | + |
| 145 | + assert isinstance(factory, AutoModelForCausalLMFactory) |
| 146 | + assert factory.trust_remote_code is False |
| 147 | + assert factory.tokenizer_kwargs["trust_remote_code"] is False |
| 148 | + |
| 149 | + |
| 150 | +def test_get_model_config_uses_factory_trust_remote_code(mock_factory): |
| 151 | + with patch.object( |
| 152 | + AutoModelForCausalLMFactory, |
| 153 | + "prefetch_checkpoint", |
| 154 | + ), patch( |
| 155 | + "tensorrt_llm._torch.auto_deploy.models.hf.AutoConfig.from_pretrained", |
| 156 | + return_value=(Llama4Config(), {}), |
| 157 | + ) as mock_from_pretrained: |
| 158 | + mock_factory._get_model_config() |
| 159 | + |
| 160 | + assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False |
| 161 | + |
| 162 | + |
| 163 | +def test_init_tokenizer_uses_factory_trust_remote_code(mock_factory): |
| 164 | + with patch( |
| 165 | + "tensorrt_llm._torch.auto_deploy.models.hf.AutoTokenizer.from_pretrained", |
| 166 | + return_value=MagicMock(), |
| 167 | + ) as mock_from_pretrained: |
| 168 | + mock_factory.init_tokenizer() |
| 169 | + |
| 170 | + assert mock_from_pretrained.call_args.kwargs["trust_remote_code"] is False |
| 171 | + |
| 172 | + |
| 173 | +def test_build_model_uses_factory_trust_remote_code(mock_factory): |
| 174 | + dummy_model = SimpleModel() |
| 175 | + dummy_model.config = MagicMock() |
| 176 | + |
| 177 | + with patch.object( |
| 178 | + AutoModelForCausalLMFactory, |
| 179 | + "_get_model_config", |
| 180 | + return_value=(Llama4Config(), {"trust_remote_code": True, "foo": "bar"}), |
| 181 | + ), patch.object(AutoModelForCausalLM, "from_config", return_value=dummy_model) as from_config: |
| 182 | + mock_factory.build_model(device="meta") |
| 183 | + |
| 184 | + assert from_config.call_args.kwargs["foo"] == "bar" |
| 185 | + assert from_config.call_args.kwargs["trust_remote_code"] is False |
| 186 | + |
| 187 | + |
| 188 | +def test_build_and_load_model_uses_factory_trust_remote_code(mock_factory): |
| 189 | + dummy_model = SimpleModel() |
| 190 | + |
| 191 | + with patch.object( |
| 192 | + AutoModelForCausalLMFactory, |
| 193 | + "_get_model_config", |
| 194 | + return_value=(Llama4Config(), {"trust_remote_code": True}), |
| 195 | + ), patch.object( |
| 196 | + AutoModelForCausalLMFactory, |
| 197 | + "prefetch_checkpoint", |
| 198 | + ), patch.object( |
| 199 | + AutoModelForCausalLM, |
| 200 | + "from_pretrained", |
| 201 | + return_value=dummy_model, |
| 202 | + ) as from_pretrained: |
| 203 | + mock_factory.build_and_load_model(device="cuda") |
| 204 | + |
| 205 | + assert from_pretrained.call_args.kwargs["trust_remote_code"] is False |
| 206 | + assert from_pretrained.call_args.kwargs["tp_plan"] == "auto" |
| 207 | + assert from_pretrained.call_args.kwargs["dtype"] == "auto" |
| 208 | + |
| 209 | + |
137 | 210 | def test_register_custom_model_cls(): |
138 | 211 | config_cls_name = "FooConfig" |
139 | 212 | custom_model_cls = MagicMock(spec=AutoModelForCausalLM) |
|
0 commit comments