@@ -145,9 +145,13 @@ def get_inx(
145145 start_x = 0 ,
146146 end_x = - 1 ,
147147 use_int_div = False ,
148+ scale_x_override = None ,
148149):
149150 """Infer input x from output x with various coordinate transformation methods"""
150- scale_x = te .div (image_width .astype ("float" ), target_width .astype ("float" ))
151+ if scale_x_override is not None :
152+ scale_x = scale_x_override
153+ else :
154+ scale_x = te .div (image_width .astype ("float" ), target_width .astype ("float" ))
151155 if coordinate_transformation_mode == "half_pixel" :
152156 in_x = (x + 0.5 ) * scale_x - 0.5
153157 elif coordinate_transformation_mode == "align_corners" :
@@ -237,6 +241,7 @@ def _resize_1d(
237241 alpha = - 0.5 ,
238242 exclude_outside = 0 ,
239243 out_dtype = None ,
244+ scale_x = None ,
240245):
241246 """Perform resize operation on the data with selected method and options.
242247
@@ -315,7 +320,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
315320 if boxes is not None :
316321 # TODO(mbrookhart): Find an example of this
317322 raise NotImplementedError ("resize1d with image boxes not yet implemented" )
318- in_x = get_inx (x , image_width , target_width , coordinate_transformation_mode , roi [0 ], roi [1 ])
323+ in_x = get_inx (
324+ x ,
325+ image_width ,
326+ target_width ,
327+ coordinate_transformation_mode ,
328+ roi [0 ],
329+ roi [1 ],
330+ scale_x_override = scale_x ,
331+ )
319332
320333 if method == "nearest_neighbor" :
321334 if rounding_method == "" :
@@ -383,6 +396,7 @@ def resize1d(
383396 extrapolation_value = 0.0 ,
384397 out_dtype = None ,
385398 output_shape = None ,
399+ scales = None ,
386400):
387401 """Perform resize operation on the data.
388402
@@ -472,6 +486,8 @@ def resize1d(
472486 if isinstance (size [i ], int ):
473487 size [i ] = tvm .tirx .IntImm ("int32" , size [i ])
474488
489+ scale_x = (1.0 / scales [0 ]) if scales is not None else None
490+
475491 def compute_func (* indices ):
476492 return _resize_1d (
477493 indices ,
@@ -487,6 +503,7 @@ def compute_func(*indices):
487503 exclude_outside = bicubic_exclude ,
488504 extrapolation_value = extrapolation_value ,
489505 out_dtype = out_dtype ,
506+ scale_x = scale_x ,
490507 )
491508
492509 return te .compute (output_shape , compute_func , name = "resize" , tag = tag .INJECTIVE )
@@ -510,6 +527,8 @@ def _resize_2d(
510527 alpha = - 0.5 ,
511528 exclude_outside = 0 ,
512529 out_dtype = None ,
530+ scale_h = None ,
531+ scale_w = None ,
513532):
514533 """Perform resize operation on the data with selected method and options.
515534
@@ -618,6 +637,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
618637 roi [1 ],
619638 roi [3 ],
620639 width_use_int_div ,
640+ scale_x_override = scale_w ,
621641 )
622642 in_y = get_inx (
623643 y ,
@@ -627,6 +647,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
627647 roi [0 ],
628648 roi [2 ],
629649 height_use_int_div ,
650+ scale_x_override = scale_h ,
630651 )
631652
632653 if method == "nearest_neighbor" :
@@ -756,6 +777,7 @@ def resize2d(
756777 extrapolation_value = 0.0 ,
757778 out_dtype = None ,
758779 output_shape = None ,
780+ scales = None ,
759781):
760782 """Perform resize operation on the data.
761783
@@ -839,6 +861,9 @@ def resize2d(
839861 if isinstance (size [i ], int ):
840862 size [i ] = tvm .tirx .IntImm ("int32" , size [i ])
841863
864+ scale_h = (1.0 / scales [0 ]) if scales is not None else None
865+ scale_w = (1.0 / scales [1 ]) if scales is not None else None
866+
842867 def compute_func (* indices ):
843868 return _resize_2d (
844869 indices ,
@@ -856,6 +881,8 @@ def compute_func(*indices):
856881 exclude_outside = bicubic_exclude ,
857882 extrapolation_value = extrapolation_value ,
858883 out_dtype = out_dtype ,
884+ scale_h = scale_h ,
885+ scale_w = scale_w ,
859886 )
860887
861888 return te .compute (output_shape , compute_func , name = "resize" , tag = tag .INJECTIVE )
@@ -976,6 +1003,9 @@ def _resize_3d(
9761003 alpha = - 0.5 ,
9771004 exclude_outside = 0 ,
9781005 out_dtype = None ,
1006+ scale_d = None ,
1007+ scale_h = None ,
1008+ scale_w = None ,
9791009):
9801010 """Perform resize operation on the data with selected method and options.
9811011
@@ -1066,9 +1096,33 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
10661096 if boxes is not None :
10671097 # TODO(mbrookhart): Find an example of this
10681098 raise NotImplementedError ("resize1d with image boxes not yet implemented" )
1069- in_z = get_inx (z , image_depth , target_depth , coordinate_transformation_mode , roi [2 ], roi [5 ])
1070- in_y = get_inx (y , image_height , target_height , coordinate_transformation_mode , roi [1 ], roi [4 ])
1071- in_x = get_inx (x , image_width , target_width , coordinate_transformation_mode , roi [0 ], roi [3 ])
1099+ in_z = get_inx (
1100+ z ,
1101+ image_depth ,
1102+ target_depth ,
1103+ coordinate_transformation_mode ,
1104+ roi [2 ],
1105+ roi [5 ],
1106+ scale_x_override = scale_d ,
1107+ )
1108+ in_y = get_inx (
1109+ y ,
1110+ image_height ,
1111+ target_height ,
1112+ coordinate_transformation_mode ,
1113+ roi [1 ],
1114+ roi [4 ],
1115+ scale_x_override = scale_h ,
1116+ )
1117+ in_x = get_inx (
1118+ x ,
1119+ image_width ,
1120+ target_width ,
1121+ coordinate_transformation_mode ,
1122+ roi [0 ],
1123+ roi [3 ],
1124+ scale_x_override = scale_w ,
1125+ )
10721126
10731127 if method == "nearest_neighbor" :
10741128 if rounding_method == "" :
@@ -1225,6 +1279,7 @@ def resize3d(
12251279 extrapolation_value = 0.0 ,
12261280 out_dtype = None ,
12271281 output_shape = None ,
1282+ scales = None ,
12281283):
12291284 """Perform resize operation on the data.
12301285
@@ -1302,6 +1357,10 @@ def resize3d(
13021357 if isinstance (size [i ], int ):
13031358 size [i ] = tvm .tirx .IntImm ("int32" , size [i ])
13041359
1360+ scale_d = (1.0 / scales [0 ]) if scales is not None else None
1361+ scale_h = (1.0 / scales [1 ]) if scales is not None else None
1362+ scale_w = (1.0 / scales [2 ]) if scales is not None else None
1363+
13051364 def compute_func (* indices ):
13061365 return _resize_3d (
13071366 indices ,
@@ -1321,6 +1380,9 @@ def compute_func(*indices):
13211380 exclude_outside = bicubic_exclude ,
13221381 extrapolation_value = extrapolation_value ,
13231382 out_dtype = out_dtype ,
1383+ scale_d = scale_d ,
1384+ scale_h = scale_h ,
1385+ scale_w = scale_w ,
13241386 )
13251387
13261388 return te .compute (output_shape , compute_func , name = "resize" , tag = tag .INJECTIVE )
0 commit comments