diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 4eb485a158..27c3bee5d9 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -89,7 +89,16 @@ def _validate_inputs(self, data_loader): if sorted(self.input_names) != sorted(data_loader[0].keys()): raise ValueError("Input names from ONNX model do not match provided input names.") for inp_name, inp_shape in data_loader[0].items(): - if self.input_shapes[inp_name] != list(inp_shape.shape): + # Get model and data shapes as numpy arrays + inp_shape_model = np.array(self.input_shapes[inp_name]) + inp_shape_data = np.array(inp_shape.shape) + # Compare input rank + raise_value_error = len(inp_shape_model) != len(inp_shape_data) + if not raise_value_error: + # Compare input shape, skipping check for unknown dimensions + mask = inp_shape_model > 0 + raise_value_error = np.any(inp_shape_model[mask] != inp_shape_data[mask]) + if raise_value_error: raise ValueError( f"Input shape from '{inp_name}' does not match provided input shape: " f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "