Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
# ConstantOfShape is preserved to avoid increasing model size unnecessarily
"ConstantOfShape",
# Quantize/DequantizeLinear are preserved to keep the quantization info
"QuantizeLinear",
"DequantizeLinear",
"DynamicQuantizeLinear",
"QuantizeLinear",
]

DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192
Expand Down
30 changes: 30 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,36 @@ def test_node_is_folded_if_specified_as_should_fold(self):
np.ones((42, 42), dtype=np.int64),
)

def test_quantize_linear_is_not_folded(self):
model_text = """
<ir_version: 10, opset_import: [ "" : 20]>
agraph () => (uint8[4] z)
<float[4] x = {1.0, 2.0, 3.0, 4.0}, float[1] scale = {1.0}, uint8[1] zero_point = {0}>
{
z = QuantizeLinear (x, scale, zero_point)
}
"""
model = ir.from_onnx_text(model_text)
optimized = self._fold(model)
ops = [node.op_type for node in optimized.graph]
# QuantizeLinear should not be folded even when all inputs are constants
self.assertEqual(ops, ["QuantizeLinear"])

def test_dequantize_linear_is_not_folded(self):
model_text = """
<ir_version: 10, opset_import: [ "" : 20]>
agraph () => (float[4] z)
<uint8[4] x = {1, 2, 3, 4}, float[1] scale = {1.0}, uint8[1] zero_point = {0}>
{
z = DequantizeLinear (x, scale, zero_point)
}
"""
model = ir.from_onnx_text(model_text)
optimized = self._fold(model)
ops = [node.op_type for node in optimized.graph]
# DequantizeLinear should not be folded even when all inputs are constants
self.assertEqual(ops, ["DequantizeLinear"])

def test_multi_graph_identity_output_preserves_output_name(self):
model = """
<ir_version: 10, opset_import: ["" : 20]>
Expand Down
Loading