@@ -44,9 +44,9 @@ class TestAnchorGenerator(unittest.TestCase):
4444 @parameterized .expand (TEST_CASES_2D )
4545 def test_anchor_2d (self , input_param , image_shape , feature_maps_shapes ):
4646 torch_anchor_utils , _ = optional_import ("torchvision.models.detection.anchor_utils" )
47- image_list , _ = optional_import ("torchvision.models.detection.image_list" )
4847
49- # test it behaves the same with torchvision for 2d
48+ # test it behaves for new functionality of centered anchors
49+ # pytorch does not follow this functionality
5050 anchor = AnchorGenerator (** input_param , indexing = "xy" )
5151 anchor_ref = torch_anchor_utils .AnchorGenerator (** input_param )
5252 for a , a_f in zip (anchor .cell_anchors , anchor_ref .cell_anchors ):
@@ -56,15 +56,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
5656
5757 grid_sizes = [[2 , 2 ], [1 , 1 ]]
5858 strides = [[torch .tensor (1 ), torch .tensor (2 )], [torch .tensor (2 ), torch .tensor (4 )]]
59- for a , a_f in zip (anchor .grid_anchors (grid_sizes , strides ), anchor_ref .grid_anchors (grid_sizes , strides )):
60- assert_allclose (a , a_f , type_test = True , device_test = False , atol = 1e-3 )
6159
62- images = torch .rand (image_shape )
63- feature_maps = tuple (torch .rand (fs ) for fs in feature_maps_shapes )
64- result = anchor (images , feature_maps )
65- result_ref = anchor_ref (image_list .ImageList (images , ([123 , 122 ],)), feature_maps )
66- for a , a_f in zip (result , result_ref ):
67- assert_allclose (a , a_f , type_test = True , device_test = False , atol = 0.1 )
60+ monai_anchors = anchor .grid_anchors (grid_sizes , strides )
61+ torchvision_anchors = anchor_ref .grid_anchors (grid_sizes , strides )
62+
63+ for a , a_f , s in zip (monai_anchors , torchvision_anchors , strides ):
64+ stride_y , stride_x = s
65+
66+ offset_x = stride_x // 2
67+ offset_y = stride_y // 2
68+ offset = torch .tensor ([offset_x , offset_y , offset_x , offset_y ], dtype = a_f .dtype , device = a_f .device )
69+
70+ assert_allclose (a , a_f + offset , type_test = True , device_test = False , atol = 1e-3 )
6871
6972 @parameterized .expand (TEST_CASES_2D )
7073 def test_script_2d (self , input_param , image_shape , feature_maps_shapes ):
0 commit comments