Skip to content

Commit 6e77a83

Browse files
ajrasanekevalmorabia97
authored andcommitted
[BugFix][5997203] Update Sqrt casts to FP16 (#1084)
### What does this PR do? Type of change: Bug fix Change cast nodes before Sqrt to FP16 ### Testing ``` python torch_quant_to_onnx.py --quantize_mode=mxfp8 --timm_model_name=vit_base_patch16_224 --onnx_save_path=vit_base_patch16_224.mxfp8.onnx --calibration_data_size=512 python evaluate.py --onnx_path=vit_base_patch16_224.mxfp8.onnx --model_name=vit_base_patch16_224 --eval_data_size=100 ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ❌ - Casts before Sqrt are now FP16 instead of FP32 - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ - Did you write any new necessary tests?:N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved ONNX model export for quantized models with reduced precision (fp16/bf16) by enhancing type casting handling during the export process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent bd188a9 commit 6e77a83

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def get_onnx_bytes_and_metadata(
608608
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
609609
)
610610
# Change FP32 cast nodes feeding into Concat/Add to FP16
611-
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
611+
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"])
612612
else:
613613
onnx_opt_graph = convert_to_f16(
614614
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False

0 commit comments

Comments
 (0)