Skip to content

Commit 35eddd9

Browse files
committed
[Relax][ONNX] Support 3D AffineGrid
1 parent 4650887 commit 35eddd9

6 files changed

Lines changed: 91 additions & 59 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3319,15 +3319,14 @@ def _impl_v20(cls, bb, inputs, attr, params):
33193319
else:
33203320
raise NotImplementedError(f"Dynamic size of type {type(size)} is not supported")
33213321

3322-
# Only 2D is supported: size = [N, C, H, W]
3323-
if len(size_vals) != 4:
3324-
raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is supported")
3325-
target_h, target_w = size_vals[2], size_vals[3]
3326-
3327-
# Relax affine_grid outputs [N, 2, H, W]
3328-
grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w), align_corners))
3329-
# Permute to ONNX convention [N, H, W, 2]
3330-
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
3322+
if len(size_vals) not in (4, 5):
3323+
raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or [N,C,D,H,W] (3D)")
3324+
3325+
# relax affine_grid outputs [N, spatial, *spatial_dims]; move the coord axis
3326+
# last to match the ONNX convention [N, *spatial_dims, spatial].
3327+
grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]), align_corners))
3328+
axes = [0, *range(2, len(size_vals)), 1]
3329+
return bb.emit(relax.op.permute_dims(grid, axes=axes))
33313330

33323331

33333332
class Einsum(OnnxOpConverter):

python/tvm/relax/op/image/image.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def affine_grid(
239239
size: Expr | SizeLike,
240240
align_corners: bool = True,
241241
) -> Expr:
242-
"""Generate a 2D sampling grid using an affine transformation matrix.
242+
"""Generate a 2D or 3D sampling grid using an affine transformation matrix.
243243
244244
This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
245245
It generates a uniform sampling grid within the target shape, normalizes it
@@ -248,11 +248,13 @@ def affine_grid(
248248
Parameters
249249
----------
250250
data : relax.Expr
251-
The input affine matrix tensor with shape [batch, 2, 3].
251+
The input affine matrix tensor with shape [batch, 2, 3] for 2D or
252+
[batch, 3, 4] for 3D.
252253
253-
size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
254-
The target output spatial shape (H, W). If a single integer or PrimExpr
255-
is provided, it is interpreted as a square output shape (size, size).
254+
size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
255+
The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If a
256+
single integer or PrimExpr is provided, it is interpreted as a square 2D
257+
output shape (size, size).
256258
257259
align_corners : bool
258260
If True, normalized grid coordinates map to corner pixels; if False, to
@@ -261,7 +263,8 @@ def affine_grid(
261263
Returns
262264
-------
263265
result : relax.Expr
264-
The output grid tensor with shape [batch, 2, H, W].
266+
The output grid tensor with shape [batch, 2, H, W] for 2D or
267+
[batch, 3, D, H, W] for 3D.
265268
"""
266269
if isinstance(size, int | PrimExpr):
267270
size = (size, size)

python/tvm/topi/image/grid_sample.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def affine_grid(data, target_shape, align_corners=True):
24-
"""affine_grid operator that generates 2D sampling grid.
24+
"""affine_grid operator that generates a 2D or 3D sampling grid.
2525
2626
This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
2727
sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
@@ -30,10 +30,10 @@ def affine_grid(data, target_shape, align_corners=True):
3030
Parameters
3131
----------
3232
data : tvm.Tensor
33-
3-D with shape [batch, 2, 3]. The affine matrix.
33+
3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The affine matrix.
3434
35-
target_shape: list/tuple of two int
36-
Specifies the output shape (H, W).
35+
target_shape: list/tuple of int
36+
Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.
3737
3838
align_corners : bool
3939
If True, normalized coordinates map to corner pixels; if False, to pixel centers
@@ -42,35 +42,35 @@ def affine_grid(data, target_shape, align_corners=True):
4242
Returns
4343
-------
4444
Output : tvm.Tensor
45-
4-D with shape [batch, 2, target_height, target_width]
45+
[batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
4646
"""
47-
assert target_shape is not None
48-
assert len(target_shape) == 2
47+
assert len(target_shape) in (2, 3)
4948
if align_corners:
50-
assert target_shape[0] > 1 and target_shape[1] > 1, (
51-
"target height/width should be greater than 1 when align_corners is True"
49+
assert all(s > 1 for s in target_shape), (
50+
"target spatial dims should be greater than 1 when align_corners is True"
5251
)
5352

5453
dtype = data.dtype
55-
height, width = target_shape[0], target_shape[1]
5654
if align_corners:
57-
y_step = tirx.const((2.0 - 1e-7) / (height - 1), dtype=dtype)
58-
x_step = tirx.const((2.0 - 1e-7) / (width - 1), dtype=dtype)
59-
y_start = tirx.const(-1.0, dtype=dtype)
60-
x_start = tirx.const(-1.0, dtype=dtype)
55+
starts = [tirx.const(-1.0, dtype=dtype) for _ in target_shape]
56+
steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in target_shape]
6157
else:
6258
# Pixel centers: coordinate i maps to (2 * i + 1) / size - 1.
63-
y_step = tirx.const(2.0 / height, dtype=dtype)
64-
x_step = tirx.const(2.0 / width, dtype=dtype)
65-
y_start = tirx.const(-1.0 + 1.0 / height, dtype=dtype)
66-
x_start = tirx.const(-1.0 + 1.0 / width, dtype=dtype)
59+
starts = [tirx.const(-1.0 + 1.0 / s, dtype=dtype) for s in target_shape]
60+
steps = [tirx.const(2.0 / s, dtype=dtype) for s in target_shape]
6761

68-
def _compute(n, dim, i, j):
69-
y = y_start + i * y_step
70-
x = x_start + j * x_step
71-
return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
62+
ndim = len(target_shape)
7263

73-
oshape = (data.shape[0], len(target_shape), *target_shape)
64+
def _compute(n, dim, *coords):
65+
# coords are ordered slowest-to-fastest (e.g. (k, i, j)); the affine matrix
66+
# columns are fastest-to-slowest (x, y, z), so index it in reverse.
67+
val = data[n, dim, ndim] # translation column
68+
for r in range(ndim):
69+
coord = starts[r] + coords[r] * steps[r]
70+
val += data[n, dim, ndim - 1 - r] * coord
71+
return val
72+
73+
oshape = (data.shape[0], ndim, *target_shape)
7474
return te.compute(oshape, _compute, tag="affine_grid")
7575

7676

src/relax/op/image/resize.cc

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,51 +389,55 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) {
389389
<< "AffineGrid expects the target size to be a Shape, while the given one is "
390390
<< call->args[1]->GetTypeKey();
391391
}
392-
if (size_ty->ndim != 2) {
392+
// 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, H, W]).
393+
if (size_ty->ndim != 2 && size_ty->ndim != 3) {
393394
TVM_FFI_VISIT_THROW(ValueError, call)
394-
<< "AffineGrid expects the target size to be a 2-dim shape, while the given "
395+
<< "AffineGrid expects the target size to be a 2-dim or 3-dim shape, while the given "
395396
"one has ndim "
396397
<< size_ty->ndim;
397398
}
399+
const int spatial = size_ty->ndim;
398400

399-
// data should be 3-D: [batch, 2, 3]
401+
// data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 3, 4]).
400402
if (data_ty->ndim != -1 && data_ty->ndim != 3) {
401-
TVM_FFI_VISIT_THROW(ValueError, call)
402-
<< "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim "
403-
<< data_ty->ndim;
403+
TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input data to be 3-D (batch, "
404+
"spatial, spatial + 1), but got ndim "
405+
<< data_ty->ndim;
404406
}
405407

406408
const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
407409
if (data_shape != nullptr) {
408-
// Check that the affine matrix has shape [batch, 2, 3]
409410
if (data_shape->values.size() >= 2) {
410411
auto* dim1 = data_shape->values[1].as<IntImmNode>();
411-
if (dim1 != nullptr && dim1->value != 2) {
412+
if (dim1 != nullptr && dim1->value != spatial) {
412413
TVM_FFI_VISIT_THROW(ValueError, call)
413-
<< "AffineGrid expects the second dimension of input to be 2, but got " << dim1->value;
414+
<< "AffineGrid expects the second dimension of input to be " << spatial << ", but got "
415+
<< dim1->value;
414416
}
415417
}
416418
if (data_shape->values.size() >= 3) {
417419
auto* dim2 = data_shape->values[2].as<IntImmNode>();
418-
if (dim2 != nullptr && dim2->value != 3) {
420+
if (dim2 != nullptr && dim2->value != spatial + 1) {
419421
TVM_FFI_VISIT_THROW(ValueError, call)
420-
<< "AffineGrid expects the third dimension of input to be 3, but got " << dim2->value;
422+
<< "AffineGrid expects the third dimension of input to be " << spatial + 1
423+
<< ", but got " << dim2->value;
421424
}
422425
}
423426
}
424427

425428
DataType out_dtype = data_ty->dtype;
426429

427430
if (data_shape == nullptr || size_value == nullptr) {
428-
return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
431+
return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
429432
}
430433

431-
// Output shape: [batch, 2, target_height, target_width]
434+
// Output shape: [batch, spatial, *target_spatial_dims].
432435
ffi::Array<PrimExpr> out_shape;
433-
out_shape.push_back(data_shape->values[0]); // batch
434-
out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions)
435-
out_shape.push_back(size_value->values[0]); // target_height
436-
out_shape.push_back(size_value->values[1]); // target_width
436+
out_shape.push_back(data_shape->values[0]); // batch
437+
out_shape.push_back(IntImm::Int64(spatial)); // number of spatial coordinates
438+
for (int i = 0; i < spatial; ++i) {
439+
out_shape.push_back(size_value->values[i]); // target spatial dim
440+
}
437441

438442
return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
439443
}

tests/python/relax/test_frontend_onnx.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5579,6 +5579,32 @@ def test_affine_grid(align_corners):
55795579
check_correctness(model, opset=20)
55805580

55815581

5582+
def test_affine_grid_3d():
5583+
affine_grid_node = helper.make_node(
5584+
"AffineGrid",
5585+
["theta", "size"],
5586+
["grid"],
5587+
align_corners=1,
5588+
)
5589+
5590+
graph = helper.make_graph(
5591+
[affine_grid_node],
5592+
"affine_grid_3d_test",
5593+
inputs=[
5594+
helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 4]),
5595+
],
5596+
initializer=[
5597+
helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 16]),
5598+
],
5599+
outputs=[
5600+
helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 16, 16, 3]),
5601+
],
5602+
)
5603+
5604+
model = helper.make_model(graph, producer_name="affine_grid_3d_test")
5605+
check_correctness(model, opset=20)
5606+
5607+
55825608
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
55835609
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
55845610
@pytest.mark.parametrize("align_corners", [0, 1])

tests/python/relax/test_transform_legalize_ops_image.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ def affine_grid(var_theta: T.handle, var_compute: T.handle):
126126
with T.sblock("root"):
127127
T.reads()
128128
T.writes()
129-
for n, dim, i, j in T.grid(T.int64(2), T.int64(2), T.int64(16), T.int64(16)):
129+
for n, dim, i0, i1 in T.grid(T.int64(2), T.int64(2), T.int64(16), T.int64(16)):
130130
with T.sblock("compute"):
131-
v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim, i, j])
131+
v_n, v_dim, v_i0, v_i1 = T.axis.remap("SSSS", [n, dim, i0, i1])
132132
T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
133-
T.writes(compute[v_n, v_dim, v_i, v_j])
134-
compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim, T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(2)]
133+
T.writes(compute[v_n, v_dim, v_i0, v_i1])
134+
compute[v_n, v_dim, v_i0, v_i1] = theta[v_n, v_dim, T.int64(2)] + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) + T.Cast("float32", v_i0) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_i1) * T.float32(0.13333332666666667))
135135
# fmt: on
136136

137137
mod = LegalizeOps()(AffineGrid)

0 commit comments

Comments
 (0)