|
32 | 32 | from onnxconverter_common import convert_float_to_float16 |
33 | 33 | from torch.nn.parallel import DataParallel, DistributedDataParallel |
34 | 34 |
|
| 35 | +from modelopt.onnx.autocast.convert import convert_to_f16 |
35 | 36 | from modelopt.onnx.export import ( |
36 | 37 | FP8QuantExporter, |
37 | 38 | INT4QuantExporter, |
@@ -578,16 +579,22 @@ def get_onnx_bytes_and_metadata( |
578 | 579 | if dq_only: |
579 | 580 | onnx_opt_graph = qdq_to_dq(onnx_opt_graph) |
580 | 581 |
|
581 | | - if weights_dtype == "fp16": |
582 | | - onnx_opt_graph = convert_float_to_float16( |
583 | | - onnx_opt_graph, |
584 | | - keep_io_types=False, |
585 | | - disable_shape_infer=True, |
586 | | - check_fp16_ready=False, |
587 | | - op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], |
588 | | - ) |
589 | | - # Change FP32 cast nodes feeding into Concat/Add to FP16 |
590 | | - onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
| 582 | + if weights_dtype in ["fp16", "bf16"]: |
| 583 | + if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model): |
| 584 | + assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" |
| 585 | + onnx_opt_graph = convert_float_to_float16( |
| 586 | + onnx_opt_graph, |
| 587 | + keep_io_types=False, |
| 588 | + disable_shape_infer=True, |
| 589 | + check_fp16_ready=False, |
| 590 | + op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], |
| 591 | + ) |
| 592 | + # Change FP32 cast nodes feeding into Concat/Add to FP16 |
| 593 | + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
| 594 | + else: |
| 595 | + onnx_opt_graph = convert_to_f16( |
| 596 | + onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False |
| 597 | + ) |
591 | 598 |
|
592 | 599 | onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) |
593 | 600 |
|
|
0 commit comments