Skip to content

Commit 7bbff19

Browse files
authored
Fix QNN is_nan bool tensor subtraction error
Differential Revision: D99049483 Pull Request resolved: #18660
1 parent 27aa628 commit 7bbff19

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

backends/qualcomm/tests/utils.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,27 @@ def setUpClass(cls):
212212
def _assert_outputs_equal(self, model_output, ref_output):
213213
self.assertTrue(len(ref_output) == len(model_output))
214214
for i in range(len(ref_output)):
215-
self.assertTrue(
216-
torch.allclose(
217-
model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol
218-
),
219-
msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
220-
)
215+
if model_output[i].dtype == torch.bool or ref_output[i].dtype == torch.bool:
216+
model_bool = model_output[i].to(torch.bool)
217+
ref_bool = ref_output[i].to(torch.bool)
218+
model_bool, ref_bool = torch.broadcast_tensors(model_bool, ref_bool)
219+
self.assertTrue(
220+
torch.equal(model_bool, ref_bool),
221+
msg=f"Output {i} does not match reference output.\n"
222+
f"\tOutput tensor shape: {model_output[i].shape}, dtype: {model_output[i].dtype}\n"
223+
f"\tReference tensor shape: {ref_output[i].shape}, dtype: {ref_output[i].dtype}\n"
224+
f"\tMismatch count: {torch.count_nonzero(model_bool ^ ref_bool).item()} / {model_bool.numel()}\n",
225+
)
226+
else:
227+
self.assertTrue(
228+
torch.allclose(
229+
model_output[i],
230+
ref_output[i],
231+
atol=self.atol,
232+
rtol=self.rtol,
233+
),
234+
msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
235+
)
221236

222237
def _save_model_and_expected_output(
223238
self,

0 commit comments

Comments
 (0)