Skip to content

Commit 3e8fc7b

Browse files
authored
Quant in checkpoint dtype (#18781)
Switches order in etLLM so we quantize in checkpoint dtype and then cast to dtype-override. This can prevent underflowing on scales. Also exposes ability to turn HQQ on/off. Export: ``` python -m extension.llm.export.export_llm \ base.model_class=phi_4_mini \ base.params=examples/models/phi_4_mini/config/config.json \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ model.dtype_override=fp32 \ export.output_dir=/tmp/phi_4_mini_no_hqq \ export.output_name=model.pte \ export.max_seq_length=2048 \ export.max_context_length=2048 \ quantization.qmode=8da4w \ quantization.group_size=32 "quantization.embedding_quantize='8,0'" quantization.use_hqq=False \ backend.xnnpack.enabled=true \ backend.xnnpack.extended_ops=true ``` Phi4 output: ``` <|im_start|>system You are a highly capable, helpful, and honest AI assistant designed to provide clear, accurate, and thoughtful responses to a wide range of questions. Your primary goal is to assist users by offering information, explanations, and guidance in a manner that is respectful, unbiased, and safe. Always strive to be as helpful as possible, but never provide content that is harmful, unethical, offensive, or illegal. If a question is unclear, nonsensical, or based on incorrect premises, politely explain the issue rather than attempting to answer inaccurately. If you do not know the answer to a question, it is better to admit uncertainty than to provide false or misleading information. When appropriate, include examples, analogies, or step-by-step reasoning to enhance understanding. Your responses should be positive, inclusive, and supportive, fostering a constructive and informative interaction.<|im_end|> <|im_start|>user Please answer the following question in detail and provide relevant context, examples, and explanations where possible: What are some of the most important considerations when designing a machine learning system for real-world applications? Discuss potential challenges, best practices, and how to ensure ethical and responsible use.<|im_end|> <|im_start|>assistant Designing a machine learning system for real-world applications involves various considerations to ensure the system is effective, fair, and secure. Some of the most important considerations include data quality and sourcing, model choice and design, evaluation and validation, interpretability and transparency, and ensuring fairness and avoiding biases. Data quality and sourcing involve ensuring data is of high quality, representative of the target application, and properly curated and preprocessed to remove noise and biases. Model choice and design involve selecting an appropriate model for the application, understanding the strengths and limitations of different models, and understanding the application domain and data. Model evaluation and validation involve properly training and tuning the model on a training set and properly validating and testing the model on a separate validation set to avoid data leakage and ``` Related work: improvement in torchao's HQQ algorithm that helps with Phi4's model distribution: pytorch/ao#4259
1 parent 74403e2 commit 3e8fc7b

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -743,10 +743,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
743743
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
744744
)
745745

746-
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
747-
748-
# We want to quantize (in the source transforms) the weights of the model
749-
# in the checkpoint dtype.
746+
# Quantize weights in checkpoint dtype for accuracy, then cast to
747+
# dtype_override afterward. IntxUnpackedToInt8Tensor.to() properly
748+
# propagates the dtype change to scale/zero_point/output dtype.
750749
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
751750
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
752751
_get_source_transforms(
@@ -791,9 +790,14 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
791790
local_global_attention=llm_config.model.local_global_attention,
792791
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
793792
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
793+
quantize_with_hqq=llm_config.quantization.use_hqq,
794794
)
795795
)
796796

797+
# Now cast to the dtype override after quantization, so non-quantized
798+
# components use the desired computation dtype.
799+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
800+
797801
return edge_manager
798802

799803

@@ -1736,8 +1740,7 @@ def _get_source_transforms( # noqa
17361740
get_quant_embedding_transform(
17371741
embedding_quantize,
17381742
use_shared_embedding,
1739-
checkpoint_dtype,
1740-
quantize_with_hqq,
1743+
quantize_with_hqq=quantize_with_hqq,
17411744
)
17421745
)
17431746

examples/models/llama/source_transformation/quantize.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,14 +755,21 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
755755
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
756756
)
757757

758+
def _apply(self, fn, recurse=True):
759+
"""Override _apply to update self.dtype when the module is cast via .to(dtype)."""
760+
super()._apply(fn, recurse)
761+
# Probe the new dtype from the scales buffer, which gets cast by super()._apply.
762+
if self.scales is not None:
763+
self.dtype = self.scales.dtype
764+
return self
765+
758766

759767
############################ Source Transform Start #######################
760768

761769

762770
def get_quant_embedding_transform(
763771
embedding_quantize: str,
764772
use_shared_embedding: bool = False,
765-
dtype_override: Optional[DType] = None,
766773
quantize_with_hqq: bool = True,
767774
):
768775
if embedding_quantize.startswith("torchao:"):
@@ -817,13 +824,11 @@ def _torchao_embedding_quantizer(model):
817824
else:
818825
group_size = int(group_size)
819826
bitwidth = int(bitwidth)
820-
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
821827
return lambda model: EmbeddingQuantHandler(
822828
model,
823829
bitwidth=bitwidth,
824830
group_size=group_size,
825831
packed=(bitwidth in [2, 4]),
826-
precision=torch_dtype,
827832
quantize_with_hqq=quantize_with_hqq,
828833
).quantized_model()
829834

extension/llm/export/config/llm_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ class QuantizationConfig:
429429
calibration_limit: Optional[int] = None
430430
calibration_seq_length: Optional[int] = None
431431
calibration_data: str = "Once upon a time"
432+
use_hqq: bool = True
432433

433434
def __post_init__(self):
434435
if self.qmode:

0 commit comments

Comments
 (0)