Skip to content

Commit 30233d5

Browse files
committed
[Relax][ONNX] Fix Resize coordinate error with non-integer scales
1 parent 15b1d98 commit 30233d5

2 files changed

Lines changed: 97 additions & 26 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,14 +3230,18 @@ def _impl_v18(cls, bb, inputs, attr, params):
32303230

32313231
use_dynamic_roi = roi_dynamic_vec is not None
32323232

3233-
# Convert scales to sizes if needed.
3233+
# Convert scales to sizes if needed, preserving the orginal spatial scales so
3234+
# the coordinate transformation uses the exact ONNX scale value rather than the
3235+
# lossy ratio derived from floor(input * scale) / input.
3236+
original_spatial_scales = None
32343237
if scales is not None:
32353238
if isinstance(scales, relax.Constant):
32363239
scales = scales.data.numpy()
32373240
elif isinstance(scales, relax.expr.ShapeExpr):
32383241
scales = [int(val.value) for val in scales.values]
32393242
else:
32403243
raise ValueError(f"Type {type(scales)} for scale is currently unsupported.")
3244+
original_spatial_scales = list(scales[2:])
32413245
sizes = []
32423246

32433247
for i, dim in enumerate(x.struct_info.shape):
@@ -3279,33 +3283,38 @@ def _impl_v18(cls, bb, inputs, attr, params):
32793283
cubic_coeff_a,
32803284
exclude_outside,
32813285
extrapolation_value,
3286+
scales=original_spatial_scales,
32823287
)
32833288
elif ndims == 4:
3284-
return relax.op.image.resize2d(
3289+
return bb.emit_te(
3290+
topi.image.resize2d,
32853291
x,
3286-
size=relax.ShapeExpr(sizes),
3287-
roi=roi_static,
3288-
layout="NCHW",
3289-
method=relax_mode,
3290-
coordinate_transformation_mode=coord_mode,
3291-
rounding_method=rounding_method,
3292-
cubic_alpha=cubic_coeff_a,
3293-
cubic_exclude=exclude_outside,
3294-
extrapolation_value=extrapolation_value,
3292+
roi_static,
3293+
sizes,
3294+
"NCHW",
3295+
topi_mode,
3296+
coord_mode,
3297+
rounding_method,
3298+
cubic_coeff_a,
3299+
exclude_outside,
3300+
extrapolation_value,
3301+
scales=original_spatial_scales,
32953302
)
32963303
else: # ndims == 5
32973304
roi3d = _topi_resize3d_roi_from_onnx_ncdhw_spatial(roi_static)
3298-
return relax.op.image.resize3d(
3305+
return bb.emit_te(
3306+
topi.image.resize3d,
32993307
x,
3300-
size=relax.ShapeExpr(sizes),
3301-
roi=roi3d,
3302-
layout="NCDHW",
3303-
method=relax_mode,
3304-
coordinate_transformation_mode=coord_mode,
3305-
rounding_method=rounding_method,
3306-
cubic_alpha=cubic_coeff_a,
3307-
cubic_exclude=exclude_outside,
3308-
extrapolation_value=extrapolation_value,
3308+
roi3d,
3309+
sizes,
3310+
"NCDHW",
3311+
relax_mode,
3312+
coord_mode,
3313+
rounding_method,
3314+
cubic_coeff_a,
3315+
exclude_outside,
3316+
extrapolation_value,
3317+
scales=original_spatial_scales,
33093318
)
33103319

33113320

python/tvm/topi/image/resize.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,13 @@ def get_inx(
145145
start_x=0,
146146
end_x=-1,
147147
use_int_div=False,
148+
scale_x_override=None,
148149
):
149150
"""Infer input x from output x with various coordinate transformation methods"""
150-
scale_x = te.div(image_width.astype("float"), target_width.astype("float"))
151+
if scale_x_override is not None:
152+
scale_x = scale_x_override
153+
else:
154+
scale_x = te.div(image_width.astype("float"), target_width.astype("float"))
151155
if coordinate_transformation_mode == "half_pixel":
152156
in_x = (x + 0.5) * scale_x - 0.5
153157
elif coordinate_transformation_mode == "align_corners":
@@ -237,6 +241,7 @@ def _resize_1d(
237241
alpha=-0.5,
238242
exclude_outside=0,
239243
out_dtype=None,
244+
scale_x=None,
240245
):
241246
"""Perform resize operation on the data with selected method and options.
242247
@@ -315,7 +320,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
315320
if boxes is not None:
316321
# TODO(mbrookhart): Find an example of this
317322
raise NotImplementedError("resize1d with image boxes not yet implemented")
318-
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[0], roi[1])
323+
in_x = get_inx(
324+
x,
325+
image_width,
326+
target_width,
327+
coordinate_transformation_mode,
328+
roi[0],
329+
roi[1],
330+
scale_x_override=scale_x,
331+
)
319332

320333
if method == "nearest_neighbor":
321334
if rounding_method == "":
@@ -383,6 +396,7 @@ def resize1d(
383396
extrapolation_value=0.0,
384397
out_dtype=None,
385398
output_shape=None,
399+
scales=None,
386400
):
387401
"""Perform resize operation on the data.
388402
@@ -472,6 +486,8 @@ def resize1d(
472486
if isinstance(size[i], int):
473487
size[i] = tvm.tirx.IntImm("int32", size[i])
474488

489+
scale_x = (1.0 / scales[0]) if scales is not None else None
490+
475491
def compute_func(*indices):
476492
return _resize_1d(
477493
indices,
@@ -487,6 +503,7 @@ def compute_func(*indices):
487503
exclude_outside=bicubic_exclude,
488504
extrapolation_value=extrapolation_value,
489505
out_dtype=out_dtype,
506+
scale_x=scale_x,
490507
)
491508

492509
return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
@@ -510,6 +527,8 @@ def _resize_2d(
510527
alpha=-0.5,
511528
exclude_outside=0,
512529
out_dtype=None,
530+
scale_h=None,
531+
scale_w=None,
513532
):
514533
"""Perform resize operation on the data with selected method and options.
515534
@@ -618,6 +637,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
618637
roi[1],
619638
roi[3],
620639
width_use_int_div,
640+
scale_x_override=scale_w,
621641
)
622642
in_y = get_inx(
623643
y,
@@ -627,6 +647,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
627647
roi[0],
628648
roi[2],
629649
height_use_int_div,
650+
scale_x_override=scale_h,
630651
)
631652

632653
if method == "nearest_neighbor":
@@ -756,6 +777,7 @@ def resize2d(
756777
extrapolation_value=0.0,
757778
out_dtype=None,
758779
output_shape=None,
780+
scales=None,
759781
):
760782
"""Perform resize operation on the data.
761783
@@ -839,6 +861,9 @@ def resize2d(
839861
if isinstance(size[i], int):
840862
size[i] = tvm.tirx.IntImm("int32", size[i])
841863

864+
scale_h = (1.0 / scales[0]) if scales is not None else None
865+
scale_w = (1.0 / scales[1]) if scales is not None else None
866+
842867
def compute_func(*indices):
843868
return _resize_2d(
844869
indices,
@@ -856,6 +881,8 @@ def compute_func(*indices):
856881
exclude_outside=bicubic_exclude,
857882
extrapolation_value=extrapolation_value,
858883
out_dtype=out_dtype,
884+
scale_h = scale_h,
885+
scale_w = scale_w,
859886
)
860887

861888
return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
@@ -976,6 +1003,9 @@ def _resize_3d(
9761003
alpha=-0.5,
9771004
exclude_outside=0,
9781005
out_dtype=None,
1006+
scale_d=None,
1007+
scale_h=None,
1008+
scale_w=None,
9791009
):
9801010
"""Perform resize operation on the data with selected method and options.
9811011
@@ -1066,9 +1096,33 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
10661096
if boxes is not None:
10671097
# TODO(mbrookhart): Find an example of this
10681098
raise NotImplementedError("resize1d with image boxes not yet implemented")
1069-
in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode, roi[2], roi[5])
1070-
in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode, roi[1], roi[4])
1071-
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[0], roi[3])
1099+
in_z = get_inx(
1100+
z,
1101+
image_depth,
1102+
target_depth,
1103+
coordinate_transformation_mode,
1104+
roi[2],
1105+
roi[5],
1106+
scale_x_override=scale_d,
1107+
)
1108+
in_y = get_inx(
1109+
y,
1110+
image_height,
1111+
target_height,
1112+
coordinate_transformation_mode,
1113+
roi[1],
1114+
roi[4],
1115+
scale_x_override=scale_h,
1116+
)
1117+
in_x = get_inx(
1118+
x,
1119+
image_width,
1120+
target_width,
1121+
coordinate_transformation_mode,
1122+
roi[0],
1123+
roi[3],
1124+
scale_x_override=scale_w,
1125+
)
10721126

10731127
if method == "nearest_neighbor":
10741128
if rounding_method == "":
@@ -1225,6 +1279,7 @@ def resize3d(
12251279
extrapolation_value=0.0,
12261280
out_dtype=None,
12271281
output_shape=None,
1282+
scales=None,
12281283
):
12291284
"""Perform resize operation on the data.
12301285
@@ -1302,6 +1357,10 @@ def resize3d(
13021357
if isinstance(size[i], int):
13031358
size[i] = tvm.tirx.IntImm("int32", size[i])
13041359

1360+
scale_d = (1.0 / scales[0]) if scales is not None else None
1361+
scale_h = (1.0 / scales[1]) if scales is not None else None
1362+
scale_w = (1.0 / scales[2]) if scales is not None else None
1363+
13051364
def compute_func(*indices):
13061365
return _resize_3d(
13071366
indices,
@@ -1321,6 +1380,9 @@ def compute_func(*indices):
13211380
exclude_outside=bicubic_exclude,
13221381
extrapolation_value=extrapolation_value,
13231382
out_dtype=out_dtype,
1383+
scale_d=scale_d,
1384+
scale_h=scale_h,
1385+
scale_w=scale_w,
13241386
)
13251387

13261388
return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)

0 commit comments

Comments
 (0)