Skip to content

Commit af69d08

Browse files
authored
Use torch.median for error stats (#16683)
### Summary This is a re-land of #16673 with updated test data. Previously, the error stats logic would throw when a large tensor is passed in due to size limitations of torch.quantile. The above PR switched it to use torch.median, but this has differing behavior for tensors with an even number of elements. Quantile will return the mean of the two medians, but torch.median takes the lower. Either is fine for this use case - we just need to update the test logic accordingly. I verified that the error stats unit tests pass with this change.
1 parent af9ce0b commit af69d08

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

backends/test/harness/error_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
3333
return cls(
3434
shape=tensor.shape,
3535
numel=tensor.numel(),
36-
median=torch.quantile(flattened, q=0.5).item(),
36+
median=torch.median(flattened).item(),
3737
mean=flattened.mean().item(),
3838
max=flattened.max().item(),
3939
min=flattened.min().item(),

backends/test/harness/tests/test_error_statistics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def test_error_stats_simple(self):
1414
# Check actual tensor statistics
1515
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
1616
self.assertEqual(error_stats.actual_stats.numel, 4)
17-
self.assertEqual(error_stats.actual_stats.median, 2.5)
17+
self.assertEqual(
18+
error_stats.actual_stats.median, 2
19+
) # torch.median takes the lower median
1820
self.assertEqual(error_stats.actual_stats.mean, 2.5)
1921
self.assertEqual(error_stats.actual_stats.max, 4)
2022
self.assertEqual(error_stats.actual_stats.min, 1)
@@ -44,15 +46,15 @@ def test_error_stats_different_shapes(self):
4446
# Check actual tensor statistics
4547
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
4648
self.assertEqual(error_stats.actual_stats.numel, 4)
47-
self.assertEqual(error_stats.actual_stats.median, 2.5)
49+
self.assertEqual(error_stats.actual_stats.median, 2)
4850
self.assertEqual(error_stats.actual_stats.mean, 2.5)
4951
self.assertEqual(error_stats.actual_stats.max, 4)
5052
self.assertEqual(error_stats.actual_stats.min, 1)
5153

5254
# Check reference tensor statistics
5355
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
5456
self.assertEqual(error_stats.reference_stats.numel, 4)
55-
self.assertEqual(error_stats.reference_stats.median, 3.5)
57+
self.assertEqual(error_stats.reference_stats.median, 3)
5658
self.assertEqual(error_stats.reference_stats.mean, 3.5)
5759
self.assertEqual(error_stats.reference_stats.max, 5)
5860
self.assertEqual(error_stats.reference_stats.min, 2)

0 commit comments

Comments
 (0)