File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -794,10 +794,6 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
794794 )
795795 )
796796
797- # Now cast to the dtype override after quantization, so non-quantized
798- # components use the desired computation dtype.
799- edge_manager .model = edge_manager .model .to (dtype = dtype_override .to_torch_dtype ())
800-
801797 return edge_manager
802798
803799
@@ -1857,6 +1853,12 @@ def _get_source_transforms( # noqa
18571853 )
18581854 )
18591855
1856+ # Cast to dtype_override after quantization transforms, so non-quantized
1857+ # components use the desired computation dtype. This must happen before
1858+ # _convert_model_for_aarch64 which converts IntxUnpackedToInt8Tensor to
1859+ # IntxOpaqueTensor (which doesn't support .to()).
1860+ transforms .append (lambda m : m .to (dtype = dtype_override .to_torch_dtype ()))
1861+
18601862 if any ([use_torchao_kernels_linear , use_torchao_kernels_tied_embedding ]):
18611863 from torchao .prototype .tensor_conversion .api import _convert_model_for_aarch64
18621864
You can’t perform that action at this time.
0 commit comments