File tree Expand file tree Collapse file tree 1 file changed +21
-6
lines changed
Expand file tree Collapse file tree 1 file changed +21
-6
lines changed Original file line number Diff line number Diff line change @@ -212,12 +212,27 @@ def setUpClass(cls):
212212 def _assert_outputs_equal (self , model_output , ref_output ):
213213 self .assertTrue (len (ref_output ) == len (model_output ))
214214 for i in range (len (ref_output )):
215- self .assertTrue (
216- torch .allclose (
217- model_output [i ], ref_output [i ], atol = self .atol , rtol = self .rtol
218- ),
219- msg = f"ref_output:\n { ref_output [i ]} \n \n model_output:\n { model_output [i ]} " ,
220- )
215+ if model_output [i ].dtype == torch .bool or ref_output [i ].dtype == torch .bool :
216+ model_bool = model_output [i ].to (torch .bool )
217+ ref_bool = ref_output [i ].to (torch .bool )
218+ model_bool , ref_bool = torch .broadcast_tensors (model_bool , ref_bool )
219+ self .assertTrue (
220+ torch .equal (model_bool , ref_bool ),
221+ msg = f"Output { i } does not match reference output.\n "
222+ f"\t Output tensor shape: { model_output [i ].shape } , dtype: { model_output [i ].dtype } \n "
223+ f"\t Reference tensor shape: { ref_output [i ].shape } , dtype: { ref_output [i ].dtype } \n "
224+ f"\t Mismatch count: { torch .count_nonzero (model_bool ^ ref_bool ).item ()} / { model_bool .numel ()} \n " ,
225+ )
226+ else :
227+ self .assertTrue (
228+ torch .allclose (
229+ model_output [i ],
230+ ref_output [i ],
231+ atol = self .atol ,
232+ rtol = self .rtol ,
233+ ),
234+ msg = f"ref_output:\n { ref_output [i ]} \n \n model_output:\n { model_output [i ]} " ,
235+ )
221236
222237 def _save_model_and_expected_output (
223238 self ,
You can’t perform that action at this time.
0 commit comments