Skip to content

Commit f6ce7b3

Browse files
committed
Fix comments
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 05c33b2 commit f6ce7b3

2 files changed

Lines changed: 38 additions & 14 deletions

File tree

modelopt/onnx/utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,10 +1368,27 @@ def _convert_constant_values(constant_node: onnx.NodeProto, cast_node: onnx.Node
13681368
cast_to_type = get_cast_to_type(cast_node)
13691369
for attr in constant_node.attribute:
13701370
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
1371-
np_array = onnx.numpy_helper.to_array(attr.t)
1372-
target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type)
1373-
new_array = np_array.astype(target_np_type)
1374-
new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name)
1371+
# Read input tensor — bfloat16 tensors use raw_data and need special handling
1372+
if attr.t.data_type == onnx.TensorProto.BFLOAT16:
1373+
np_array = read_f16_tensor_as_fp32(attr.t)
1374+
else:
1375+
np_array = onnx.numpy_helper.to_array(attr.t)
1376+
1377+
# Write output tensor — bfloat16 cannot use numpy_helper.from_array
1378+
if cast_to_type == onnx.TensorProto.BFLOAT16:
1379+
import ml_dtypes
1380+
1381+
new_tensor = onnx.TensorProto()
1382+
new_tensor.dims.extend(np_array.shape)
1383+
new_tensor.name = attr.t.name
1384+
new_tensor.data_type = onnx.TensorProto.BFLOAT16
1385+
bf16_bytes = np_array.astype(np.float32).astype(ml_dtypes.bfloat16)
1386+
new_tensor.raw_data = bf16_bytes.view(np.uint16).tobytes()
1387+
else:
1388+
target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type)
1389+
new_array = np_array.astype(target_np_type)
1390+
new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name)
1391+
13751392
attr.t.CopyFrom(new_tensor)
13761393
break
13771394

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from onnxconverter_common import convert_float_to_float16
3333
from torch.nn.parallel import DataParallel, DistributedDataParallel
3434

35+
from modelopt.onnx.autocast.convert import convert_to_f16
3536
from modelopt.onnx.export import (
3637
FP8QuantExporter,
3738
INT4QuantExporter,
@@ -578,16 +579,22 @@ def get_onnx_bytes_and_metadata(
578579
if dq_only:
579580
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)
580581

581-
if weights_dtype == "fp16":
582-
onnx_opt_graph = convert_float_to_float16(
583-
onnx_opt_graph,
584-
keep_io_types=False,
585-
disable_shape_infer=True,
586-
check_fp16_ready=False,
587-
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
588-
)
589-
# Change FP32 cast nodes feeding into Concat/Add to FP16
590-
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
582+
if weights_dtype in ["fp16", "bf16"]:
583+
if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model):
584+
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
585+
onnx_opt_graph = convert_float_to_float16(
586+
onnx_opt_graph,
587+
keep_io_types=False,
588+
disable_shape_infer=True,
589+
check_fp16_ready=False,
590+
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
591+
)
592+
# Change FP32 cast nodes feeding into Concat/Add to FP16
593+
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
594+
else:
595+
onnx_opt_graph = convert_to_f16(
596+
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
597+
)
591598

592599
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)
593600

0 commit comments

Comments
 (0)