Skip to content

Commit 474187a

Browse files
committed
last changes
1 parent 8bae9ab commit 474187a

3 files changed

Lines changed: 13 additions & 6 deletions

File tree

_unittests/ut_tasks/try_export.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _config_reduction(config, task):
149149
begin = time.perf_counter()
150150
if not os.environ.get("STOPAT", ""):
151151
expected = model_to_export(**inputs)
152+
expected_big = model_to_export(**big_inputs)
152153
else:
153154
expected = None
154155
print(f"-- MODEL RUN IN {time.perf_counter() - begin}")
@@ -266,7 +267,7 @@ def _config_reduction(config, task):
266267
(f"test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}"),
267268
filename,
268269
model_to_export,
269-
export_inputs,
270+
big_inputs, # export_inputs,
270271
verbose=1,
271272
providers=(
272273
["CUDAExecutionProvider", "CPUExecutionProvider"]
@@ -277,7 +278,9 @@ def _config_reduction(config, task):
277278
atol=0.05,
278279
rtol=10,
279280
# ep=pt2_file,
280-
expected=expected,
281+
expected=expected_big,
282+
log_severity_level=0,
283+
log_verbosity_level=0,
281284
)
282285
print(f"-- MODEL VERIFIED IN {time.perf_counter() - begin}")
283286
os.environ["QWEN25ATTENTION"] = qwen25_attention

onnx_diagnostic/ext_test_case.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def tryCall(
12181218
def assert_onnx_disc(
12191219
self,
12201220
test_name: str,
1221-
proto: "onnx.ModelProto", # noqa: F821
1221+
proto: Union[str, "onnx.ModelProto"], # noqa: F821
12221222
model: "torch.nn.Module", # noqa: F821
12231223
inputs: Union[Tuple[Any], Dict[str, Any]],
12241224
verbose: int = 0,
@@ -1264,7 +1264,9 @@ def assert_onnx_disc(
12641264
name = f"{test_name}.onnx"
12651265
if verbose:
12661266
print(f"[{vname}] save the onnx model into {name!r}")
1267+
model_file = None
12671268
if isinstance(proto, str):
1269+
model_file = proto
12681270
name = proto
12691271
proto = onnx.load(name)
12701272
elif not self.unit_test_going():
@@ -1287,11 +1289,15 @@ def assert_onnx_disc(
12871289
options = onnxruntime.SessionOptions()
12881290
if ort_optimized_graph:
12891291
options.optimized_model_filepath = f"{name}.optort.onnx"
1292+
if "log_severity_level" in kwargs:
1293+
options.log_severity_level = kwargs["log_severity_level"]
1294+
if "log_verbosity_level" in kwargs:
1295+
options.log_verbosity_level = kwargs["log_verbosity_level"]
12901296
providers = kwargs.get("providers", ["CPUExecutionProvider"])
12911297
if verbose:
12921298
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
12931299
sess = onnxruntime.InferenceSession(
1294-
proto.SerializeToString(), options, providers=providers
1300+
model_file or proto.SerializeToString(), options, providers=providers
12951301
)
12961302
if verbose:
12971303
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,6 @@ def forward(
538538
break
539539

540540
hidden_states = self.merger(hidden_states)
541-
if STOPAT is not None:
542-
return hidden_states
543541
reverse_indices = torch.argsort(window_index)
544542
hidden_states = hidden_states[reverse_indices, :]
545543
return hidden_states

0 commit comments

Comments
 (0)