@@ -506,24 +506,25 @@ def resize(
506506 else :
507507 raise ValueError (f"resize_mode { resize_mode } is not supported" )
508508
509- elif isinstance (image , torch .Tensor ):
510- torch_mode , use_antialias = TORCH_INTERPOLATION [self .config .resample ]
511- image = torch .nn .functional .interpolate (
512- image ,
513- size = (height , width ),
514- mode = torch_mode ,
515- antialias = use_antialias ,
516- )
517- elif isinstance (image , np .ndarray ):
518- torch_mode , use_antialias = TORCH_INTERPOLATION [self .config .resample ]
519- image = self .numpy_to_pt (image )
509+ elif isinstance (image , (torch .Tensor , np .ndarray )):
510+ resample = self .config .resample
511+ if resample not in TORCH_INTERPOLATION :
512+ logger .warning (
513+ f"The resample mode '{ resample } ' is not supported for torch.Tensor/np.ndarray inputs "
514+ f"and will be ignored. Supported modes are: { list (TORCH_INTERPOLATION .keys ())} . "
515+ "Falling back to default 'nearest' interpolation."
516+ )
517+ torch_mode , use_antialias = "nearest" , False
518+ else :
519+ torch_mode , use_antialias = TORCH_INTERPOLATION [resample ]
520+
521+ if isinstance (image , np .ndarray ):
522+ image = self .numpy_to_pt (image )
520523 image = torch .nn .functional .interpolate (
521- image ,
522- size = (height , width ),
523- mode = torch_mode ,
524- antialias = use_antialias ,
524+ image , size = (height , width ), mode = torch_mode , antialias = use_antialias
525525 )
526- image = self .pt_to_numpy (image )
526+ if isinstance (image , np .ndarray ):
527+ image = self .pt_to_numpy (image )
527528
528529 return image
529530
0 commit comments