66
77#include < test/utils/adler32.hpp>
88#include < test/utils/wasm_engine.hpp>
9+ #include < algorithm>
910#include < cassert>
1011#include < cstring>
12+ #include < stdexcept>
1113
1214namespace fizzy ::test
1315{
@@ -38,6 +40,33 @@ class Wasm3Engine final : public WasmEngine
3840
3941namespace
4042{
43+ M3ValueType translate_valtype (char input)
44+ {
45+ if (input == ' i' )
46+ return M3ValueType::c_m3Type_i32;
47+ else if (input == ' I' )
48+ return M3ValueType::c_m3Type_i64;
49+ else
50+ throw std::runtime_error{" invalid type" };
51+ }
52+
53+ std::pair<std::vector<M3ValueType>, std::vector<M3ValueType>> translate_signature (
54+ std::string_view signature)
55+ {
56+ const auto delimiter_pos = signature.find (' :' );
57+ assert (delimiter_pos != std::string_view::npos);
58+ const auto inputs = signature.substr (0 , delimiter_pos);
59+ const auto outputs = signature.substr (delimiter_pos + 1 );
60+
61+ std::vector<M3ValueType> input_types;
62+ std::vector<M3ValueType> output_types;
63+ std::transform (
64+ std::begin (inputs), std::end (inputs), std::back_inserter (input_types), translate_valtype);
65+ std::transform (std::begin (outputs), std::end (outputs), std::back_inserter (output_types),
66+ translate_valtype);
67+ return {std::move (input_types), std::move (output_types)};
68+ }
69+
4170const void * env_adler32 (
4271 IM3Runtime /* runtime*/ , IM3ImportContext /* context*/ , uint64_t * stack, void * mem)
4372{
@@ -115,14 +144,29 @@ fizzy::bytes_view Wasm3Engine::get_memory() const
115144}
116145
117146std::optional<WasmEngine::FuncRef> Wasm3Engine::find_function (
118- std::string_view name, std::string_view) const
147+ std::string_view name, std::string_view signature ) const
119148{
120149 IM3Function function;
121- if (m3_FindFunction (&function, m_runtime, name.data ()) == m3Err_none)
122- // TODO: validate input/output types
123- // (m3_GetArgCount/m3_GetArgType/m3_GetRetCount/m3_GetRetType)
124- return reinterpret_cast <WasmEngine::FuncRef>(function);
125- return std::nullopt ;
150+ if (m3_FindFunction (&function, m_runtime, name.data ()) != m3Err_none)
151+ return std::nullopt ;
152+
153+ std::vector<M3ValueType> inputs;
154+ std::vector<M3ValueType> outputs;
155+ std::tie (inputs, outputs) = translate_signature (signature);
156+
157+ if (inputs.size () != m3_GetArgCount (function))
158+ return std::nullopt ;
159+ for (unsigned i = 0 ; i < m3_GetArgCount (function); i++)
160+ if (inputs[i] != m3_GetArgType (function, i))
161+ return std::nullopt ;
162+
163+ if (outputs.size () != m3_GetRetCount (function))
164+ return std::nullopt ;
165+ for (unsigned i = 0 ; i < m3_GetRetCount (function); i++)
166+ if (outputs[i] != m3_GetRetType (function, i))
167+ return std::nullopt ;
168+
169+ return reinterpret_cast <WasmEngine::FuncRef>(function);
126170}
127171
128172WasmEngine::Result Wasm3Engine::execute (
@@ -137,7 +181,7 @@ WasmEngine::Result Wasm3Engine::execute(
137181
138182 // This ensures input count/type matches. For the return value we assume find_function did the
139183 // validation.
140- if (m3_Call (function, static_cast <uint32_t >(args .size ()), argPtrs.data ()) == m3Err_none)
184+ if (m3_Call (function, static_cast <uint32_t >(argPtrs .size ()), argPtrs.data ()) == m3Err_none)
141185 {
142186 if (m3_GetRetCount (function) == 0 )
143187 return {false , std::nullopt };
0 commit comments