Skip to content

Commit 4b161a0

Browse files
authored
Merge branch 'dev' into fix/8239-cldice-loss-enhancements
2 parents 4f91217 + 71f16ed commit 4b161a0

2 files changed

Lines changed: 14 additions & 11 deletions

File tree

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253253
# compute anchor centers regarding to the image.
254254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255255
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2
257257
for axis in range(self.spatial_dims)
258258
]
259259

tests/apps/detection/utils/test_anchor_box.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class TestAnchorGenerator(unittest.TestCase):
4444
@parameterized.expand(TEST_CASES_2D)
4545
def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
4646
torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils")
47-
image_list, _ = optional_import("torchvision.models.detection.image_list")
4847

49-
# test it behaves the same with torchvision for 2d
48+
# test it behaves for new functionality of centered anchors
49+
# pytorch does not follow this functionality
5050
anchor = AnchorGenerator(**input_param, indexing="xy")
5151
anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param)
5252
for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):
@@ -56,15 +56,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
5656

5757
grid_sizes = [[2, 2], [1, 1]]
5858
strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]]
59-
for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)):
60-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
6159

62-
images = torch.rand(image_shape)
63-
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
64-
result = anchor(images, feature_maps)
65-
result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps)
66-
for a, a_f in zip(result, result_ref):
67-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)
60+
monai_anchors = anchor.grid_anchors(grid_sizes, strides)
61+
torchvision_anchors = anchor_ref.grid_anchors(grid_sizes, strides)
62+
63+
for a, a_f, s in zip(monai_anchors, torchvision_anchors, strides):
64+
stride_y, stride_x = s
65+
66+
offset_x = stride_x // 2
67+
offset_y = stride_y // 2
68+
offset = torch.tensor([offset_x, offset_y, offset_x, offset_y], dtype=a_f.dtype, device=a_f.device)
69+
70+
assert_allclose(a, a_f + offset, type_test=True, device_test=False, atol=1e-3)
6871

6972
@parameterized.expand(TEST_CASES_2D)
7073
def test_script_2d(self, input_param, image_shape, feature_maps_shapes):

0 commit comments

Comments
 (0)