Skip to content

Commit fe20993

Browse files
authored
Disable qdq to mnb fusion in test_mnb_to_qdq (microsoft#2429)
## Describe your changes Latest ORT has QDQ to MatMulNBits rules for more cases now. We want to disable this fusion in the `test_mnb_to_qdq` test since we are trying to compare the original MNB model with the replacement QDQ model. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent 39d795f commit fe20993

1 file changed

Lines changed: 6 additions & 10 deletions

File tree

test/passes/onnx/test_mnb_to_qdq.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,25 +148,21 @@ def test_mnb_to_qdq(create_mnb_model, nodes_to_exclude, add_zero_point, use_sign
148148
# validate
149149
original_session = onnxruntime.InferenceSession(str(mnb_path))
150150
original_session.disable_fallback()
151+
# disable qdq to mnb fusion so we can test the output of the DQ nodes directly
152+
disabled_optimizers = ["QDQSelectorActionTransformer"]
151153
if is_symmetric and use_signed_int and not add_zero_point and use_transpose_op:
152154
# there seems to be a bug in ORT graph optimization which changes the int4 DQ to uint8 DQ
153155
with pytest.raises(Exception, match="uint8"):
154-
onnxruntime.InferenceSession(str(qdq_model.model_path))
156+
onnxruntime.InferenceSession(str(qdq_model.model_path), disabled_optimizers=disabled_optimizers)
155157
return
156158
else:
157-
qdq_session = onnxruntime.InferenceSession(str(qdq_model.model_path))
159+
qdq_session = onnxruntime.InferenceSession(str(qdq_model.model_path), disabled_optimizers=disabled_optimizers)
158160
qdq_session.disable_fallback()
159161

160162
input_data = {"input": np.random.randn(1, 1, in_dim).astype(np.float32)}
161163
original_output = original_session.run(None, input_data)[0]
162164
qdq_output = qdq_session.run(None, input_data)[0]
163165
assert original_output.shape == qdq_output.shape
164166
assert original_output.dtype == qdq_output.dtype
165-
if bits == 4 and not use_transpose_op:
166-
# Pre transposed DQ model does not match the expected output on x64 CPU
167-
# check for assertion failure so we know when the test is fixed
168-
with pytest.raises(AssertionError):
169-
np.testing.assert_allclose(original_output, qdq_output, atol=1e-4)
170-
else:
171-
# acc level 4 is used for 8 bit, so the tolerance is higher
172-
np.testing.assert_allclose(original_output, qdq_output, atol=1e-2 if bits == 8 else 1e-4)
167+
# acc level 4 is used for 8 bit, so the tolerance is higher
168+
np.testing.assert_allclose(original_output, qdq_output, atol=1e-2 if bits == 8 else 1e-4)

0 commit comments

Comments
 (0)