Skip to content

Commit 71f16ed

Browse files
hachojericspod
andauthored
7951 anchors are not centered on grid cells (#8475)
This pull request resolves issue #7951 by correctly centering anchor boxes within their grid cells. ### Description **Changes Made:** - Modified the anchor generation logic in `monai/apps/detection/utils/anchor_utils.py` to add a `stride // 2` offset. - Refactored the corresponding unit test in `tests/apps/detection/utils/test_anchor_box.py` to validate this new, correct behavior against the torchvision baseline by accounting for the offset. Fixes #7951 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. - [ ] --------- Signed-off-by: hachoj <hjchojnowski@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 3bb9d8e commit 71f16ed

2 files changed

Lines changed: 14 additions & 11 deletions

File tree

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253253
# compute anchor centers regarding to the image.
254254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255255
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2
257257
for axis in range(self.spatial_dims)
258258
]
259259

tests/apps/detection/utils/test_anchor_box.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class TestAnchorGenerator(unittest.TestCase):
4444
@parameterized.expand(TEST_CASES_2D)
4545
def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
4646
torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils")
47-
image_list, _ = optional_import("torchvision.models.detection.image_list")
4847

49-
# test it behaves the same with torchvision for 2d
48+
# test it behaves for new functionality of centered anchors
49+
# pytorch does not follow this functionality
5050
anchor = AnchorGenerator(**input_param, indexing="xy")
5151
anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param)
5252
for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):
@@ -56,15 +56,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
5656

5757
grid_sizes = [[2, 2], [1, 1]]
5858
strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]]
59-
for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)):
60-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
6159

62-
images = torch.rand(image_shape)
63-
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
64-
result = anchor(images, feature_maps)
65-
result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps)
66-
for a, a_f in zip(result, result_ref):
67-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)
60+
monai_anchors = anchor.grid_anchors(grid_sizes, strides)
61+
torchvision_anchors = anchor_ref.grid_anchors(grid_sizes, strides)
62+
63+
for a, a_f, s in zip(monai_anchors, torchvision_anchors, strides):
64+
stride_y, stride_x = s
65+
66+
offset_x = stride_x // 2
67+
offset_y = stride_y // 2
68+
offset = torch.tensor([offset_x, offset_y, offset_x, offset_y], dtype=a_f.dtype, device=a_f.device)
69+
70+
assert_allclose(a, a_f + offset, type_test=True, device_test=False, atol=1e-3)
6871

6972
@parameterized.expand(TEST_CASES_2D)
7073
def test_script_2d(self, input_param, image_shape, feature_maps_shapes):

0 commit comments

Comments
 (0)