Skip to content

Commit 8f704ef

Browse files
committed
[Relax][ONNX] Support 3D AffineGrid
1 parent a9ce41e commit 8f704ef

5 files changed

Lines changed: 89 additions & 47 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,15 +3322,17 @@ def _impl_v20(cls, bb, inputs, attr, params):
33223322
else:
33233323
raise NotImplementedError(f"Dynamic size of type {type(size)} is not supported")
33243324

3325-
# Only 2D is supported: size = [N, C, H, W]
3326-
if len(size_vals) != 4:
3327-
raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is supported")
3328-
target_h, target_w = size_vals[2], size_vals[3]
3329-
3330-
# Relax affine_grid outputs [N, 2, H, W]
3331-
grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w)))
3332-
# Permute to ONNX convention [N, H, W, 2]
3333-
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
3325+
if len(size_vals) == 4:
3326+
# 2D: size = [N, C, H, W]; relax affine_grid outputs [N, 2, H, W].
3327+
grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:])))
3328+
# Permute to ONNX convention [N, H, W, 2].
3329+
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
3330+
if len(size_vals) == 5:
3331+
# 3D: size = [N, C, D, H, W]; relax affine_grid outputs [N, 3, D, H, W].
3332+
grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:])))
3333+
# Permute to ONNX convention [N, D, H, W, 3].
3334+
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 4, 1]))
3335+
raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or [N,C,D,H,W] (3D)")
33343336

33353337

33363338
class Einsum(OnnxOpConverter):

python/tvm/relax/op/image/image.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def affine_grid(
238238
data: Expr,
239239
size: Expr | SizeLike,
240240
) -> Expr:
241-
"""Generate a 2D sampling grid using an affine transformation matrix.
241+
"""Generate a 2D or 3D sampling grid using an affine transformation matrix.
242242
243243
This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
244244
It generates a uniform sampling grid within the target shape, normalizes it
@@ -247,16 +247,19 @@ def affine_grid(
247247
Parameters
248248
----------
249249
data : relax.Expr
250-
The input affine matrix tensor with shape [batch, 2, 3].
250+
The input affine matrix tensor with shape [batch, 2, 3] for 2D or
251+
[batch, 3, 4] for 3D.
251252
252-
size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
253-
The target output spatial shape (H, W). If a single integer or PrimExpr
254-
is provided, it is interpreted as a square output shape (size, size).
253+
size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
254+
The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If a
255+
single integer or PrimExpr is provided, it is interpreted as a square 2D
256+
output shape (size, size).
255257
256258
Returns
257259
-------
258260
result : relax.Expr
259-
The output grid tensor with shape [batch, 2, H, W].
261+
The output grid tensor with shape [batch, 2, H, W] for 2D or
262+
[batch, 3, D, H, W] for 3D.
260263
261264
Note
262265
----

python/tvm/topi/image/grid_sample.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def affine_grid(data, target_shape):
24-
"""affine_grid operator that generates 2D sampling grid.
24+
"""affine_grid operator that generates a 2D or 3D sampling grid.
2525
2626
This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
2727
sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
@@ -30,31 +30,38 @@ def affine_grid(data, target_shape):
3030
Parameters
3131
----------
3232
data : tvm.Tensor
33-
3-D with shape [batch, 2, 3]. The affine matrix.
33+
3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The affine matrix.
3434
35-
target_shape: list/tuple of two int
36-
Specifies the output shape (H, W).
35+
target_shape: list/tuple of int
36+
Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.
3737
3838
Returns
3939
-------
4040
Output : tvm.Tensor
41-
4-D with shape [batch, 2, target_height, target_width]
41+
[batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
4242
"""
4343
assert target_shape is not None
44-
assert len(target_shape) == 2
45-
assert target_shape[0] > 1 and target_shape[1] > 1, (
46-
"target height/width should be greater than 1"
47-
)
44+
assert len(target_shape) in (2, 3)
45+
assert all(s > 1 for s in target_shape), "target spatial dims should be greater than 1"
4846

4947
dtype = data.dtype
50-
y_step = tirx.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype)
51-
x_step = tirx.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype)
5248
start = tirx.const(-1.0, dtype=dtype)
49+
steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in target_shape]
50+
51+
if len(target_shape) == 2:
52+
53+
def _compute(n, dim, i, j):
54+
y = start + i * steps[0]
55+
x = start + j * steps[1]
56+
return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
57+
58+
else:
5359

54-
def _compute(n, dim, i, j):
55-
y = start + i * y_step
56-
x = start + j * x_step
57-
return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
60+
def _compute(n, dim, k, i, j):
61+
z = start + k * steps[0]
62+
y = start + i * steps[1]
63+
x = start + j * steps[2]
64+
return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] * z + data[n, dim, 3]
5865

5966
oshape = (data.shape[0], len(target_shape), *target_shape)
6067
return te.compute(oshape, _compute, tag="affine_grid")

src/relax/op/image/resize.cc

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -385,51 +385,55 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) {
385385
<< "AffineGrid expects the target size to be a Shape, while the given one is "
386386
<< call->args[1]->GetTypeKey();
387387
}
388-
if (size_ty->ndim != 2) {
388+
// 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, H, W]).
389+
if (size_ty->ndim != 2 && size_ty->ndim != 3) {
389390
TVM_FFI_VISIT_THROW(ValueError, call)
390-
<< "AffineGrid expects the target size to be a 2-dim shape, while the given "
391+
<< "AffineGrid expects the target size to be a 2-dim or 3-dim shape, while the given "
391392
"one has ndim "
392393
<< size_ty->ndim;
393394
}
395+
const int spatial = size_ty->ndim;
394396

395-
// data should be 3-D: [batch, 2, 3]
397+
// data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 3, 4]).
396398
if (data_ty->ndim != -1 && data_ty->ndim != 3) {
397-
TVM_FFI_VISIT_THROW(ValueError, call)
398-
<< "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim "
399-
<< data_ty->ndim;
399+
TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input data to be 3-D (batch, "
400+
"spatial, spatial + 1), but got ndim "
401+
<< data_ty->ndim;
400402
}
401403

402404
const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
403405
if (data_shape != nullptr) {
404-
// Check that the affine matrix has shape [batch, 2, 3]
405406
if (data_shape->values.size() >= 2) {
406407
auto* dim1 = data_shape->values[1].as<IntImmNode>();
407-
if (dim1 != nullptr && dim1->value != 2) {
408+
if (dim1 != nullptr && dim1->value != spatial) {
408409
TVM_FFI_VISIT_THROW(ValueError, call)
409-
<< "AffineGrid expects the second dimension of input to be 2, but got " << dim1->value;
410+
<< "AffineGrid expects the second dimension of input to be " << spatial << ", but got "
411+
<< dim1->value;
410412
}
411413
}
412414
if (data_shape->values.size() >= 3) {
413415
auto* dim2 = data_shape->values[2].as<IntImmNode>();
414-
if (dim2 != nullptr && dim2->value != 3) {
416+
if (dim2 != nullptr && dim2->value != spatial + 1) {
415417
TVM_FFI_VISIT_THROW(ValueError, call)
416-
<< "AffineGrid expects the third dimension of input to be 3, but got " << dim2->value;
418+
<< "AffineGrid expects the third dimension of input to be " << spatial + 1
419+
<< ", but got " << dim2->value;
417420
}
418421
}
419422
}
420423

421424
DataType out_dtype = data_ty->dtype;
422425

423426
if (data_shape == nullptr || size_value == nullptr) {
424-
return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
427+
return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
425428
}
426429

427-
// Output shape: [batch, 2, target_height, target_width]
430+
// Output shape: [batch, spatial, *target_spatial_dims].
428431
ffi::Array<PrimExpr> out_shape;
429-
out_shape.push_back(data_shape->values[0]); // batch
430-
out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions)
431-
out_shape.push_back(size_value->values[0]); // target_height
432-
out_shape.push_back(size_value->values[1]); // target_width
432+
out_shape.push_back(data_shape->values[0]); // batch
433+
out_shape.push_back(IntImm::Int64(spatial)); // number of spatial coordinates
434+
for (int i = 0; i < spatial; ++i) {
435+
out_shape.push_back(size_value->values[i]); // target spatial dim
436+
}
433437

434438
return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
435439
}

tests/python/relax/test_frontend_onnx.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5581,6 +5581,32 @@ def test_affine_grid():
55815581
check_correctness(model, opset=20)
55825582

55835583

5584+
def test_affine_grid_3d():
5585+
affine_grid_node = helper.make_node(
5586+
"AffineGrid",
5587+
["theta", "size"],
5588+
["grid"],
5589+
align_corners=1,
5590+
)
5591+
5592+
graph = helper.make_graph(
5593+
[affine_grid_node],
5594+
"affine_grid_3d_test",
5595+
inputs=[
5596+
helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 4]),
5597+
],
5598+
initializer=[
5599+
helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 16]),
5600+
],
5601+
outputs=[
5602+
helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 16, 16, 3]),
5603+
],
5604+
)
5605+
5606+
model = helper.make_model(graph, producer_name="affine_grid_3d_test")
5607+
check_correctness(model, opset=20)
5608+
5609+
55845610
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
55855611
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
55865612
@pytest.mark.parametrize("align_corners", [0, 1])

0 commit comments

Comments
 (0)