File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -92,9 +92,13 @@ def _validate_inputs(self, data_loader):
9292 # Get model and data shapes as numpy arrays
9393 inp_shape_model = np .array (self .input_shapes [inp_name ])
9494 inp_shape_data = np .array (inp_shape .shape )
95- # Skip check for unknown dimensions (shape = -1)
96- mask = (inp_shape_model != - 1 ) & (inp_shape_data != - 1 )
97- if np .any (inp_shape_model [mask ] != inp_shape_data [mask ]):
95+ # Compare input rank
96+ raise_value_error = len (inp_shape_model ) != len (inp_shape_data )
97+ if not raise_value_error :
98+ # Compare input shape, skipping check for unknown dimensions (shape = -1)
99+ mask = (inp_shape_model != - 1 ) & (inp_shape_data != - 1 )
100+ raise_value_error = np .any (inp_shape_model [mask ] != inp_shape_data [mask ])
101+ if raise_value_error :
98102 raise ValueError (
99103 f"Input shape from '{ inp_name } ' does not match provided input shape: "
100104 f"{ self .input_shapes [inp_name ]} vs { list (inp_shape .shape )} . "
You can’t perform that action at this time.
0 commit comments