Skip to content

Commit f1614a6

Browse files
authored
[Custom Descriptors] Better handle unreachable ref.cast_desc (#7937)
Validate that the descriptor operand of ref.cast_desc is always a descriptor reference (or a nullref or unreachable), even if the ref operand and therefore the ref.cast_desc are unreachable. This removes an edge case passes might otherwise have to consider where the descriptor operand is not actually a descriptor. Update GTO to preserve this new validation invariant. Also print an unreachable ref.cast_desc instruction when we can determine its type from a concrete, non-null descriptor operand rather than bailing out and printing an unreachable block. This allows us to preserve such IR through text round trips, which can help e.g. test case reduction.
1 parent 6ff2f94 commit f1614a6

4 files changed

Lines changed: 376 additions & 71 deletions

File tree

src/passes/Print.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
321321
void printUnreachableReplacement(Expression* curr);
322322
bool maybePrintUnreachableReplacement(Expression* curr, Type type);
323323
void visitRefCast(RefCast* curr) {
324-
if (!maybePrintUnreachableReplacement(curr, curr->type)) {
324+
if ((curr->desc && curr->desc->type != Type::unreachable) ||
325+
!maybePrintUnreachableReplacement(curr, curr->type)) {
325326
visitExpression(curr);
326327
}
327328
}
@@ -2224,7 +2225,20 @@ struct PrintExpressionContents
22242225
} else {
22252226
printMedium(o, "ref.cast ");
22262227
}
2227-
printType(curr->type);
2228+
if (curr->type != Type::unreachable) {
2229+
printType(curr->type);
2230+
} else {
2231+
// We can still recover a valid result type from the type of the
2232+
// descriptor.
2233+
auto described = curr->desc->type.getHeapType().getDescribedType();
2234+
if (described) {
2235+
printType(
2236+
Type(*described, NonNullable, curr->desc->type.getExactness()));
2237+
} else {
2238+
// Invalid, so it doesn't matter what we print.
2239+
printType(Type::unreachable);
2240+
}
2241+
}
22282242
}
22292243
void visitRefGetDesc(RefGetDesc* curr) {
22302244
printMedium(o, "ref.get_desc ");

src/wasm/wasm-validator.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,18 @@ void FunctionValidator::visitRefTest(RefTest* curr) {
29442944
void FunctionValidator::visitRefCast(RefCast* curr) {
29452945
shouldBeTrue(
29462946
getModule()->features.hasGC(), curr, "ref.cast requires gc [--enable-gc]");
2947+
2948+
// Require descriptors to be valid even if the ref is unreachable.
2949+
if (curr->desc && curr->desc->type != Type::unreachable) {
2950+
auto descType = curr->desc->type;
2951+
bool isNull = descType.isNull();
2952+
bool isDescriptor =
2953+
descType.isRef() && descType.getHeapType().getDescribedType();
2954+
shouldBeTrue(isNull || isDescriptor,
2955+
curr,
2956+
"ref.cast_desc descriptor must be a descriptor reference");
2957+
}
2958+
29472959
if (curr->type == Type::unreachable) {
29482960
return;
29492961
}
@@ -3006,11 +3018,7 @@ void FunctionValidator::visitRefCast(RefCast* curr) {
30063018
}
30073019

30083020
auto described = descriptor.getDescribedType();
3009-
if (!shouldBeTrue(bool(described),
3010-
curr,
3011-
"ref.cast_desc descriptor should have a described type")) {
3012-
return;
3013-
}
3021+
assert(described && "already checked descriptor");
30143022
shouldBeEqual(*described,
30153023
curr->type.getHeapType(),
30163024
curr,

test/gtest/validator.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "support/string.h"
2323
#include "wasm-binary.h"
2424
#include "wasm-builder.h"
25+
#include "wasm-features.h"
2526
#include "wasm-validator.h"
2627
#include "gtest/gtest.h"
2728

@@ -65,3 +66,27 @@ TEST(ValidatorTest, ReturnUnreachable) {
6566
WasmValidator::FlagValues::Globally | WasmValidator::FlagValues::Quiet;
6667
EXPECT_FALSE(WasmValidator{}.validate(func.get(), module, flags));
6768
}
69+
70+
TEST(ValidatorTest, UnreachableCastDesc) {
71+
// The parser will error trying to parse a ref.cast_desc with a non-matching
72+
// descriptor type, so we must construct the IR directly to test the
73+
// validator.
74+
Module module;
75+
module.features = FeatureSet::All;
76+
Builder builder(module);
77+
78+
auto func = builder.makeFunction(
79+
"func",
80+
{},
81+
Signature(Type::none, Type::none),
82+
{},
83+
builder.makeDrop(builder.makeRefCast(builder.makeUnreachable(),
84+
builder.makeStructNew(Struct{}, {}),
85+
Type::unreachable)));
86+
87+
ASSERT_EQ(func->body->type, Type(Type::unreachable));
88+
89+
auto flags =
90+
WasmValidator::FlagValues::Globally | WasmValidator::FlagValues::Quiet;
91+
EXPECT_FALSE(WasmValidator{}.validate(func.get(), module, flags));
92+
}

0 commit comments

Comments
 (0)