Skip to content

Commit 0a9638a

Browse files
committed
wasm3: validate function type
1 parent 7d4c8f1 commit 0a9638a

1 file changed

Lines changed: 51 additions & 7 deletions

File tree

test/utils/wasm3_engine.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
#include <test/utils/adler32.hpp>
88
#include <test/utils/wasm_engine.hpp>
9+
#include <algorithm>
910
#include <cassert>
1011
#include <cstring>
12+
#include <stdexcept>
1113

1214
namespace fizzy::test
1315
{
@@ -38,6 +40,33 @@ class Wasm3Engine final : public WasmEngine
3840

3941
namespace
4042
{
43+
M3ValueType translate_valtype(char input)
44+
{
45+
if (input == 'i')
46+
return M3ValueType::c_m3Type_i32;
47+
else if (input == 'I')
48+
return M3ValueType::c_m3Type_i64;
49+
else
50+
throw std::runtime_error{"invalid type"};
51+
}
52+
53+
std::pair<std::vector<M3ValueType>, std::vector<M3ValueType>> translate_signature(
54+
std::string_view signature)
55+
{
56+
const auto delimiter_pos = signature.find(':');
57+
assert(delimiter_pos != std::string_view::npos);
58+
const auto inputs = signature.substr(0, delimiter_pos);
59+
const auto outputs = signature.substr(delimiter_pos + 1);
60+
61+
std::vector<M3ValueType> input_types;
62+
std::vector<M3ValueType> output_types;
63+
std::transform(
64+
std::begin(inputs), std::end(inputs), std::back_inserter(input_types), translate_valtype);
65+
std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_types),
66+
translate_valtype);
67+
return {std::move(input_types), std::move(output_types)};
68+
}
69+
4170
const void* env_adler32(
4271
IM3Runtime /*runtime*/, IM3ImportContext /*context*/, uint64_t* stack, void* mem)
4372
{
@@ -115,14 +144,29 @@ fizzy::bytes_view Wasm3Engine::get_memory() const
115144
}
116145

117146
std::optional<WasmEngine::FuncRef> Wasm3Engine::find_function(
118-
std::string_view name, std::string_view) const
147+
std::string_view name, std::string_view signature) const
119148
{
120149
IM3Function function;
121-
if (m3_FindFunction(&function, m_runtime, name.data()) == m3Err_none)
122-
// TODO: validate input/output types
123-
// (m3_GetArgCount/m3_GetArgType/m3_GetRetCount/m3_GetRetType)
124-
return reinterpret_cast<WasmEngine::FuncRef>(function);
125-
return std::nullopt;
150+
if (m3_FindFunction(&function, m_runtime, name.data()) != m3Err_none)
151+
return std::nullopt;
152+
153+
std::vector<M3ValueType> inputs;
154+
std::vector<M3ValueType> outputs;
155+
std::tie(inputs, outputs) = translate_signature(signature);
156+
157+
if (inputs.size() != m3_GetArgCount(function))
158+
return std::nullopt;
159+
for (unsigned i = 0; i < m3_GetArgCount(function); i++)
160+
if (inputs[i] != m3_GetArgType(function, i))
161+
return std::nullopt;
162+
163+
if (outputs.size() != m3_GetRetCount(function))
164+
return std::nullopt;
165+
for (unsigned i = 0; i < m3_GetRetCount(function); i++)
166+
if (outputs[i] != m3_GetRetType(function, i))
167+
return std::nullopt;
168+
169+
return reinterpret_cast<WasmEngine::FuncRef>(function);
126170
}
127171

128172
WasmEngine::Result Wasm3Engine::execute(
@@ -137,7 +181,7 @@ WasmEngine::Result Wasm3Engine::execute(
137181

138182
// This ensures input count/type matches. For the return value we assume find_function did the
139183
// validation.
140-
if (m3_Call(function, static_cast<uint32_t>(args.size()), argPtrs.data()) == m3Err_none)
184+
if (m3_Call(function, static_cast<uint32_t>(argPtrs.size()), argPtrs.data()) == m3Err_none)
141185
{
142186
if (m3_GetRetCount(function) == 0)
143187
return {false, std::nullopt};

0 commit comments

Comments
 (0)