Skip to content

Commit 3b6369e

Browse files
Add support for (either ...) in wast (#8421)
* Adds support for (either ...) in assert_return in WAST * Also allows customizing the error type in Result, which we use to keep track of the failing lane in assertions. Example failure: ```wast (module (func (export "f") (result i32) (i32.const 1) ) ) (assert_return (invoke "f") (either (i32.const 2) (i32.const 3))) ``` ``` Expected one of (2 | 3) but got 1 ``` Example failure with SIMD from `relaxed_min_max.wast`: ``` Expected one of (canonical f32 | canonical f32 | 0 | 0x00000000) at lane 0 but got i32x4 0x7fc00000 0x7fc00000 0x7fc00000 0x7fc00000 ``` Part of #8261 and #8315. Fixes `i16x8_relaxed_q15mulr_s.wast`, `i8x16_relaxed_swizzle.wast`, `relaxed_madd_nmadd.wast` spec tests, and partially fixes `relaxed_dot_product.wast`, `relaxed_laneselect.wast`, `relaxed_min_max.wast`, and `threads/thread.wast`.
1 parent 7288ed5 commit 3b6369e

File tree

6 files changed

+182
-95
lines changed

6 files changed

+182
-95
lines changed

scripts/test/shared.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_tests(test_dir, extensions=[], recursive=False):
396396
# Test invalid
397397
'elem.wast',
398398

399-
# Requires wast `either` support
399+
# Requires scoping of `register` statements within `thread` blocks
400400
'threads/thread.wast',
401401

402402
# Requires better support for multi-threaded tests
@@ -453,12 +453,9 @@ def get_tests(test_dir, extensions=[], recursive=False):
453453
'type-subtyping.wast', # ShellExternalInterface::callTable does not handle subtyping
454454
'memory64.wast', # Requires validations on the max memory size
455455
'imports3.wast', # Requires better checking of exports from the special "spectest" module
456-
'i16x8_relaxed_q15mulr_s.wast', # Requires wast `either` support
457-
'i8x16_relaxed_swizzle.wast', # Requires wast `either` support
458-
'relaxed_dot_product.wast', # Requires wast `either` support
459-
'relaxed_laneselect.wast', # Requires wast `either` support
460-
'relaxed_madd_nmadd.wast', # Requires wast `either` support
461-
'relaxed_min_max.wast', # Requires wast `either` support
456+
'relaxed_dot_product.wast', # i16x8.relaxed_dot_i8x16_i7x16_s instruction not supported
457+
'relaxed_laneselect.wast', # i8x16.relaxed_laneselect instruction not supported
458+
'relaxed_min_max.wast', # Non-canonical NaN from f32x4.relaxed_min
462459
'simd_const.wast', # Hex float constant not recognized as out of range
463460
'simd_conversions.wast', # Promoted NaN should be canonical
464461
'simd_f32x4.wast', # Min of 0 and NaN should give a canonical NaN

src/parser/wast-parser.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,30 @@ Result<ExpectedResult> result(Lexer& in) {
342342
return in.err("unrecognized result");
343343
}
344344

345+
Result<ResultAlternatives> eitherResult(Lexer& in) {
346+
if (in.takeSExprStart("either"sv)) {
347+
ResultAlternatives alternatives;
348+
do {
349+
auto r = result(in);
350+
CHECK_ERR(r);
351+
352+
alternatives.push_back(*std::move(r));
353+
} while (!in.takeRParen());
354+
355+
return alternatives;
356+
}
357+
358+
auto r = result(in);
359+
CHECK_ERR(r);
360+
return ResultAlternatives{*std::move(r)};
361+
}
362+
345363
Result<ExpectedResults> results(Lexer& in) {
346364
ExpectedResults res;
347365
while (!in.peekRParen()) {
348-
auto r = result(in);
366+
auto r = eitherResult(in);
349367
CHECK_ERR(r);
350-
res.emplace_back(std::move(*r));
368+
res.emplace_back(*std::move(r));
351369
}
352370
return res;
353371
}
@@ -666,7 +684,7 @@ Result<WASTScript> wast(Lexer& in) {
666684
return cmds;
667685
}
668686
CHECK_ERR(cmd);
669-
cmds.push_back(ScriptEntry{std::move(*cmd), line});
687+
cmds.push_back(ScriptEntry{*std::move(cmd), line});
670688
}
671689
return cmds;
672690
}

src/parser/wat-parser.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ struct LaneResults {
8989
using ExpectedResult =
9090
std::variant<Literal, NullRefResult, RefResult, NaNResult, LaneResults>;
9191

92-
using ExpectedResults = std::vector<ExpectedResult>;
92+
using ResultAlternatives = std::vector<ExpectedResult>;
93+
94+
// The WAST spec states that `either`s maybe be nested arbitrarily e.g.
95+
// (either (either "a" "b") (either "a" "c"))
96+
// but we store this flattened since there's no way to tell the difference
97+
// anyway.
98+
using ExpectedResults = std::vector<ResultAlternatives>;
9399

94100
struct AssertReturn {
95101
Action action;

src/support/result.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,30 @@ struct Err {
3636
// Check a Result or MaybeResult for error and return the error if it exists.
3737
#define CHECK_ERR(val) \
3838
if (auto _val = (val); auto err = _val.getErr()) { \
39-
return Err{*err}; \
39+
return (typename decltype(_val)::ErrorType)(*err); \
4040
}
4141

4242
// Represent a result of type T or an error message.
43-
template<typename T = Ok> struct [[nodiscard]] Result {
44-
std::variant<T, Err> val;
45-
46-
Result(Result<T>& other) = default;
47-
Result(Result<T>&& other) = default;
48-
Result(const Err& e) : val(std::in_place_type<Err>, e) {}
49-
Result(Err&& e) : val(std::in_place_type<Err>, std::move(e)) {}
43+
template<typename T = Ok, typename E = Err> struct [[nodiscard]] Result {
44+
using ErrorType = E;
45+
std::variant<T, E> val;
46+
47+
Result(Result<T, E>& other) = default;
48+
Result(Result<T, E>&& other) = default;
49+
Result(const E& e) : val(std::in_place_type<E>, e) {}
50+
Result(E&& e) : val(std::in_place_type<E>, std::move(e)) {}
5051
template<typename U = T>
5152
Result(U&& u) : val(std::in_place_type<T>, std::forward<U>(u)) {}
5253

53-
Err* getErr() { return std::get_if<Err>(&val); }
54+
E* getErr() { return std::get_if<E>(&val); }
5455
T& operator*() { return *std::get_if<T>(&val); }
5556
T* operator->() { return std::get_if<T>(&val); }
5657
};
5758

5859
// Represent an optional result of type T or an error message.
5960
template<typename T = Ok> struct [[nodiscard]] MaybeResult {
6061
std::variant<T, None, Err> val;
62+
using ErrorType = Err;
6163

6264
MaybeResult() : val(None{}) {}
6365
MaybeResult(MaybeResult<T>& other) = default;

src/tools/wasm-shell.cpp

Lines changed: 129 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -319,31 +319,31 @@ struct Shell {
319319
switch (nan.kind) {
320320
case NaNKind::Canonical:
321321
if (val.type != nan.type || !val.isCanonicalNaN()) {
322-
err << "expected canonical " << nan.type << " NaN, got " << val;
322+
err << "canonical " << nan.type;
323323
return Err{err.str()};
324324
}
325325
break;
326326
case NaNKind::Arithmetic:
327327
if (val.type != nan.type || !val.isArithmeticNaN()) {
328-
err << "expected arithmetic " << nan.type << " NaN, got " << val;
328+
err << "arithmetic " << nan.type;
329329
return Err{err.str()};
330330
}
331331
break;
332332
}
333333
return Ok{};
334334
}
335335

336-
Result<> checkLane(Literal val, LaneResult expected, Index index) {
336+
Result<> checkLane(Literal val, LaneResult expected) {
337337
std::stringstream err;
338338
if (auto* e = std::get_if<Literal>(&expected)) {
339339
if (*e != val) {
340-
err << "expected " << *e << ", got " << val << " at lane " << index;
340+
err << *e;
341341
return Err{err.str()};
342342
}
343343
} else if (auto* nan = std::get_if<NaNResult>(&expected)) {
344344
auto check = checkNaN(val, *nan);
345345
if (auto* e = check.getErr()) {
346-
err << e->msg << " at lane " << index;
346+
err << e->msg;
347347
return Err{err.str()};
348348
}
349349
} else {
@@ -352,6 +352,88 @@ struct Shell {
352352
return Ok{};
353353
}
354354

355+
struct AlternativeErr {
356+
std::string expected;
357+
int lane = -1;
358+
};
359+
360+
Result<Ok, AlternativeErr> matchAlternative(const Literal& val,
361+
const ExpectedResult& expected) {
362+
std::stringstream err;
363+
364+
if (auto* v = std::get_if<Literal>(&expected)) {
365+
if (val != *v) {
366+
err << *v;
367+
return AlternativeErr{err.str()};
368+
}
369+
} else if (auto* ref = std::get_if<RefResult>(&expected)) {
370+
if (!val.type.isRef() ||
371+
!HeapType::isSubType(val.type.getHeapType(), ref->type)) {
372+
err << ref->type;
373+
return AlternativeErr{err.str()};
374+
}
375+
} else if ([[maybe_unused]] auto* nullRef =
376+
std::get_if<NullRefResult>(&expected)) {
377+
if (!val.isNull()) {
378+
err << "ref.null";
379+
return AlternativeErr{err.str()};
380+
}
381+
} else if (auto* nan = std::get_if<NaNResult>(&expected)) {
382+
auto check = checkNaN(val, *nan);
383+
if (auto* e = check.getErr()) {
384+
err << e->msg;
385+
return AlternativeErr{err.str()};
386+
}
387+
} else if (auto* laneResults = std::get_if<LaneResults>(&expected)) {
388+
auto check = [&](const auto& vals) -> Result<Ok, AlternativeErr> {
389+
for (size_t i = 0; i < vals.size(); ++i) {
390+
auto check = checkLane(vals[i], laneResults->lanes[i]);
391+
if (auto* e = check.getErr()) {
392+
err << e->msg;
393+
394+
// The number of lanes is small
395+
assert(i <= std::numeric_limits<int>::max());
396+
return AlternativeErr{err.str(), static_cast<int>(i)};
397+
}
398+
}
399+
return Ok{};
400+
};
401+
402+
bool isFloat =
403+
laneResults->type == WATParser::LaneResults::LaneType::Float;
404+
switch (laneResults->lanes.size()) {
405+
// Use unsigned values for the smaller types here to avoid sign
406+
// extension when storing 8/16-bit values in 32-bit ints. This isn't
407+
// needed for i32 and i64.
408+
case 16: {
409+
// There is no f8.
410+
assert(!isFloat && "float8 does not exist");
411+
CHECK_ERR(check(val.getLanesUI8x16()));
412+
break;
413+
}
414+
case 8: {
415+
CHECK_ERR(
416+
check(isFloat ? val.getLanesF16x8() : val.getLanesUI16x8()));
417+
break;
418+
}
419+
case 4: {
420+
CHECK_ERR(check(isFloat ? val.getLanesF32x4() : val.getLanesI32x4()));
421+
break;
422+
}
423+
case 2: {
424+
CHECK_ERR(check(isFloat ? val.getLanesF64x2() : val.getLanesI64x2()));
425+
break;
426+
}
427+
default:
428+
WASM_UNREACHABLE("unexpected number of lanes");
429+
}
430+
431+
} else {
432+
WASM_UNREACHABLE("unexpected expectation");
433+
}
434+
return Ok{};
435+
}
436+
355437
Result<> assertReturn(AssertReturn& assn) {
356438
std::stringstream err;
357439
auto result = doAction(assn.action);
@@ -374,79 +456,53 @@ struct Shell {
374456
return ss.str();
375457
};
376458

377-
Literal val = (*values)[i];
378-
auto& expected = assn.expected[i];
379-
if (auto* v = std::get_if<Literal>(&expected)) {
380-
if (val != *v) {
381-
err << "expected " << *v << ", got " << val << atIndex();
382-
return Err{err.str()};
383-
}
384-
} else if (auto* ref = std::get_if<RefResult>(&expected)) {
385-
if (!val.type.isRef() ||
386-
!HeapType::isSubType(val.type.getHeapType(), ref->type)) {
387-
err << "expected " << ref->type << " reference, got " << val
388-
<< atIndex();
389-
return Err{err.str()};
390-
}
391-
} else if ([[maybe_unused]] auto* nullRef =
392-
std::get_if<NullRefResult>(&expected)) {
393-
if (!val.isNull()) {
394-
err << "expected ref.null, got " << val << atIndex();
395-
return Err{err.str()};
459+
// non-either case
460+
if (assn.expected[i].size() == 1) {
461+
auto result = matchAlternative((*values)[i], assn.expected[i][0]);
462+
if (auto* e = result.getErr()) {
463+
std::stringstream ss;
464+
ss << "expected " << e->expected << ", got " << (*values)[i];
465+
if (e->lane != -1) {
466+
ss << " at lane " << e->lane;
467+
}
468+
ss << atIndex();
469+
return Err{ss.str()};
396470
}
397-
} else if (auto* nan = std::get_if<NaNResult>(&expected)) {
398-
auto check = checkNaN(val, *nan);
399-
if (auto* e = check.getErr()) {
400-
err << e->msg << atIndex();
401-
return Err{err.str()};
471+
continue;
472+
}
473+
474+
// either case
475+
bool success = false;
476+
std::vector<std::string> expecteds;
477+
int failedLane = -1;
478+
for (const auto& alternative : assn.expected[i]) {
479+
auto result = matchAlternative((*values)[i], alternative);
480+
if (!result.getErr()) {
481+
success = true;
482+
break;
402483
}
403-
} else if (auto* l = std::get_if<LaneResults>(&expected)) {
404-
auto* lanes = &l->lanes;
405-
406-
auto check = [&](const auto& vals) -> Result<> {
407-
for (size_t i = 0; i < vals.size(); ++i) {
408-
auto check = checkLane(vals[i], (*lanes)[i], i);
409-
if (auto* e = check.getErr()) {
410-
err << e->msg << atIndex();
411-
return Err{err.str()};
412-
}
413-
}
414-
return Ok{};
415-
};
416-
417-
bool isFloat = l->type == WATParser::LaneResults::LaneType::Float;
418-
switch (lanes->size()) {
419-
// Use unsigned values for the smaller types here to avoid sign
420-
// extension when storing 8/16-bit values in 32-bit ints. This isn't
421-
// needed for i32 and i64.
422-
case 16: {
423-
// There is no f8.
424-
assert(!isFloat && "float8 does not exist");
425-
CHECK_ERR(check(val.getLanesUI8x16()));
426-
break;
427-
}
428-
case 8: {
429-
CHECK_ERR(
430-
check(isFloat ? val.getLanesF16x8() : val.getLanesUI16x8()));
431-
break;
432-
}
433-
case 4: {
434-
CHECK_ERR(
435-
check(isFloat ? val.getLanesF32x4() : val.getLanesI32x4()));
436-
break;
437-
}
438-
case 2: {
439-
CHECK_ERR(
440-
check(isFloat ? val.getLanesF64x2() : val.getLanesI64x2()));
441-
break;
442-
}
443-
default:
444-
WASM_UNREACHABLE("unexpected number of lanes");
484+
485+
auto* e = result.getErr();
486+
expecteds.push_back(e->expected);
487+
if (failedLane == -1 && e->lane != -1) {
488+
failedLane = e->lane;
445489
}
446-
} else {
447-
WASM_UNREACHABLE("unexpected expectation");
448490
}
491+
if (success) {
492+
continue;
493+
}
494+
std::stringstream ss;
495+
ss << "Expected one of (" << String::join(expecteds, " | ") << ")";
496+
if (failedLane != -1) {
497+
ss << " at lane " << failedLane;
498+
}
499+
ss << " but got " << (*values)[i];
500+
501+
ss << atIndex();
502+
503+
return Err{ss.str()};
449504
}
505+
450506
return Ok{};
451507
}
452508

0 commit comments

Comments
 (0)