@@ -385,51 +385,55 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) {
385385 << " AffineGrid expects the target size to be a Shape, while the given one is "
386386 << call->args [1 ]->GetTypeKey ();
387387 }
388- if (size_ty->ndim != 2 ) {
388+ // 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, H, W]).
389+ if (size_ty->ndim != 2 && size_ty->ndim != 3 ) {
389390 TVM_FFI_VISIT_THROW (ValueError, call)
390- << " AffineGrid expects the target size to be a 2-dim shape, while the given "
391+ << " AffineGrid expects the target size to be a 2-dim or 3-dim shape, while the given "
391392 " one has ndim "
392393 << size_ty->ndim ;
393394 }
395+ const int spatial = size_ty->ndim ;
394396
395- // data should be 3-D: [batch, 2, 3]
397+ // data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 3, 4]).
396398 if (data_ty->ndim != -1 && data_ty->ndim != 3 ) {
397- TVM_FFI_VISIT_THROW (ValueError, call)
398- << " AffineGrid expects the input data to be 3-D (batch, 2, 3 ), but got ndim "
399- << data_ty->ndim ;
399+ TVM_FFI_VISIT_THROW (ValueError, call) << " AffineGrid expects the input data to be 3-D (batch, "
400+ " spatial, spatial + 1 ), but got ndim "
401+ << data_ty->ndim ;
400402 }
401403
402404 const auto * data_shape = data_ty->shape .as <ShapeExprNode>();
403405 if (data_shape != nullptr ) {
404- // Check that the affine matrix has shape [batch, 2, 3]
405406 if (data_shape->values .size () >= 2 ) {
406407 auto * dim1 = data_shape->values [1 ].as <IntImmNode>();
407- if (dim1 != nullptr && dim1->value != 2 ) {
408+ if (dim1 != nullptr && dim1->value != spatial ) {
408409 TVM_FFI_VISIT_THROW (ValueError, call)
409- << " AffineGrid expects the second dimension of input to be 2, but got " << dim1->value ;
410+ << " AffineGrid expects the second dimension of input to be " << spatial << " , but got "
411+ << dim1->value ;
410412 }
411413 }
412414 if (data_shape->values .size () >= 3 ) {
413415 auto * dim2 = data_shape->values [2 ].as <IntImmNode>();
414- if (dim2 != nullptr && dim2->value != 3 ) {
416+ if (dim2 != nullptr && dim2->value != spatial + 1 ) {
415417 TVM_FFI_VISIT_THROW (ValueError, call)
416- << " AffineGrid expects the third dimension of input to be 3, but got " << dim2->value ;
418+ << " AffineGrid expects the third dimension of input to be " << spatial + 1
419+ << " , but got " << dim2->value ;
417420 }
418421 }
419422 }
420423
421424 DataType out_dtype = data_ty->dtype ;
422425
423426 if (data_shape == nullptr || size_value == nullptr ) {
424- return TensorType (out_dtype, /* ndim=*/ 4 , data_ty->vdevice );
427+ return TensorType (out_dtype, /* ndim=*/ spatial + 2 , data_ty->vdevice );
425428 }
426429
427- // Output shape: [batch, 2, target_height, target_width]
430+ // Output shape: [batch, spatial, *target_spatial_dims].
428431 ffi::Array<PrimExpr> out_shape;
429- out_shape.push_back (data_shape->values [0 ]); // batch
430- out_shape.push_back (IntImm::Int64 (2 )); // 2 (spatial dimensions)
431- out_shape.push_back (size_value->values [0 ]); // target_height
432- out_shape.push_back (size_value->values [1 ]); // target_width
432+ out_shape.push_back (data_shape->values [0 ]); // batch
433+ out_shape.push_back (IntImm::Int64 (spatial)); // number of spatial coordinates
434+ for (int i = 0 ; i < spatial; ++i) {
435+ out_shape.push_back (size_value->values [i]); // target spatial dim
436+ }
433437
434438 return TensorType (ShapeExpr (out_shape), out_dtype, data_ty->vdevice );
435439}
0 commit comments