Skip to content

Commit 35d4a00

Browse files
committed
[FIX][Relax] Update reverse_sequence dtype checks
1 parent 1313eb9 commit 35d4a00

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/relax/op/tensor/manipulate.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,14 +2103,15 @@ Type InferTypeReverseSequence(const Call& call, const BlockBuilder& ctx) {
21032103
<< "ReverseSequence requires seq_lengths to be 1-D. However, seq_lengths has ndim "
21042104
<< seq_lengths_ty->ndim;
21052105
}
2106-
if (!seq_lengths_ty->dtype.is_void() && !seq_lengths_ty->dtype.is_int()) {
2106+
PrimType seq_lengths_dtype = seq_lengths_ty->dtype;
2107+
if (!seq_lengths_ty->IsUnknownDtype() && !seq_lengths_dtype.MatchesCode(DLDataTypeCode::kDLInt)) {
21072108
TVM_FFI_VISIT_THROW(ValueError, call)
21082109
<< "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, "
21092110
"seq_lengths has dtype "
21102111
<< seq_lengths_ty->dtype;
21112112
}
2112-
if (seq_lengths_ty->dtype.is_int() && seq_lengths_ty->dtype.bits() != 32 &&
2113-
seq_lengths_ty->dtype.bits() != 64) {
2113+
if (seq_lengths_dtype.MatchesCode(DLDataTypeCode::kDLInt) &&
2114+
seq_lengths_dtype->dtype.bits != 32 && seq_lengths_dtype->dtype.bits != 64) {
21142115
TVM_FFI_VISIT_THROW(ValueError, call)
21152116
<< "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, "
21162117
"seq_lengths has dtype "

0 commit comments

Comments
 (0)