Skip to content

Commit 115fc48

Browse files
committed
Add check for input rank
Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent 1fc5843 commit 115fc48

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

modelopt/onnx/autocast/referencerunner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,13 @@ def _validate_inputs(self, data_loader):
9292
# Get model and data shapes as numpy arrays
9393
inp_shape_model = np.array(self.input_shapes[inp_name])
9494
inp_shape_data = np.array(inp_shape.shape)
95-
# Skip check for unknown dimensions (shape = -1)
96-
mask = (inp_shape_model != -1) & (inp_shape_data != -1)
97-
if np.any(inp_shape_model[mask] != inp_shape_data[mask]):
95+
# Compare input rank
96+
raise_value_error = len(inp_shape_model) != len(inp_shape_data)
97+
if not raise_value_error:
98+
# Compare input shape, skipping check for unknown dimensions (shape = -1)
99+
mask = (inp_shape_model != -1) & (inp_shape_data != -1)
100+
raise_value_error = np.any(inp_shape_model[mask] != inp_shape_data[mask])
101+
if raise_value_error:
98102
raise ValueError(
99103
f"Input shape from '{inp_name}' does not match provided input shape: "
100104
f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "

0 commit comments

Comments
 (0)