Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3319,15 +3319,14 @@ def _impl_v20(cls, bb, inputs, attr, params):
else:
raise NotImplementedError(f"Dynamic size of type {type(size)} is not supported")

# Only 2D is supported: size = [N, C, H, W]
if len(size_vals) != 4:
raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is supported")
target_h, target_w = size_vals[2], size_vals[3]

# Relax affine_grid outputs [N, 2, H, W]
grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w), align_corners))
# Permute to ONNX convention [N, H, W, 2]
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
if len(size_vals) not in (4, 5):
raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or [N,C,D,H,W] (3D)")

# relax affine_grid outputs [N, spatial, *spatial_dims]; move the coord axis
# last to match the ONNX convention [N, *spatial_dims, spatial].
grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]), align_corners))
axes = [0, *range(2, len(size_vals)), 1]
return bb.emit(relax.op.permute_dims(grid, axes=axes))


class Einsum(OnnxOpConverter):
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relax/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def affine_grid(
size: Expr | SizeLike,
align_corners: bool = True,
) -> Expr:
"""Generate a 2D sampling grid using an affine transformation matrix.
"""Generate a 2D or 3D sampling grid using an affine transformation matrix.

This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
It generates a uniform sampling grid within the target shape, normalizes it
Expand All @@ -248,11 +248,13 @@ def affine_grid(
Parameters
----------
data : relax.Expr
The input affine matrix tensor with shape [batch, 2, 3].
The input affine matrix tensor with shape [batch, 2, 3] for 2D or
[batch, 3, 4] for 3D.

size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
The target output spatial shape (H, W). If a single integer or PrimExpr
is provided, it is interpreted as a square output shape (size, size).
size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If a
single integer or PrimExpr is provided, it is interpreted as a square 2D
output shape (size, size).

align_corners : bool
If True, normalized grid coordinates map to corner pixels; if False, to
Expand All @@ -261,7 +263,8 @@ def affine_grid(
Returns
-------
result : relax.Expr
The output grid tensor with shape [batch, 2, H, W].
The output grid tensor with shape [batch, 2, H, W] for 2D or
[batch, 3, D, H, W] for 3D.
"""
if isinstance(size, int | PrimExpr):
size = (size, size)
Expand Down
46 changes: 23 additions & 23 deletions python/tvm/topi/image/grid_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def affine_grid(data, target_shape, align_corners=True):
"""affine_grid operator that generates 2D sampling grid.
"""affine_grid operator that generates a 2D or 3D sampling grid.

This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
Expand All @@ -30,10 +30,10 @@ def affine_grid(data, target_shape, align_corners=True):
Parameters
----------
data : tvm.Tensor
3-D with shape [batch, 2, 3]. The affine matrix.
3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The affine matrix.

target_shape: list/tuple of two int
Specifies the output shape (H, W).
target_shape: list/tuple of int
Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.

align_corners : bool
If True, normalized coordinates map to corner pixels; if False, to pixel centers
Expand All @@ -42,35 +42,35 @@ def affine_grid(data, target_shape, align_corners=True):
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, target_height, target_width]
[batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
"""
assert target_shape is not None
assert len(target_shape) == 2
assert len(target_shape) in (2, 3)
if align_corners:
assert target_shape[0] > 1 and target_shape[1] > 1, (
"target height/width should be greater than 1 when align_corners is True"
assert all(s > 1 for s in target_shape), (
"target spatial dims should be greater than 1 when align_corners is True"
)

dtype = data.dtype
height, width = target_shape[0], target_shape[1]
if align_corners:
y_step = tirx.const((2.0 - 1e-7) / (height - 1), dtype=dtype)
x_step = tirx.const((2.0 - 1e-7) / (width - 1), dtype=dtype)
y_start = tirx.const(-1.0, dtype=dtype)
x_start = tirx.const(-1.0, dtype=dtype)
starts = [tirx.const(-1.0, dtype=dtype) for _ in target_shape]
steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in target_shape]
else:
# Pixel centers: coordinate i maps to (2 * i + 1) / size - 1.
y_step = tirx.const(2.0 / height, dtype=dtype)
x_step = tirx.const(2.0 / width, dtype=dtype)
y_start = tirx.const(-1.0 + 1.0 / height, dtype=dtype)
x_start = tirx.const(-1.0 + 1.0 / width, dtype=dtype)
starts = [tirx.const(-1.0 + 1.0 / s, dtype=dtype) for s in target_shape]
steps = [tirx.const(2.0 / s, dtype=dtype) for s in target_shape]

def _compute(n, dim, i, j):
y = y_start + i * y_step
x = x_start + j * x_step
return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
ndim = len(target_shape)

oshape = (data.shape[0], len(target_shape), *target_shape)
def _compute(n, dim, *coords):
# coords are ordered slowest-to-fastest (e.g. (k, i, j)); the affine matrix
# columns are fastest-to-slowest (x, y, z), so index it in reverse.
val = data[n, dim, ndim] # translation column
for r in range(ndim):
coord = starts[r] + coords[r] * steps[r]
val += data[n, dim, ndim - 1 - r] * coord
return val

oshape = (data.shape[0], ndim, *target_shape)
return te.compute(oshape, _compute, tag="affine_grid")


Expand Down
38 changes: 21 additions & 17 deletions src/relax/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,51 +389,55 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) {
<< "AffineGrid expects the target size to be a Shape, while the given one is "
<< call->args[1]->GetTypeKey();
}
if (size_ty->ndim != 2) {
// 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, H, W]).
if (size_ty->ndim != 2 && size_ty->ndim != 3) {
TVM_FFI_VISIT_THROW(ValueError, call)
<< "AffineGrid expects the target size to be a 2-dim shape, while the given "
<< "AffineGrid expects the target size to be a 2-dim or 3-dim shape, while the given "
"one has ndim "
<< size_ty->ndim;
}
const int spatial = size_ty->ndim;

// data should be 3-D: [batch, 2, 3]
// data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 3, 4]).
if (data_ty->ndim != -1 && data_ty->ndim != 3) {
TVM_FFI_VISIT_THROW(ValueError, call)
<< "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim "
<< data_ty->ndim;
TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input data to be 3-D (batch, "
"spatial, spatial + 1), but got ndim "
<< data_ty->ndim;
}

const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
if (data_shape != nullptr) {
// Check that the affine matrix has shape [batch, 2, 3]
if (data_shape->values.size() >= 2) {
auto* dim1 = data_shape->values[1].as<IntImmNode>();
if (dim1 != nullptr && dim1->value != 2) {
if (dim1 != nullptr && dim1->value != spatial) {
TVM_FFI_VISIT_THROW(ValueError, call)
<< "AffineGrid expects the second dimension of input to be 2, but got " << dim1->value;
<< "AffineGrid expects the second dimension of input to be " << spatial << ", but got "
<< dim1->value;
}
}
if (data_shape->values.size() >= 3) {
auto* dim2 = data_shape->values[2].as<IntImmNode>();
if (dim2 != nullptr && dim2->value != 3) {
if (dim2 != nullptr && dim2->value != spatial + 1) {
TVM_FFI_VISIT_THROW(ValueError, call)
<< "AffineGrid expects the third dimension of input to be 3, but got " << dim2->value;
<< "AffineGrid expects the third dimension of input to be " << spatial + 1
<< ", but got " << dim2->value;
}
}
}

DataType out_dtype = data_ty->dtype;

if (data_shape == nullptr || size_value == nullptr) {
return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
}

// Output shape: [batch, 2, target_height, target_width]
// Output shape: [batch, spatial, *target_spatial_dims].
ffi::Array<PrimExpr> out_shape;
out_shape.push_back(data_shape->values[0]); // batch
out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions)
out_shape.push_back(size_value->values[0]); // target_height
out_shape.push_back(size_value->values[1]); // target_width
out_shape.push_back(data_shape->values[0]); // batch
out_shape.push_back(IntImm::Int64(spatial)); // number of spatial coordinates
for (int i = 0; i < spatial; ++i) {
out_shape.push_back(size_value->values[i]); // target spatial dim
}

return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
}
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5579,6 +5579,37 @@ def test_affine_grid(align_corners):
check_correctness(model, opset=20)


def test_affine_grid_3d():
affine_grid_node = helper.make_node(
"AffineGrid",
["theta", "size"],
["grid"],
align_corners=1,
)

graph = helper.make_graph(
[affine_grid_node],
"affine_grid_3d_test",
inputs=[
helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 4]),
],
initializer=[
helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 16]),
],
outputs=[
helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 16, 16, 3]),
],
)

model = helper.make_model(graph, producer_name="affine_grid_3d_test")

tvm_model = from_onnx(model, opset=20, keep_params_in_input=True)
call_ops = collect_relax_call_ops(tvm_model["main"])
assert "relax.image.affine_grid" in call_ops
assert "relax.permute_dims" in call_ops
assert [int(d) for d in tvm_model["main"].ret_ty.shape] == [2, 8, 16, 16, 3]


@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/test_transform_legalize_ops_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def affine_grid(var_theta: T.handle, var_compute: T.handle):
with T.sblock("root"):
T.reads()
T.writes()
for n, dim, i, j in T.grid(T.int64(2), T.int64(2), T.int64(16), T.int64(16)):
for n, dim, i0, i1 in T.grid(T.int64(2), T.int64(2), T.int64(16), T.int64(16)):
with T.sblock("compute"):
v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim, i, j])
v_n, v_dim, v_i0, v_i1 = T.axis.remap("SSSS", [n, dim, i0, i1])
T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
T.writes(compute[v_n, v_dim, v_i, v_j])
compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim, T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(2)]
T.writes(compute[v_n, v_dim, v_i0, v_i1])
compute[v_n, v_dim, v_i0, v_i1] = theta[v_n, v_dim, T.int64(2)] + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) + T.Cast("float32", v_i0) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_i1) * T.float32(0.13333332666666667))
# fmt: on

mod = LegalizeOps()(AffineGrid)
Expand Down
Loading