2121
2222
2323def affine_grid (data , target_shape , align_corners = True ):
24- """affine_grid operator that generates 2D sampling grid.
24+ """affine_grid operator that generates a 2D or 3D sampling grid.
2525
2626 This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
2727 sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
@@ -30,10 +30,10 @@ def affine_grid(data, target_shape, align_corners=True):
3030 Parameters
3131 ----------
3232 data : tvm.Tensor
33- 3-D with shape [batch, 2, 3]. The affine matrix.
33+ 3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D . The affine matrix.
3434
35- target_shape: list/tuple of two int
36- Specifies the output shape (H, W).
35+ target_shape: list/tuple of int
36+ Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D .
3737
3838 align_corners : bool
3939 If True, normalized coordinates map to corner pixels; if False, to pixel centers
@@ -42,35 +42,35 @@ def affine_grid(data, target_shape, align_corners=True):
4242 Returns
4343 -------
4444 Output : tvm.Tensor
45- 4-D with shape [batch, 2, target_height, target_width]
45+ [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
4646 """
47- assert target_shape is not None
48- assert len (target_shape ) == 2
47+ assert len (target_shape ) in (2 , 3 )
4948 if align_corners :
50- assert target_shape [ 0 ] > 1 and target_shape [ 1 ] > 1 , (
51- "target height/width should be greater than 1 when align_corners is True"
49+ assert all ( s > 1 for s in target_shape ) , (
50+ "target spatial dims should be greater than 1 when align_corners is True"
5251 )
5352
5453 dtype = data .dtype
55- height , width = target_shape [0 ], target_shape [1 ]
5654 if align_corners :
57- y_step = tirx .const ((2.0 - 1e-7 ) / (height - 1 ), dtype = dtype )
58- x_step = tirx .const ((2.0 - 1e-7 ) / (width - 1 ), dtype = dtype )
59- y_start = tirx .const (- 1.0 , dtype = dtype )
60- x_start = tirx .const (- 1.0 , dtype = dtype )
55+ starts = [tirx .const (- 1.0 , dtype = dtype ) for _ in target_shape ]
56+ steps = [tirx .const ((2.0 - 1e-7 ) / (s - 1 ), dtype = dtype ) for s in target_shape ]
6157 else :
6258 # Pixel centers: coordinate i maps to (2 * i + 1) / size - 1.
63- y_step = tirx .const (2.0 / height , dtype = dtype )
64- x_step = tirx .const (2.0 / width , dtype = dtype )
65- y_start = tirx .const (- 1.0 + 1.0 / height , dtype = dtype )
66- x_start = tirx .const (- 1.0 + 1.0 / width , dtype = dtype )
59+ starts = [tirx .const (- 1.0 + 1.0 / s , dtype = dtype ) for s in target_shape ]
60+ steps = [tirx .const (2.0 / s , dtype = dtype ) for s in target_shape ]
6761
68- def _compute (n , dim , i , j ):
69- y = y_start + i * y_step
70- x = x_start + j * x_step
71- return data [n , dim , 0 ] * x + data [n , dim , 1 ] * y + data [n , dim , 2 ]
62+ ndim = len (target_shape )
7263
73- oshape = (data .shape [0 ], len (target_shape ), * target_shape )
64+ def _compute (n , dim , * coords ):
65+ # coords are ordered slowest-to-fastest (e.g. (k, i, j)); the affine matrix
66+ # columns are fastest-to-slowest (x, y, z), so index it in reverse.
67+ val = data [n , dim , ndim ] # translation column
68+ for r in range (ndim ):
69+ coord = starts [r ] + coords [r ] * steps [r ]
70+ val += data [n , dim , ndim - 1 - r ] * coord
71+ return val
72+
73+ oshape = (data .shape [0 ], ndim , * target_shape )
7474 return te .compute (oshape , _compute , tag = "affine_grid" )
7575
7676
0 commit comments