Skip to content

Commit ed1c88b

Browse files
authored
1 parent 3180927 commit ed1c88b

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -794,10 +794,6 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
794794
)
795795
)
796796

797-
# Now cast to the dtype override after quantization, so non-quantized
798-
# components use the desired computation dtype.
799-
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
800-
801797
return edge_manager
802798

803799

@@ -1857,6 +1853,12 @@ def _get_source_transforms( # noqa
18571853
)
18581854
)
18591855

1856+
# Cast to dtype_override after quantization transforms, so non-quantized
1857+
# components use the desired computation dtype. This must happen before
1858+
# _convert_model_for_aarch64 which converts IntxUnpackedToInt8Tensor to
1859+
# IntxOpaqueTensor (which doesn't support .to()).
1860+
transforms.append(lambda m: m.to(dtype=dtype_override.to_torch_dtype()))
1861+
18601862
if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]):
18611863
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
18621864

0 commit comments

Comments
 (0)