From 283f22d9fb06f9459210f0c73eae921fd32af151 Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 08:11:37 +0100 Subject: [PATCH 1/6] fix: Complete resource handle lifecycle Signed-off-by: Gordon Smith --- docs/issue-backlog.md | 161 +++++++++++++++++++++++ include/cmcpp.hpp | 1 + include/cmcpp/context.hpp | 260 ++++++++++++++++++++++++++++++++------ include/cmcpp/util.hpp | 8 -- test/main.cpp | 92 ++++++++++++++ 5 files changed, 472 insertions(+), 50 deletions(-) create mode 100644 docs/issue-backlog.md diff --git a/docs/issue-backlog.md b/docs/issue-backlog.md new file mode 100644 index 0000000..cb505e3 --- /dev/null +++ b/docs/issue-backlog.md @@ -0,0 +1,161 @@ +# GitHub Issue Backlog + +Draft issues ready to be copy-pasted into GitHub. Each entry lists a suggested title, labels, background, concrete scope, and acceptance criteria aligned with the canonical ABI reference implementation. + +--- + +## Issue: Implement canonical async runtime scaffolding +- **Labels:** `enhancement`, `async` + +### Background +The canonical ABI reference (`design/mvp/canonical-abi/definitions.py`) includes a `Store`, `Thread`, and scheduling loop that enable cooperative async execution via `tick()`. The current C++ headers lack any runtime to drive asynchronous component calls. + +### Scope +- Add `Store`, `Thread`, `Task`, and related scheduling types to `cmcpp` mirroring the canonical semantics. +- Implement queueing of pending threads and `tick()` to resume ready work. +- Expose hooks for component functions to register `on_start`/`on_resolve` callbacks and support cancellation tokens on the returned `Call` object. +- Provide doctest (or equivalent) coverage that simulates async invocation and verifies correct scheduling behavior. + +### Acceptance Criteria +1. New runtime types are available in the public API and documented. +2. Asynchronous component calls can progress via repeated `tick()` calls without blocking the host thread. +3. Unit tests demonstrate thread scheduling, `on_start` argument delivery, and cooperative cancellation. +4. Documentation explains how hosts drive async work. + +--- + +## Issue: Complete resource handle lifecycle +- **Labels:** `enhancement`, `abi` + +### Background +`ResourceHandle`, `ResourceType`, and `ComponentInstance.resources` are currently empty shells. The canonical implementation tracks ownership, borrow counts, destructors, and exposes `canon resource.{new,drop,rep}`. + +### Scope +- Flesh out `ResourceHandle` with `own`, `borrow_scope`, and `num_lends` semantics. +- Implement `ResourceType` destructor registration, including async callback support. +- Add `canon resource.new`, `canon resource.drop`, and `canon resource.rep` helper functions that update component tables and trap on misuse. +- Write tests covering own/borrow flows, lend counting, and destructor invocation. + +### Acceptance Criteria +1. Resource creation traps if the table would overflow or if ownership rules are violated. +2. Dropping resources invokes registered destructors exactly once and respects async/sync constraints. +3. Borrowed handles track lend counts and trap on invalid drops or reps. +4. Tests mirror canonical success and trap cases. + +--- + +## Issue: Add waitable, stream, and future infrastructure +- **Labels:** `enhancement`, `async` + +### Background +Canonical ABI defines waitables, waitable-sets, buffers, streams, and futures plus their cancellation behaviors. Our headers only contain empty structs. + +### Scope +- Model waitable and waitable-set state, including registration with the component store. +- Implement buffer management for {stream,future} types, covering readable/writable halves. +- Provide APIs for `canon waitable-set.{new,wait,poll,drop}`, `canon waitable.join`, and `canon {stream,future}.{new,read,write,cancel,drop}`. +- Add tests verifying read/write ordering, cancellation pathways, and polling semantics. + +### Acceptance Criteria +1. Streams and futures can be created, awaited, and canceled via the new APIs. +2. Waitable sets correctly block/unblock pending tasks and surface readiness in tests. +3. Cancellation propagates to queued operations per canonical rules. +4. Documentation describes how hosts integrate these constructs. + +--- + +## Issue: Implement backpressure and task lifecycle management +- **Labels:** `enhancement`, `async` + +### Background +`ComponentInstance` holds flags for `may_leave`, `backpressure`, and call-state tracking but they are unused. Canonical ABI specifies `canon task.{return,cancel}`, `canon yield`, and backpressure counters governing concurrent entry. + +### Scope +- Track outstanding synchronous/async calls and enforce `may_leave` invariants when entering/leaving component code. +- Implement `canon task.return` and `canon task.cancel` helpers wired to pending task queues. +- Support `canon yield` to hand control back to the embedder. +- Ensure backpressure counters gate `Store.invoke` while tasks are paused or pending. + +### Acceptance Criteria +1. Re-entrant sync calls are rejected per canonical rules (tests cover both allowed and disallowed cases). +2. Tasks marked for cancellation resolve promptly with `on_resolve(None)`. +3. `canon yield` transitions tasks to pending and requires a subsequent `tick()` to resume. +4. Backpressure metrics are exposed for debugging and verified in tests. + +--- + +## Issue: Support context locals and error-context APIs +- **Labels:** `enhancement`, `abi` + +### Background +`LiftLowerContext` currently omits instance references and borrow scopes, and `ContextLocalStorage`/`ErrorContext` types are unused. The canonical ABI exposes `canon context.{get,set}` and `canon error-context.{new,debug-message,drop}`. + +### Scope +- Extend `LiftLowerContext` to hold the active `ComponentInstance` and scoped borrow variant. +- Implement context-local storage getters/setters with bounds validation. +- Provide error-context creation, debug-message formatting via the host converter, and drop semantics that respect async callbacks. +- Add tests ensuring invalid indices and double drops trap. + +### Acceptance Criteria +1. Borrowed resources capture their scope and trap when accessed outside it. +2. Context locals persist across lift/lower calls and reset appropriately between tasks. +3. Error-context debug messages surface through the host trap mechanism. +4. Test coverage includes both success and failure paths for each API. + +--- + +## Issue: Finish function flattening utilities +- **Labels:** `enhancement`, `abi` + +### Background +`include/cmcpp/func.hpp` contains commented-out flattening helpers. Canonical ABI requires flattening functions to honor `MAX_FLAT_PARAMS/RESULTS` and spill to heap memory via the provided `realloc`. + +### Scope +- Implement `cmcpp::func::flatten`, `pack_flags_into_int`, and the associated load/store helpers. +- Respect max-flat thresholds and ensure out-params allocate via `LiftLowerContext::opts.realloc`. +- Add regression tests covering large tuples, records, and flag types that exceed the flat limit. +- Compare flattened signatures against outputs from the canonical Python definitions for validation. + +### Acceptance Criteria +1. Flattened core signatures match canonical expectations for representative WIT signatures. +2. Heap-based lowering is invoked automatically when flat limits are exceeded. +3. Flags marshal correctly between bitsets and flat integers. +4. Tests demonstrate both flat and heap pathways. + +--- + +## Issue: Wire canonical options and callbacks through lift/lower +- **Labels:** `enhancement`, `abi` + +### Background +`CanonicalOptions` exposes `post_return`, `callback`, and `sync` but they are currently unused. Canonical ABI requires these flags to control async vs sync paths and post-return cleanup. + +### Scope +- Invoke `post_return` after successful lowerings when provided. +- Enforce `sync` by trapping when async behavior would occur while `sync == true`. +- Invoke `callback` when async continuations schedule additional work. +- Ensure option fields propagate through `InstanceContext::createLiftLowerContext` and task lifecycles. + +### Acceptance Criteria +1. `post_return` is called exactly once per lowering when configured (verified via tests). +2. Async lowering attempts while `sync == true` trap with a descriptive error. +3. Registered callbacks fire for asynchronous continuations and can trigger host-side scheduling. +4. Documentation clarifies option usage and interaction with new runtime pieces. + +--- + +## Issue: Expand docs and tests for canonical runtime features +- **Labels:** `documentation`, `testing` + +### Background +New runtime pieces require supporting documentation and tests. Currently, README lacks guidance and test coverage mirrors only existing functionality. + +### Scope +- Update `README.md` (or add a new guide) summarizing the async runtime, resource management, and waitable APIs. +- Add doctest/ICU-backed unit tests covering the new behavior to `test/main.cpp` (or adjacent files). +- Optionally add a Python cross-check using `ref/component-model/design/mvp/canonical-abi/run_tests.py` for parity. + +### Acceptance Criteria +1. Documentation explains how to configure `InstanceContext`, allocate options, and drive async flows. +2. New tests pass in CI and cover at least one example for each newly implemented feature. +3. The backlog of canonical ABI features is reflected as "Done" within this issue once merged. diff --git a/include/cmcpp.hpp b/include/cmcpp.hpp index dfd5c21..c664e21 100644 --- a/include/cmcpp.hpp +++ b/include/cmcpp.hpp @@ -1,6 +1,7 @@ #ifndef CMCPP_HPP #define CMCPP_HPP +#include #include #include #include diff --git a/include/cmcpp/context.hpp b/include/cmcpp/context.hpp index 8523387..20ccf6d 100644 --- a/include/cmcpp/context.hpp +++ b/include/cmcpp/context.hpp @@ -4,6 +4,10 @@ #include "traits.hpp" #include +#include +#include +#include +#include #if __has_include() #include #else @@ -53,9 +57,167 @@ namespace cmcpp bool allways_task_return = false; }; - // Runtime State --- - struct ResourceHandle + struct ComponentInstance; + struct HandleElement; + + class LiftLowerContext + { + public: + HostTrap trap; + HostUnicodeConversion convert; + + LiftLowerOptions opts; + ComponentInstance *inst = nullptr; + std::vector lenders; + uint32_t borrow_count = 0; + + LiftLowerContext(const HostTrap &host_trap, const HostUnicodeConversion &conversion, const LiftLowerOptions &options, ComponentInstance *instance = nullptr) + : trap(host_trap), convert(conversion), opts(options), inst(instance) {} + + void track_owning_lend(HandleElement &lending_handle); + void exit_call(); + }; + + inline void trap_if(const LiftLowerContext &cx, bool condition, const char *message = nullptr) noexcept(false) + { + if (!condition) + { + return; + } + + const char *msg = message == nullptr ? "Unknown trap" : message; + if (cx.trap) + { + cx.trap(msg); + return; + } + throw std::runtime_error(msg); + } + + inline LiftLowerContext make_trap_context(const HostTrap &trap) + { + HostUnicodeConversion convert{}; + LiftLowerOptions opts{}; + return LiftLowerContext(trap, convert, opts); + } + + struct ResourceType + { + ComponentInstance *impl = nullptr; + std::function dtor; + + ResourceType() = default; + + explicit ResourceType(ComponentInstance &instance, std::function destructor = {}) + : impl(&instance), dtor(std::move(destructor)) {} + }; + + struct HandleElement { + uint32_t rep = 0; + bool own = false; + LiftLowerContext *scope = nullptr; + uint32_t lend_count = 0; + }; + + class HandleTable + { + public: + static constexpr uint32_t MAX_LENGTH = 1u << 30; + + HandleElement &get(uint32_t index, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, index >= entries_.size(), "resource index out of bounds"); + auto &slot = entries_[index]; + trap_if(trap_cx, !slot.has_value(), "resource slot empty"); + return slot.value(); + } + + const HandleElement &get(uint32_t index, const HostTrap &trap) const + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, index >= entries_.size(), "resource index out of bounds"); + const auto &slot = entries_[index]; + trap_if(trap_cx, !slot.has_value(), "resource slot empty"); + return slot.value(); + } + + uint32_t add(const HandleElement &element, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + uint32_t index; + if (!free_.empty()) + { + index = free_.back(); + free_.pop_back(); + entries_[index] = element; + } + else + { + trap_if(trap_cx, entries_.size() >= MAX_LENGTH, "resource table overflow"); + entries_.push_back(element); + index = static_cast(entries_.size() - 1); + } + return index; + } + + HandleElement remove(uint32_t index, const HostTrap &trap) + { + HandleElement element = get(index, trap); + entries_[index].reset(); + free_.push_back(index); + return element; + } + + const std::vector> &entries() const + { + return entries_; + } + + const std::vector &free_list() const + { + return free_; + } + + private: + std::vector> entries_{std::nullopt}; + std::vector free_; + }; + + class HandleTables + { + public: + HandleElement &get(ResourceType &rt, uint32_t index, const HostTrap &trap) + { + return table(rt).get(index, trap); + } + + const HandleElement &get(const ResourceType &rt, uint32_t index, const HostTrap &trap) const + { + auto it = tables_.find(&rt); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, it == tables_.end(), "resource table missing"); + return it->second.get(index, trap); + } + + uint32_t add(ResourceType &rt, const HandleElement &element, const HostTrap &trap) + { + return table(rt).add(element, trap); + } + + HandleElement remove(ResourceType &rt, uint32_t index, const HostTrap &trap) + { + return table(rt).remove(index, trap); + } + + HandleTable &table(ResourceType &rt) + { + return tables_[&rt]; + } + + private: + std::unordered_map tables_; }; struct Waitable @@ -72,16 +234,9 @@ namespace cmcpp struct ComponentInstance { - std::vector resources; - std::vector waitables; - std::vector waitable_sets; - std::vector error_contexts; bool may_leave = true; - bool backpressure = false; - bool calling_sync_export = false; - bool calling_sync_import = false; - // std::vector, Future>> pending_tasks; - bool starting_pending_tasks = false; + bool may_enter = true; + HandleTables handles; }; class ContextLocalStorage @@ -103,25 +258,6 @@ namespace cmcpp } }; - template - class Task - { - public: - using function_type = func_t; - - CanonicalOptions opts; - ComponentInstance inst; - function_type ft; - std::optional supertask; - std::optional> on_return; - std::function(std::future)> on_block; - int num_borrows = 0; - ContextLocalStorage context(); - - Task(CanonicalOptions &opts, ComponentInstance &inst, function_type &ft, std::optional &supertask = std::nullopt, std::optional> &on_return = std::nullopt, std::function(std::future)> &on_block = std::nullopt) - : opts(opts), inst(inst), ft(ft), supertask(supertask), on_return(on_return), on_block(on_block) {} - }; - struct Subtask : Waitable { }; @@ -130,21 +266,61 @@ namespace cmcpp { }; - // Lifting and Lowering Context --- - // template - class LiftLowerContext + inline void LiftLowerContext::track_owning_lend(HandleElement &lending_handle) { - public: - HostTrap trap; - HostUnicodeConversion convert; + trap_if(*this, !lending_handle.own, "lender must own resource"); + lending_handle.lend_count += 1; + lenders.push_back(&lending_handle); + } - LiftLowerOptions opts; - // ComponentInstance inst; - // std::optional, Subtask>> borrow_scope; + inline void LiftLowerContext::exit_call() + { + trap_if(*this, borrow_count != 0, "borrow count mismatch on exit"); + for (auto *handle : lenders) + { + if (handle && handle->lend_count > 0) + { + handle->lend_count -= 1; + } + } + lenders.clear(); + } - LiftLowerContext(const HostTrap &trap, const HostUnicodeConversion &convert, const LiftLowerOptions &opts) - : trap(trap), convert(convert), opts(opts) {} - }; + inline uint32_t canon_resource_new(ComponentInstance &inst, ResourceType &rt, uint32_t rep, const HostTrap &trap) + { + HandleElement element; + element.rep = rep; + element.own = true; + return inst.handles.add(rt, element, trap); + } + + inline void canon_resource_drop(ComponentInstance &inst, ResourceType &rt, uint32_t index, const HostTrap &trap) + { + HandleElement element = inst.handles.remove(rt, index, trap); + auto trap_cx = make_trap_context(trap); + if (element.own) + { + trap_if(trap_cx, element.scope != nullptr, "own handle cannot have borrow scope"); + trap_if(trap_cx, element.lend_count != 0, "resource has outstanding lends"); + trap_if(trap_cx, rt.impl != nullptr && (&inst != rt.impl) && !rt.impl->may_enter, "resource impl may not enter"); + if (rt.dtor) + { + rt.dtor(element.rep); + } + } + else + { + trap_if(trap_cx, element.scope == nullptr, "borrow scope missing"); + trap_if(trap_cx, element.scope->borrow_count == 0, "borrow scope underflow"); + element.scope->borrow_count -= 1; + } + } + + inline uint32_t canon_resource_rep(ComponentInstance &inst, ResourceType &rt, uint32_t index, const HostTrap &trap) + { + const HandleElement &element = inst.handles.get(rt, index, trap); + return element.rep; + } // ---------------------------- diff --git a/include/cmcpp/util.hpp b/include/cmcpp/util.hpp index 1df9b3c..e774ba2 100644 --- a/include/cmcpp/util.hpp +++ b/include/cmcpp/util.hpp @@ -7,14 +7,6 @@ namespace cmcpp { const bool DETERMINISTIC_PROFILE = false; - inline void trap_if(const LiftLowerContext &cx, bool condition, const char *message = nullptr) noexcept(false) - { - if (condition) - { - cx.trap(message == nullptr ? "Unknown trap" : message); - } - } - inline bool_t convert_int_to_bool(uint8_t i) { return i > 0; diff --git a/test/main.cpp b/test/main.cpp index 24fc42a..b39015b 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -14,6 +14,7 @@ using namespace cmcpp; #include #include #include +#include // #include #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN @@ -204,6 +205,97 @@ TEST_CASE("Async runtime propagates cancellation") CHECK(thread->completed()); } +TEST_CASE("Resource handle lifecycle mirrors canonical definitions") +{ + ComponentInstance resource_impl; + ComponentInstance inst; + std::vector dtor_calls; + + HostTrap host_trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + ResourceType rt(resource_impl, [&](uint32_t rep) + { dtor_calls.push_back(rep); }); + + REQUIRE(inst.may_leave); + REQUIRE(inst.may_enter); + + uint32_t h1 = canon_resource_new(inst, rt, 42, host_trap); + uint32_t h2 = canon_resource_new(inst, rt, 43, host_trap); + + CHECK(h1 == 1); + CHECK(h2 == 2); + + LiftLowerOptions borrow_opts; + HostUnicodeConversion noop_convert = [](void *, uint32_t, const void *, uint32_t, Encoding, Encoding) + { + return std::pair{nullptr, 0}; + }; + LiftLowerContext borrow_scope(host_trap, noop_convert, borrow_opts, &inst); + borrow_scope.borrow_count = 1; + + HandleElement borrowed; + borrowed.rep = 44; + borrowed.own = false; + borrowed.scope = &borrow_scope; + uint32_t h3 = inst.handles.add(rt, borrowed, host_trap); + CHECK(h3 == 3); + + uint32_t h4 = canon_resource_new(inst, rt, 45, host_trap); + CHECK(h4 == 4); + + const auto &table_entries = inst.handles.table(rt).entries(); + CHECK(table_entries.size() == 5); + CHECK_FALSE(table_entries[0].has_value()); + CHECK(table_entries[1].has_value()); + CHECK(table_entries[2].has_value()); + CHECK(table_entries[3].has_value()); + CHECK(table_entries[4].has_value()); + + CHECK(canon_resource_rep(inst, rt, h1, host_trap) == 42); + CHECK(canon_resource_rep(inst, rt, h2, host_trap) == 43); + CHECK(canon_resource_rep(inst, rt, h3, host_trap) == 44); + CHECK(canon_resource_rep(inst, rt, h4, host_trap) == 45); + + dtor_calls.clear(); + canon_resource_drop(inst, rt, h1, host_trap); + CHECK(dtor_calls == std::vector{42}); + auto &table_after_drop = inst.handles.table(rt); + CHECK(table_after_drop.entries()[1].has_value() == false); + CHECK(table_after_drop.free_list().size() == 1); + + uint32_t h5 = canon_resource_new(inst, rt, 46, host_trap); + CHECK(h5 == 1); + CHECK(table_after_drop.entries().size() == 5); + CHECK(table_after_drop.entries()[1].has_value()); + CHECK(dtor_calls == std::vector{42}); + + borrow_scope.borrow_count = 1; + canon_resource_drop(inst, rt, h3, host_trap); + CHECK(dtor_calls == std::vector{42}); + CHECK(borrow_scope.borrow_count == 0); + CHECK(table_after_drop.entries()[3].has_value() == false); + CHECK(table_after_drop.free_list().size() == 1); + + canon_resource_drop(inst, rt, h2, host_trap); + CHECK(dtor_calls == std::vector{42, 43}); + + canon_resource_drop(inst, rt, h4, host_trap); + CHECK(dtor_calls == std::vector{42, 43, 45}); + + canon_resource_drop(inst, rt, h5, host_trap); + CHECK(dtor_calls == std::vector{42, 43, 45, 46}); + + auto &final_table = inst.handles.table(rt); + CHECK(final_table.free_list().size() == 4); + for (size_t i = 1; i < final_table.entries().size(); ++i) + { + CHECK_FALSE(final_table.entries()[i].has_value()); + } +} + TEST_CASE("Boolean") { Heap heap(1024 * 1024); From 7e85a21a8b02ac9141687960eed888e6a0a5f345 Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 09:16:41 +0100 Subject: [PATCH 2/6] fx: Add waitable, stream, and future infrastructure Signed-off-by: Gordon Smith --- README.md | 4 + definitions.py | 1177 ------------------------------------- include/cmcpp/context.hpp | 945 ++++++++++++++++++++++++++++- run_tests.py | 469 --------------- test/main.cpp | 159 ++++- 5 files changed, 1099 insertions(+), 1655 deletions(-) delete mode 100644 definitions.py delete mode 100644 run_tests.py diff --git a/README.md b/README.md index a496499..8466642 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,10 @@ auto call = store.invoke(func, nullptr, [] { return std::vector{}; }, store.tick(); ``` +### Waitables, streams, and futures + +The canonical async ABI surfaces are implemented via `canon_waitable_*`, `canon_stream_*`, and `canon_future_*` helpers on `ComponentInstance`. Waitable sets can be joined to readable/writable stream ends or futures, and `canon_waitable_set_poll` reports readiness using the same event payload layout defined by the spec. See the doctests in `test/main.cpp` for end-to-end examples. + Call `tick()` in your host loop until all pending work completes. Cancellation is cooperative: calling `Call::request_cancellation()` marks the associated thread as cancelled before the next `tick()`. diff --git a/definitions.py b/definitions.py deleted file mode 100644 index bc23771..0000000 --- a/definitions.py +++ /dev/null @@ -1,1177 +0,0 @@ -# After the Boilerplate section, this file is ordered to line up with the code -# blocks in ../CanonicalABI.md (split by # comment lines). If you update this -# file, don't forget to update ../CanonicalABI.md. - -### Boilerplate - -from __future__ import annotations -import math -import struct -import random -from dataclasses import dataclass -from typing import Optional -from typing import Callable -from typing import MutableMapping - -class Trap(BaseException): pass -class CoreWebAssemblyException(BaseException): pass - -def trap(): - raise Trap() - -def trap_if(cond): - if cond: - raise Trap() - -class Type: pass -class ValType(Type): pass -class ExternType(Type): pass -class CoreExternType(Type): pass - -@dataclass -class CoreImportDecl: - module: str - field: str - t: CoreExternType - -@dataclass -class CoreExportDecl: - name: str - t: CoreExternType - -@dataclass -class ModuleType(ExternType): - imports: list[CoreImportDecl] - exports: list[CoreExportDecl] - -@dataclass -class CoreFuncType(CoreExternType): - params: list[str] - results: list[str] - -@dataclass -class CoreMemoryType(CoreExternType): - initial: list[int] - maximum: Optional[int] - -@dataclass -class ExternDecl: - name: str - t: ExternType - -@dataclass -class ComponentType(ExternType): - imports: list[ExternDecl] - exports: list[ExternDecl] - -@dataclass -class InstanceType(ExternType): - exports: list[ExternDecl] - -@dataclass -class FuncType(ExternType): - params: list[tuple[str,ValType]] - results: list[ValType|tuple[str,ValType]] - def param_types(self): - return self.extract_types(self.params) - def result_types(self): - return self.extract_types(self.results) - def extract_types(self, vec): - if len(vec) == 0: - return [] - if isinstance(vec[0], ValType): - return vec - return [t for name,t in vec] - -@dataclass -class ValueType(ExternType): - t: ValType - -class Bounds: pass - -@dataclass -class Eq(Bounds): - t: Type - -@dataclass -class TypeType(ExternType): - bounds: Bounds - -class Bool(ValType): pass -class S8(ValType): pass -class U8(ValType): pass -class S16(ValType): pass -class U16(ValType): pass -class S32(ValType): pass -class U32(ValType): pass -class S64(ValType): pass -class U64(ValType): pass -class Float32(ValType): pass -class Float64(ValType): pass -class Char(ValType): pass -class String(ValType): pass - -@dataclass -class List(ValType): - t: ValType - -@dataclass -class Field: - label: str - t: ValType - -@dataclass -class Record(ValType): - fields: list[Field] - -@dataclass -class Tuple(ValType): - ts: list[ValType] - -@dataclass -class Case: - label: str - t: Optional[ValType] - refines: Optional[str] = None - -@dataclass -class Variant(ValType): - cases: list[Case] - -@dataclass -class Enum(ValType): - labels: list[str] - -@dataclass -class Option(ValType): - t: ValType - -@dataclass -class Result(ValType): - ok: Optional[ValType] - error: Optional[ValType] - -@dataclass -class Flags(ValType): - labels: list[str] - -@dataclass -class Own(ValType): - rt: ResourceType - -@dataclass -class Borrow(ValType): - rt: ResourceType - -### Despecialization - -def despecialize(t): - match t: - case Tuple(ts) : return Record([ Field(str(i), t) for i,t in enumerate(ts) ]) - case Enum(labels) : return Variant([ Case(l, None) for l in labels ]) - case Option(t) : return Variant([ Case("none", None), Case("some", t) ]) - case Result(ok, error) : return Variant([ Case("ok", ok), Case("error", error) ]) - case _ : return t - -### Alignment - -def alignment(t): - match despecialize(t): - case Bool() : return 1 - case S8() | U8() : return 1 - case S16() | U16() : return 2 - case S32() | U32() : return 4 - case S64() | U64() : return 8 - case Float32() : return 4 - case Float64() : return 8 - case Char() : return 4 - case String() | List(_) : return 4 - case Record(fields) : return alignment_record(fields) - case Variant(cases) : return alignment_variant(cases) - case Flags(labels) : return alignment_flags(labels) - case Own(_) | Borrow(_) : return 4 - -def alignment_record(fields): - a = 1 - for f in fields: - a = max(a, alignment(f.t)) - return a - -def alignment_variant(cases): - return max(alignment(discriminant_type(cases)), max_case_alignment(cases)) - -def discriminant_type(cases): - n = len(cases) - assert(0 < n < (1 << 32)) - match math.ceil(math.log2(n)/8): - case 0: return U8() - case 1: return U8() - case 2: return U16() - case 3: return U32() - -def max_case_alignment(cases): - a = 1 - for c in cases: - if c.t is not None: - a = max(a, alignment(c.t)) - return a - -def alignment_flags(labels): - n = len(labels) - if n <= 8: return 1 - if n <= 16: return 2 - return 4 - -### Size - -def size(t): - match despecialize(t): - case Bool() : return 1 - case S8() | U8() : return 1 - case S16() | U16() : return 2 - case S32() | U32() : return 4 - case S64() | U64() : return 8 - case Float32() : return 4 - case Float64() : return 8 - case Char() : return 4 - case String() | List(_) : return 8 - case Record(fields) : return size_record(fields) - case Variant(cases) : return size_variant(cases) - case Flags(labels) : return size_flags(labels) - case Own(_) | Borrow(_) : return 4 - -def size_record(fields): - s = 0 - for f in fields: - s = align_to(s, alignment(f.t)) - s += size(f.t) - assert(s > 0) - return align_to(s, alignment_record(fields)) - -def align_to(ptr, alignment): - return math.ceil(ptr / alignment) * alignment - -def size_variant(cases): - s = size(discriminant_type(cases)) - s = align_to(s, max_case_alignment(cases)) - cs = 0 - for c in cases: - if c.t is not None: - cs = max(cs, size(c.t)) - s += cs - return align_to(s, alignment_variant(cases)) - -def size_flags(labels): - n = len(labels) - assert(n > 0) - if n <= 8: return 1 - if n <= 16: return 2 - return 4 * num_i32_flags(labels) - -def num_i32_flags(labels): - return math.ceil(len(labels) / 32) - -### Runtime State - -class LiftLowerContext: - opts: CanonicalOptions - inst: ComponentInstance - lenders: list[HandleElem] - borrow_count: int - - def __init__(self, opts, inst): - self.opts = opts - self.inst = inst - self.lenders = [] - self.borrow_count = 0 - - def track_owning_lend(self, lending_handle): - assert(lending_handle.own) - lending_handle.lend_count += 1 - self.lenders.append(lending_handle) - - def exit_call(self): - trap_if(self.borrow_count != 0) - for h in self.lenders: - h.lend_count -= 1 - -class CanonicalOptions: - memory: bytearray - string_encoding: str - realloc: Callable[[int,int,int,int],int] - post_return: Callable[[],None] - -class ComponentInstance: - may_leave: bool - may_enter: bool - handles: HandleTables - - def __init__(self): - self.may_leave = True - self.may_enter = True - self.handles = HandleTables() - -class ResourceType(Type): - impl: ComponentInstance - dtor: Optional[Callable[[int],None]] - - def __init__(self, impl, dtor = None): - self.impl = impl - self.dtor = dtor - -class HandleElem: - rep: int - own: bool - scope: Optional[LiftLowerContext] - lend_count: int - - def __init__(self, rep, own, scope = None): - self.rep = rep - self.own = own - self.scope = scope - self.lend_count = 0 - -class HandleTable: - array: list[Optional[HandleElem]] - free: list[int] - - def __init__(self): - self.array = [None] - self.free = [] - - def get(self, i): - trap_if(i >= len(self.array)) - trap_if(self.array[i] is None) - return self.array[i] - - def add(self, h): - if self.free: - i = self.free.pop() - assert(self.array[i] is None) - self.array[i] = h - else: - i = len(self.array) - trap_if(i >= 2**30) - self.array.append(h) - return i - - def remove(self, rt, i): - h = self.get(i) - self.array[i] = None - self.free.append(i) - return h - -class HandleTables: - rt_to_table: MutableMapping[ResourceType, HandleTable] - - def __init__(self): - self.rt_to_table = dict() - - def table(self, rt): - if rt not in self.rt_to_table: - self.rt_to_table[rt] = HandleTable() - return self.rt_to_table[rt] - - def get(self, rt, i): - return self.table(rt).get(i) - def add(self, rt, h): - return self.table(rt).add(h) - def remove(self, rt, i): - return self.table(rt).remove(rt, i) - -### Loading - -def load(cx, ptr, t): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + size(t) <= len(cx.opts.memory)) - match despecialize(t): - case Bool() : return convert_int_to_bool(load_int(cx, ptr, 1)) - case U8() : return load_int(cx, ptr, 1) - case U16() : return load_int(cx, ptr, 2) - case U32() : return load_int(cx, ptr, 4) - case U64() : return load_int(cx, ptr, 8) - case S8() : return load_int(cx, ptr, 1, signed=True) - case S16() : return load_int(cx, ptr, 2, signed=True) - case S32() : return load_int(cx, ptr, 4, signed=True) - case S64() : return load_int(cx, ptr, 8, signed=True) - case Float32() : return decode_i32_as_float(load_int(cx, ptr, 4)) - case Float64() : return decode_i64_as_float(load_int(cx, ptr, 8)) - case Char() : return convert_i32_to_char(cx, load_int(cx, ptr, 4)) - case String() : return load_string(cx, ptr) - case List(t) : return load_list(cx, ptr, t) - case Record(fields) : return load_record(cx, ptr, fields) - case Variant(cases) : return load_variant(cx, ptr, cases) - case Flags(labels) : return load_flags(cx, ptr, labels) - case Own() : return lift_own(cx, load_int(cx, ptr, 4), t) - case Borrow() : return lift_borrow(cx, load_int(cx, ptr, 4), t) - -def load_int(cx, ptr, nbytes, signed = False): - return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) - -def convert_int_to_bool(i): - assert(i >= 0) - return bool(i) - -DETERMINISTIC_PROFILE = False # or True -CANONICAL_FLOAT32_NAN = 0x7fc00000 -CANONICAL_FLOAT64_NAN = 0x7ff8000000000000 - -def canonicalize_nan32(f): - if math.isnan(f): - f = core_f32_reinterpret_i32(CANONICAL_FLOAT32_NAN) - assert(math.isnan(f)) - return f - -def canonicalize_nan64(f): - if math.isnan(f): - f = core_f64_reinterpret_i64(CANONICAL_FLOAT64_NAN) - assert(math.isnan(f)) - return f - -def decode_i32_as_float(i): - return canonicalize_nan32(core_f32_reinterpret_i32(i)) - -def decode_i64_as_float(i): - return canonicalize_nan64(core_f64_reinterpret_i64(i)) - -def core_f32_reinterpret_i32(i): - return struct.unpack('!f', struct.pack('!I', i))[0] # f32.reinterpret_i32 - -def core_f64_reinterpret_i64(i): - return struct.unpack('!d', struct.pack('!Q', i))[0] # f64.reinterpret_i64 - -def convert_i32_to_char(cx, i): - trap_if(i >= 0x110000) - trap_if(0xD800 <= i <= 0xDFFF) - return chr(i) - -def load_string(cx, ptr): - begin = load_int(cx, ptr, 4) - tagged_code_units = load_int(cx, ptr + 4, 4) - return load_string_from_range(cx, begin, tagged_code_units) - -UTF16_TAG = 1 << 31 - -def load_string_from_range(cx, ptr, tagged_code_units): - match cx.opts.string_encoding: - case 'utf8': - alignment = 1 - byte_length = tagged_code_units - encoding = 'utf-8' - case 'utf16': - alignment = 2 - byte_length = 2 * tagged_code_units - encoding = 'utf-16-le' - case 'latin1+utf16': - alignment = 2 - if bool(tagged_code_units & UTF16_TAG): - byte_length = 2 * (tagged_code_units ^ UTF16_TAG) - encoding = 'utf-16-le' - else: - byte_length = tagged_code_units - encoding = 'latin-1' - - trap_if(ptr != align_to(ptr, alignment)) - trap_if(ptr + byte_length > len(cx.opts.memory)) - try: - s = cx.opts.memory[ptr : ptr+byte_length].decode(encoding) - except UnicodeError: - trap() - - return (s, cx.opts.string_encoding, tagged_code_units) - -def load_list(cx, ptr, elem_type): - begin = load_int(cx, ptr, 4) - length = load_int(cx, ptr + 4, 4) - return load_list_from_range(cx, begin, length, elem_type) - -def load_list_from_range(cx, ptr, length, elem_type): - trap_if(ptr != align_to(ptr, alignment(elem_type))) - trap_if(ptr + length * size(elem_type) > len(cx.opts.memory)) - a = [] - for i in range(length): - a.append(load(cx, ptr + i * size(elem_type), elem_type)) - return a - -def load_record(cx, ptr, fields): - record = {} - for field in fields: - ptr = align_to(ptr, alignment(field.t)) - record[field.label] = load(cx, ptr, field.t) - ptr += size(field.t) - return record - -def load_variant(cx, ptr, cases): - disc_size = size(discriminant_type(cases)) - case_index = load_int(cx, ptr, disc_size) - ptr += disc_size - trap_if(case_index >= len(cases)) - c = cases[case_index] - ptr = align_to(ptr, max_case_alignment(cases)) - case_label = case_label_with_refinements(c, cases) - if c.t is None: - return { case_label: None } - return { case_label: load(cx, ptr, c.t) } - -def case_label_with_refinements(c, cases): - label = c.label - while c.refines is not None: - c = cases[find_case(c.refines, cases)] - label += '|' + c.label - return label - -def find_case(label, cases): - matches = [i for i,c in enumerate(cases) if c.label == label] - assert(len(matches) <= 1) - if len(matches) == 1: - return matches[0] - return -1 - -def load_flags(cx, ptr, labels): - i = load_int(cx, ptr, size_flags(labels)) - return unpack_flags_from_int(i, labels) - -def unpack_flags_from_int(i, labels): - record = {} - for l in labels: - record[l] = bool(i & 1) - i >>= 1 - return record - -def lift_own(cx, i, t): - h = cx.inst.handles.remove(t.rt, i) - trap_if(h.lend_count != 0) - trap_if(not h.own) - return h.rep - -def lift_borrow(cx, i, t): - h = cx.inst.handles.get(t.rt, i) - if h.own: - cx.track_owning_lend(h) - return h.rep - -### Storing - -def store(cx, v, t, ptr): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + size(t) <= len(cx.opts.memory)) - match despecialize(t): - case Bool() : store_int(cx, int(bool(v)), ptr, 1) - case U8() : store_int(cx, v, ptr, 1) - case U16() : store_int(cx, v, ptr, 2) - case U32() : store_int(cx, v, ptr, 4) - case U64() : store_int(cx, v, ptr, 8) - case S8() : store_int(cx, v, ptr, 1, signed=True) - case S16() : store_int(cx, v, ptr, 2, signed=True) - case S32() : store_int(cx, v, ptr, 4, signed=True) - case S64() : store_int(cx, v, ptr, 8, signed=True) - case Float32() : store_int(cx, encode_float_as_i32(v), ptr, 4) - case Float64() : store_int(cx, encode_float_as_i64(v), ptr, 8) - case Char() : store_int(cx, char_to_i32(v), ptr, 4) - case String() : store_string(cx, v, ptr) - case List(t) : store_list(cx, v, ptr, t) - case Record(fields) : store_record(cx, v, ptr, fields) - case Variant(cases) : store_variant(cx, v, ptr, cases) - case Flags(labels) : store_flags(cx, v, ptr, labels) - case Own() : store_int(cx, lower_own(cx.opts, v, t), ptr, 4) - case Borrow() : store_int(cx, lower_borrow(cx.opts, v, t), ptr, 4) - -def store_int(cx, v, ptr, nbytes, signed = False): - cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed) - -def maybe_scramble_nan32(f): - if math.isnan(f): - if DETERMINISTIC_PROFILE: - f = core_f32_reinterpret_i32(CANONICAL_FLOAT32_NAN) - else: - f = core_f32_reinterpret_i32(random_nan_bits(32, 8)) - assert(math.isnan(f)) - return f - -def maybe_scramble_nan64(f): - if math.isnan(f): - if DETERMINISTIC_PROFILE: - f = core_f64_reinterpret_i64(CANONICAL_FLOAT64_NAN) - else: - f = core_f64_reinterpret_i64(random_nan_bits(64, 11)) - assert(math.isnan(f)) - return f - -def random_nan_bits(total_bits, exponent_bits): - fraction_bits = total_bits - exponent_bits - 1 - bits = random.getrandbits(total_bits) - bits |= ((1 << exponent_bits) - 1) << fraction_bits - bits |= 1 << random.randrange(fraction_bits - 1) - return bits - -def encode_float_as_i32(f): - return core_i32_reinterpret_f32(maybe_scramble_nan32(f)) - -def encode_float_as_i64(f): - return core_i64_reinterpret_f64(maybe_scramble_nan64(f)) - -def core_i32_reinterpret_f32(f): - return struct.unpack('!I', struct.pack('!f', f))[0] # i32.reinterpret_f32 - -def core_i64_reinterpret_f64(f): - return struct.unpack('!Q', struct.pack('!d', f))[0] # i64.reinterpret_f64 - -def char_to_i32(c): - i = ord(c) - assert(0 <= i <= 0xD7FF or 0xD800 <= i <= 0x10FFFF) - return i - -def store_string(cx, v, ptr): - begin, tagged_code_units = store_string_into_range(cx, v) - store_int(cx, begin, ptr, 4) - store_int(cx, tagged_code_units, ptr + 4, 4) - -def store_string_into_range(cx, v): - src, src_encoding, src_tagged_code_units = v - - if src_encoding == 'latin1+utf16': - if bool(src_tagged_code_units & UTF16_TAG): - src_simple_encoding = 'utf16' - src_code_units = src_tagged_code_units ^ UTF16_TAG - else: - src_simple_encoding = 'latin1' - src_code_units = src_tagged_code_units - else: - src_simple_encoding = src_encoding - src_code_units = src_tagged_code_units - - match cx.opts.string_encoding: - case 'utf8': - match src_simple_encoding: - case 'utf8' : return store_string_copy(cx, src, src_code_units, 1, 1, 'utf-8') - case 'utf16' : return store_utf16_to_utf8(cx, src, src_code_units) - case 'latin1' : return store_latin1_to_utf8(cx, src, src_code_units) - case 'utf16': - match src_simple_encoding: - case 'utf8' : return store_utf8_to_utf16(cx, src, src_code_units) - case 'utf16' : return store_string_copy(cx, src, src_code_units, 2, 2, 'utf-16-le') - case 'latin1' : return store_string_copy(cx, src, src_code_units, 2, 2, 'utf-16-le') - case 'latin1+utf16': - match src_encoding: - case 'utf8' : return store_string_to_latin1_or_utf16(cx, src, src_code_units) - case 'utf16' : return store_string_to_latin1_or_utf16(cx, src, src_code_units) - case 'latin1+utf16' : - match src_simple_encoding: - case 'latin1' : return store_string_copy(cx, src, src_code_units, 1, 2, 'latin-1') - case 'utf16' : return store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units) - -MAX_STRING_BYTE_LENGTH = (1 << 31) - 1 - -def store_string_copy(cx, src, src_code_units, dst_code_unit_size, dst_alignment, dst_encoding): - dst_byte_length = dst_code_unit_size * src_code_units - trap_if(dst_byte_length > MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(0, 0, dst_alignment, dst_byte_length) - trap_if(ptr != align_to(ptr, dst_alignment)) - trap_if(ptr + dst_byte_length > len(cx.opts.memory)) - encoded = src.encode(dst_encoding) - assert(dst_byte_length == len(encoded)) - cx.opts.memory[ptr : ptr+len(encoded)] = encoded - return (ptr, src_code_units) - -def store_utf16_to_utf8(cx, src, src_code_units): - worst_case_size = src_code_units * 3 - return store_string_to_utf8(cx, src, src_code_units, worst_case_size) - -def store_latin1_to_utf8(cx, src, src_code_units): - worst_case_size = src_code_units * 2 - return store_string_to_utf8(cx, src, src_code_units, worst_case_size) - -def store_string_to_utf8(cx, src, src_code_units, worst_case_size): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(0, 0, 1, src_code_units) - trap_if(ptr + src_code_units > len(cx.opts.memory)) - encoded = src.encode('utf-8') - assert(src_code_units <= len(encoded)) - cx.opts.memory[ptr : ptr+src_code_units] = encoded[0 : src_code_units] - if src_code_units < len(encoded): - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(ptr, src_code_units, 1, worst_case_size) - trap_if(ptr + worst_case_size > len(cx.opts.memory)) - cx.opts.memory[ptr+src_code_units : ptr+len(encoded)] = encoded[src_code_units : ] - if worst_case_size > len(encoded): - ptr = cx.opts.realloc(ptr, worst_case_size, 1, len(encoded)) - trap_if(ptr + len(encoded) > len(cx.opts.memory)) - return (ptr, len(encoded)) - -def store_utf8_to_utf16(cx, src, src_code_units): - worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(0, 0, 2, worst_case_size) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + worst_case_size > len(cx.opts.memory)) - encoded = src.encode('utf-16-le') - cx.opts.memory[ptr : ptr+len(encoded)] = encoded - if len(encoded) < worst_case_size: - ptr = cx.opts.realloc(ptr, worst_case_size, 2, len(encoded)) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + len(encoded) > len(cx.opts.memory)) - code_units = int(len(encoded) / 2) - return (ptr, code_units) - -def store_string_to_latin1_or_utf16(cx, src, src_code_units): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(0, 0, 2, src_code_units) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + src_code_units > len(cx.opts.memory)) - dst_byte_length = 0 - for usv in src: - if ord(usv) < (1 << 8): - cx.opts.memory[ptr + dst_byte_length] = ord(usv) - dst_byte_length += 1 - else: - worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(ptr, src_code_units, 2, worst_case_size) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + worst_case_size > len(cx.opts.memory)) - for j in range(dst_byte_length-1, -1, -1): - cx.opts.memory[ptr + 2*j] = cx.opts.memory[ptr + j] - cx.opts.memory[ptr + 2*j + 1] = 0 - encoded = src.encode('utf-16-le') - cx.opts.memory[ptr+2*dst_byte_length : ptr+len(encoded)] = encoded[2*dst_byte_length : ] - if worst_case_size > len(encoded): - ptr = cx.opts.realloc(ptr, worst_case_size, 2, len(encoded)) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + len(encoded) > len(cx.opts.memory)) - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG - return (ptr, tagged_code_units) - if dst_byte_length < src_code_units: - ptr = cx.opts.realloc(ptr, src_code_units, 2, dst_byte_length) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + dst_byte_length > len(cx.opts.memory)) - return (ptr, dst_byte_length) - -def store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units): - src_byte_length = 2 * src_code_units - trap_if(src_byte_length > MAX_STRING_BYTE_LENGTH) - ptr = cx.opts.realloc(0, 0, 2, src_byte_length) - trap_if(ptr != align_to(ptr, 2)) - trap_if(ptr + src_byte_length > len(cx.opts.memory)) - encoded = src.encode('utf-16-le') - cx.opts.memory[ptr : ptr+len(encoded)] = encoded - if any(ord(c) >= (1 << 8) for c in src): - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG - return (ptr, tagged_code_units) - latin1_size = int(len(encoded) / 2) - for i in range(latin1_size): - cx.opts.memory[ptr + i] = cx.opts.memory[ptr + 2*i] - ptr = cx.opts.realloc(ptr, src_byte_length, 1, latin1_size) - trap_if(ptr + latin1_size > len(cx.opts.memory)) - return (ptr, latin1_size) - -def store_list(cx, v, ptr, elem_type): - begin, length = store_list_into_range(cx, v, elem_type) - store_int(cx, begin, ptr, 4) - store_int(cx, length, ptr + 4, 4) - -def store_list_into_range(cx, v, elem_type): - byte_length = len(v) * size(elem_type) - trap_if(byte_length >= (1 << 32)) - ptr = cx.opts.realloc(0, 0, alignment(elem_type), byte_length) - trap_if(ptr != align_to(ptr, alignment(elem_type))) - trap_if(ptr + byte_length > len(cx.opts.memory)) - for i,e in enumerate(v): - store(cx, e, elem_type, ptr + i * size(elem_type)) - return (ptr, len(v)) - -def store_record(cx, v, ptr, fields): - for f in fields: - ptr = align_to(ptr, alignment(f.t)) - store(cx, v[f.label], f.t, ptr) - ptr += size(f.t) - -def store_variant(cx, v, ptr, cases): - case_index, case_value = match_case(v, cases) - disc_size = size(discriminant_type(cases)) - store_int(cx, case_index, ptr, disc_size) - ptr += disc_size - ptr = align_to(ptr, max_case_alignment(cases)) - c = cases[case_index] - if c.t is not None: - store(cx, case_value, c.t, ptr) - -def match_case(v, cases): - assert(len(v.keys()) == 1) - key = list(v.keys())[0] - value = list(v.values())[0] - for label in key.split('|'): - case_index = find_case(label, cases) - if case_index != -1: - return (case_index, value) - -def store_flags(cx, v, ptr, labels): - i = pack_flags_into_int(v, labels) - store_int(cx, i, ptr, size_flags(labels)) - -def pack_flags_into_int(v, labels): - i = 0 - shift = 0 - for l in labels: - i |= (int(bool(v[l])) << shift) - shift += 1 - return i - -def lower_own(cx, rep, t): - h = HandleElem(rep, own=True) - return cx.inst.handles.add(t.rt, h) - -def lower_borrow(cx, rep, t): - if cx.inst is t.rt.impl: - return rep - h = HandleElem(rep, own=False, scope=cx) - cx.borrow_count += 1 - return cx.inst.handles.add(t.rt, h) - -### Flattening - -MAX_FLAT_PARAMS = 16 -MAX_FLAT_RESULTS = 1 - -def flatten_functype(ft, context): - flat_params = flatten_types(ft.param_types()) - if len(flat_params) > MAX_FLAT_PARAMS: - flat_params = ['i32'] - - flat_results = flatten_types(ft.result_types()) - if len(flat_results) > MAX_FLAT_RESULTS: - match context: - case 'lift': - flat_results = ['i32'] - case 'lower': - flat_params += ['i32'] - flat_results = [] - - return CoreFuncType(flat_params, flat_results) - -def flatten_types(ts): - return [ft for t in ts for ft in flatten_type(t)] - -def flatten_type(t): - match despecialize(t): - case Bool() : return ['i32'] - case U8() | U16() | U32() : return ['i32'] - case S8() | S16() | S32() : return ['i32'] - case S64() | U64() : return ['i64'] - case Float32() : return ['f32'] - case Float64() : return ['f64'] - case Char() : return ['i32'] - case String() | List(_) : return ['i32', 'i32'] - case Record(fields) : return flatten_record(fields) - case Variant(cases) : return flatten_variant(cases) - case Flags(labels) : return ['i32'] * num_i32_flags(labels) - case Own(_) | Borrow(_) : return ['i32'] - -def flatten_record(fields): - flat = [] - for f in fields: - flat += flatten_type(f.t) - return flat - -def flatten_variant(cases): - flat = [] - for c in cases: - if c.t is not None: - for i,ft in enumerate(flatten_type(c.t)): - if i < len(flat): - flat[i] = join(flat[i], ft) - else: - flat.append(ft) - return flatten_type(discriminant_type(cases)) + flat - -def join(a, b): - if a == b: return a - if (a == 'i32' and b == 'f32') or (a == 'f32' and b == 'i32'): return 'i32' - return 'i64' - -### Flat Lifting - -@dataclass -class Value: - t: str # 'i32'|'i64'|'f32'|'f64' - v: int|float - -@dataclass -class ValueIter: - values: list[Value] - i = 0 - def next(self, t): - v = self.values[self.i] - self.i += 1 - assert(v.t == t) - return v.v - -def lift_flat(cx, vi, t): - match despecialize(t): - case Bool() : return convert_int_to_bool(vi.next('i32')) - case U8() : return lift_flat_unsigned(vi, 32, 8) - case U16() : return lift_flat_unsigned(vi, 32, 16) - case U32() : return lift_flat_unsigned(vi, 32, 32) - case U64() : return lift_flat_unsigned(vi, 64, 64) - case S8() : return lift_flat_signed(vi, 32, 8) - case S16() : return lift_flat_signed(vi, 32, 16) - case S32() : return lift_flat_signed(vi, 32, 32) - case S64() : return lift_flat_signed(vi, 64, 64) - case Float32() : return canonicalize_nan32(vi.next('f32')) - case Float64() : return canonicalize_nan64(vi.next('f64')) - case Char() : return convert_i32_to_char(cx, vi.next('i32')) - case String() : return lift_flat_string(cx, vi) - case List(t) : return lift_flat_list(cx, vi, t) - case Record(fields) : return lift_flat_record(cx, vi, fields) - case Variant(cases) : return lift_flat_variant(cx, vi, cases) - case Flags(labels) : return lift_flat_flags(vi, labels) - case Own() : return lift_own(cx, vi.next('i32'), t) - case Borrow() : return lift_borrow(cx, vi.next('i32'), t) - -def lift_flat_unsigned(vi, core_width, t_width): - i = vi.next('i' + str(core_width)) - assert(0 <= i < (1 << core_width)) - return i % (1 << t_width) - -def lift_flat_signed(vi, core_width, t_width): - i = vi.next('i' + str(core_width)) - assert(0 <= i < (1 << core_width)) - i %= (1 << t_width) - if i >= (1 << (t_width - 1)): - return i - (1 << t_width) - return i - -def lift_flat_string(cx, vi): - ptr = vi.next('i32') - packed_length = vi.next('i32') - return load_string_from_range(cx, ptr, packed_length) - -def lift_flat_list(cx, vi, elem_type): - ptr = vi.next('i32') - length = vi.next('i32') - return load_list_from_range(cx, ptr, length, elem_type) - -def lift_flat_record(cx, vi, fields): - record = {} - for f in fields: - record[f.label] = lift_flat(cx, vi, f.t) - return record - -def lift_flat_variant(cx, vi, cases): - flat_types = flatten_variant(cases) - assert(flat_types.pop(0) == 'i32') - case_index = vi.next('i32') - trap_if(case_index >= len(cases)) - class CoerceValueIter: - def next(self, want): - have = flat_types.pop(0) - x = vi.next(have) - match (have, want): - case ('i32', 'f32') : return decode_i32_as_float(x) - case ('i64', 'i32') : return wrap_i64_to_i32(x) - case ('i64', 'f32') : return decode_i32_as_float(wrap_i64_to_i32(x)) - case ('i64', 'f64') : return decode_i64_as_float(x) - case _ : return x - c = cases[case_index] - if c.t is None: - v = None - else: - v = lift_flat(cx, CoerceValueIter(), c.t) - for have in flat_types: - _ = vi.next(have) - return { case_label_with_refinements(c, cases): v } - -def wrap_i64_to_i32(i): - assert(0 <= i < (1 << 64)) - return i % (1 << 32) - -def lift_flat_flags(vi, labels): - i = 0 - shift = 0 - for _ in range(num_i32_flags(labels)): - i |= (vi.next('i32') << shift) - shift += 32 - return unpack_flags_from_int(i, labels) - -### Flat Lowering - -def lower_flat(cx, v, t): - match despecialize(t): - case Bool() : return [Value('i32', int(v))] - case U8() : return [Value('i32', v)] - case U16() : return [Value('i32', v)] - case U32() : return [Value('i32', v)] - case U64() : return [Value('i64', v)] - case S8() : return lower_flat_signed(v, 32) - case S16() : return lower_flat_signed(v, 32) - case S32() : return lower_flat_signed(v, 32) - case S64() : return lower_flat_signed(v, 64) - case Float32() : return [Value('f32', maybe_scramble_nan32(v))] - case Float64() : return [Value('f64', maybe_scramble_nan64(v))] - case Char() : return [Value('i32', char_to_i32(v))] - case String() : return lower_flat_string(cx, v) - case List(t) : return lower_flat_list(cx, v, t) - case Record(fields) : return lower_flat_record(cx, v, fields) - case Variant(cases) : return lower_flat_variant(cx, v, cases) - case Flags(labels) : return lower_flat_flags(v, labels) - case Own() : return [Value('i32', lower_own(cx, v, t))] - case Borrow() : return [Value('i32', lower_borrow(cx, v, t))] - -def lower_flat_signed(i, core_bits): - if i < 0: - i += (1 << core_bits) - return [Value('i' + str(core_bits), i)] - -def lower_flat_string(cx, v): - ptr, packed_length = store_string_into_range(cx, v) - return [Value('i32', ptr), Value('i32', packed_length)] - -def lower_flat_list(cx, v, elem_type): - (ptr, length) = store_list_into_range(cx, v, elem_type) - return [Value('i32', ptr), Value('i32', length)] - -def lower_flat_record(cx, v, fields): - flat = [] - for f in fields: - flat += lower_flat(cx, v[f.label], f.t) - return flat - -def lower_flat_variant(cx, v, cases): - case_index, case_value = match_case(v, cases) - flat_types = flatten_variant(cases) - assert(flat_types.pop(0) == 'i32') - c = cases[case_index] - if c.t is None: - payload = [] - else: - payload = lower_flat(cx, case_value, c.t) - for i,have in enumerate(payload): - want = flat_types.pop(0) - match (have.t, want): - case ('f32', 'i32') : payload[i] = Value('i32', encode_float_as_i32(have.v)) - case ('i32', 'i64') : payload[i] = Value('i64', have.v) - case ('f32', 'i64') : payload[i] = Value('i64', encode_float_as_i32(have.v)) - case ('f64', 'i64') : payload[i] = Value('i64', encode_float_as_i64(have.v)) - case _ : pass - for want in flat_types: - payload.append(Value(want, 0)) - return [Value('i32', case_index)] + payload - -def lower_flat_flags(v, labels): - i = pack_flags_into_int(v, labels) - flat = [] - for _ in range(num_i32_flags(labels)): - flat.append(Value('i32', i & 0xffffffff)) - i >>= 32 - assert(i == 0) - return flat - -### Lifting and Lowering Values - -def lift_values(cx, max_flat, vi, ts): - flat_types = flatten_types(ts) - if len(flat_types) > max_flat: - ptr = vi.next('i32') - tuple_type = Tuple(ts) - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + size(tuple_type) > len(cx.opts.memory)) - return list(load(cx, ptr, tuple_type).values()) - else: - return [ lift_flat(cx, vi, t) for t in ts ] - -def lower_values(cx, max_flat, vs, ts, out_param = None): - flat_types = flatten_types(ts) - if len(flat_types) > max_flat: - tuple_type = Tuple(ts) - tuple_value = {str(i): v for i,v in enumerate(vs)} - if out_param is None: - ptr = cx.opts.realloc(0, 0, alignment(tuple_type), size(tuple_type)) - else: - ptr = out_param.next('i32') - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + size(tuple_type) > len(cx.opts.memory)) - store(cx, tuple_value, tuple_type, ptr) - return [ Value('i32', ptr) ] - else: - flat_vals = [] - for i in range(len(vs)): - flat_vals += lower_flat(cx, vs[i], ts[i]) - return flat_vals - -### `canon lift` - -def canon_lift(opts, inst, callee, ft, args): - cx = LiftLowerContext(opts, inst) - trap_if(not inst.may_enter) - - assert(inst.may_leave) - inst.may_leave = False - flat_args = lower_values(cx, MAX_FLAT_PARAMS, args, ft.param_types()) - inst.may_leave = True - - try: - flat_results = callee(flat_args) - except CoreWebAssemblyException: - trap() - - results = lift_values(cx, MAX_FLAT_RESULTS, ValueIter(flat_results), ft.result_types()) - - def post_return(): - if opts.post_return is not None: - opts.post_return(flat_results) - cx.exit_call() - - return (results, post_return) - -### `canon lower` - -def canon_lower(opts, inst, callee, calling_import, ft, flat_args): - cx = LiftLowerContext(opts, inst) - trap_if(not inst.may_leave) - - assert(inst.may_enter) - if calling_import: - inst.may_enter = False - - flat_args = ValueIter(flat_args) - args = lift_values(cx, MAX_FLAT_PARAMS, flat_args, ft.param_types()) - - results, post_return = callee(args) - - inst.may_leave = False - flat_results = lower_values(cx, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args) - inst.may_leave = True - - post_return() - cx.exit_call() - - if calling_import: - inst.may_enter = True - - return flat_results - -### `canon resource.new` - -def canon_resource_new(inst, rt, rep): - h = HandleElem(rep, own=True) - return inst.handles.add(rt, h) - -### `canon resource.drop` - -def canon_resource_drop(inst, rt, i): - h = inst.handles.remove(rt, i) - if h.own: - assert(h.scope is None) - trap_if(h.lend_count != 0) - trap_if(inst is not rt.impl and not rt.impl.may_enter) - if rt.dtor: - rt.dtor(h.rep) - else: - assert(h.scope is not None) - assert(h.scope.borrow_count > 0) - h.scope.borrow_count -= 1 - -### `canon resource.rep` - -def canon_resource_rep(inst, rt, i): - h = inst.handles.get(rt, i) - return h.rep diff --git a/include/cmcpp/context.hpp b/include/cmcpp/context.hpp index 20ccf6d..b85df17 100644 --- a/include/cmcpp/context.hpp +++ b/include/cmcpp/context.hpp @@ -3,11 +3,18 @@ #include "traits.hpp" +#include +#include #include +#include #include +#include #include +#include #include #include +#include +#include #if __has_include() #include #else @@ -101,6 +108,747 @@ namespace cmcpp return LiftLowerContext(trap, convert, opts); } + constexpr uint32_t BLOCKED = 0xFFFF'FFFFu; + + enum class EventCode : uint8_t + { + NONE = 0, + SUBTASK = 1, + STREAM_READ = 2, + STREAM_WRITE = 3, + FUTURE_READ = 4, + FUTURE_WRITE = 5, + TASK_CANCELLED = 6 + }; + + struct Event + { + EventCode code = EventCode::NONE; + uint32_t index = 0; + uint32_t payload = 0; + }; + + struct TableEntry + { + virtual ~TableEntry() = default; + }; + + class WaitableSet; + + class Waitable : public TableEntry + { + public: + Waitable() = default; + + void set_pending_event(const Event &event) + { + pending_event_ = event; + } + + bool has_pending_event() const + { + return pending_event_.has_value(); + } + + Event get_pending_event(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !pending_event_.has_value(), "waitable pending event missing"); + Event event = *pending_event_; + pending_event_.reset(); + return event; + } + + void clear_pending_event() + { + pending_event_.reset(); + } + + void join(WaitableSet *set, const HostTrap &trap); + + WaitableSet *joined_set() const + { + return wset_; + } + + void drop(const HostTrap &trap); + + private: + std::optional pending_event_; + WaitableSet *wset_ = nullptr; + }; + + class WaitableSet : public TableEntry + { + public: + void add_waitable(Waitable &waitable) + { + if (std::find(waitables_.begin(), waitables_.end(), &waitable) == waitables_.end()) + { + waitables_.push_back(&waitable); + } + } + + void remove_waitable(Waitable &waitable) + { + auto it = std::find(waitables_.begin(), waitables_.end(), &waitable); + if (it != waitables_.end()) + { + waitables_.erase(it); + } + } + + bool has_pending_event() const + { + return std::any_of(waitables_.begin(), waitables_.end(), [](const Waitable *w) + { return w && w->has_pending_event(); }); + } + + Event take_pending_event(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, waitables_.empty(), "waitable set empty"); + for (auto *w : waitables_) + { + if (w != nullptr && w->has_pending_event()) + { + return w->get_pending_event(trap); + } + } + trap_if(trap_cx, true, "waitable set missing event"); + return {}; + } + + void drop(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !waitables_.empty(), "waitable set not empty"); + trap_if(trap_cx, num_waiting_ != 0, "waitable set has waiters"); + } + + void begin_wait() + { + num_waiting_ += 1; + } + + void end_wait() + { + if (num_waiting_ > 0) + { + num_waiting_ -= 1; + } + } + + uint32_t num_waiting() const + { + return num_waiting_; + } + + private: + std::vector waitables_; + uint32_t num_waiting_ = 0; + }; + + inline void Waitable::join(WaitableSet *set, const HostTrap &) + { + if (wset_ == set) + { + return; + } + if (wset_) + { + wset_->remove_waitable(*this); + } + wset_ = set; + if (wset_) + { + wset_->add_waitable(*this); + } + } + + inline void Waitable::drop(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, has_pending_event(), "waitable drop with pending event"); + if (wset_) + { + wset_->remove_waitable(*this); + wset_ = nullptr; + } + } + + class InstanceTable + { + public: + uint32_t add(const std::shared_ptr &entry, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !entry, "null table entry"); + uint32_t index; + if (!free_.empty()) + { + index = free_.back(); + free_.pop_back(); + entries_[index] = entry; + } + else + { + trap_if(trap_cx, entries_.size() >= (1u << 30), "instance table overflow"); + entries_.push_back(entry); + index = static_cast(entries_.size() - 1); + } + return index; + } + + std::shared_ptr get_entry(uint32_t index, const HostTrap &trap) const + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, index == 0 || index >= entries_.size(), "table index out of bounds"); + auto entry = entries_[index]; + trap_if(trap_cx, !entry, "table slot empty"); + return entry; + } + + std::shared_ptr remove_entry(uint32_t index, const HostTrap &trap) + { + auto entry = get_entry(index, trap); + entries_[index].reset(); + free_.push_back(index); + return entry; + } + + template + std::shared_ptr get(uint32_t index, const HostTrap &trap) const + { + auto base = get_entry(index, trap); + auto derived = std::dynamic_pointer_cast(base); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !derived, "table entry type mismatch"); + return derived; + } + + template + std::shared_ptr remove(uint32_t index, const HostTrap &trap) + { + auto base = remove_entry(index, trap); + auto derived = std::dynamic_pointer_cast(base); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !derived, "table entry type mismatch"); + return derived; + } + + private: + std::vector> entries_{nullptr}; + std::vector free_; + }; + + struct StreamDescriptor + { + uint32_t element_size = 0; + uint32_t alignment = 1; + std::type_index type = typeid(void); + }; + + template + StreamDescriptor make_stream_descriptor() + { + return StreamDescriptor{ValTrait::size, ValTrait::alignment, std::type_index(typeid(T))}; + } + + struct FutureDescriptor + { + uint32_t element_size = 0; + uint32_t alignment = 1; + std::type_index type = typeid(void); + }; + + template + FutureDescriptor make_future_descriptor() + { + return FutureDescriptor{ValTrait::size, ValTrait::alignment, std::type_index(typeid(T))}; + } + + inline uint8_t normalize_alignment(uint32_t alignment) + { + if (alignment == 0) + { + return 1; + } + return static_cast(std::min(alignment, 255)); + } + + inline void ensure_memory_range(const LiftLowerContext &cx, uint32_t ptr, uint32_t count, uint32_t alignment, uint32_t elem_size) + { + auto align_value = normalize_alignment(alignment); + trap_if(cx, ptr != align_to(ptr, align_value), "misaligned memory access"); + uint64_t total_bytes = static_cast(count) * elem_size; + trap_if(cx, ptr + total_bytes > cx.opts.memory.size(), "memory overflow"); + } + + inline void write_event_fields(GuestMemory mem, uint32_t ptr, uint32_t p1, uint32_t p2, const HostTrap &trap) + { + if (ptr + 8 > mem.size()) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, true, "event write out of bounds"); + } + std::memcpy(mem.data() + ptr, &p1, sizeof(uint32_t)); + std::memcpy(mem.data() + ptr + sizeof(uint32_t), &p2, sizeof(uint32_t)); + } + + enum class CopyResult : uint32_t + { + Completed = 0, + Dropped = 1, + Cancelled = 2 + }; + + enum class CopyState : uint8_t + { + Idle = 0, + Copying = 1, + Done = 2 + }; + + inline uint32_t pack_copy_result(CopyResult result, uint32_t progress) + { + return static_cast(result) | (progress << 4); + } + + inline void validate_descriptor(const StreamDescriptor &expected, const StreamDescriptor &actual, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, expected.element_size != actual.element_size, "stream descriptor size mismatch"); + trap_if(trap_cx, expected.alignment != actual.alignment, "stream descriptor alignment mismatch"); + trap_if(trap_cx, expected.type != actual.type, "stream descriptor type mismatch"); + } + + inline void validate_descriptor(const FutureDescriptor &expected, const FutureDescriptor &actual, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, expected.element_size != actual.element_size, "future descriptor size mismatch"); + trap_if(trap_cx, expected.alignment != actual.alignment, "future descriptor alignment mismatch"); + trap_if(trap_cx, expected.type != actual.type, "future descriptor type mismatch"); + } + + class ReadableStreamEnd; + class WritableStreamEnd; + + struct SharedStreamState + { + explicit SharedStreamState(const StreamDescriptor &desc) : descriptor(desc) {} + + StreamDescriptor descriptor; + std::deque> queue; + bool readable_dropped = false; + bool writable_dropped = false; + + struct PendingRead + { + std::shared_ptr cx; + uint32_t ptr = 0; + uint32_t requested = 0; + uint32_t progress = 0; + uint32_t handle_index = 0; + ReadableStreamEnd *endpoint = nullptr; + }; + + std::optional pending_read; + }; + + inline void copy_into_queue(const std::shared_ptr &cx, uint32_t ptr, uint32_t count, SharedStreamState &state, const HostTrap &trap) + { + if (count == 0) + { + return; + } + ensure_memory_range(*cx, ptr, count, state.descriptor.alignment, state.descriptor.element_size); + auto *src = cx->opts.memory.data() + ptr; + for (uint32_t i = 0; i < count; ++i) + { + std::vector bytes(state.descriptor.element_size); + std::memcpy(bytes.data(), src + i * state.descriptor.element_size, state.descriptor.element_size); + state.queue.push_back(std::move(bytes)); + } + } + + inline uint32_t copy_from_queue(const std::shared_ptr &cx, + uint32_t ptr, + uint32_t offset, + uint32_t max_count, + SharedStreamState &state, + const HostTrap &trap) + { + if (max_count == 0) + { + return 0; + } + uint32_t available = std::min(max_count, static_cast(state.queue.size())); + if (available == 0) + { + return 0; + } + ensure_memory_range(*cx, ptr, offset + available, state.descriptor.alignment, state.descriptor.element_size); + auto *dest = cx->opts.memory.data() + ptr + offset * state.descriptor.element_size; + auto trap_cx = make_trap_context(trap); + for (uint32_t i = 0; i < available; ++i) + { + const auto &bytes = state.queue.front(); + trap_if(trap_cx, bytes.size() != state.descriptor.element_size, "stream element size mismatch"); + std::memcpy(dest + i * state.descriptor.element_size, bytes.data(), state.descriptor.element_size); + state.queue.pop_front(); + } + return available; + } + + inline void satisfy_pending_read(SharedStreamState &state, const HostTrap &trap); + + class ReadableStreamEnd : public Waitable + { + public: + explicit ReadableStreamEnd(std::shared_ptr shared) : shared_(std::move(shared)) {} + + const StreamDescriptor &descriptor() const + { + return shared_->descriptor; + } + + uint32_t read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, uint32_t n, bool sync, const HostTrap &trap); + uint32_t cancel(bool sync, const HostTrap &trap); + void drop(const HostTrap &trap); + void complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); + + private: + std::shared_ptr shared_; + CopyState state_ = CopyState::Idle; + }; + + class WritableStreamEnd : public Waitable + { + public: + explicit WritableStreamEnd(std::shared_ptr shared) : shared_(std::move(shared)) {} + + const StreamDescriptor &descriptor() const + { + return shared_->descriptor; + } + + uint32_t write(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, uint32_t n, const HostTrap &trap); + uint32_t cancel(bool sync, const HostTrap &trap); + void drop(const HostTrap &trap); + + private: + std::shared_ptr shared_; + CopyState state_ = CopyState::Idle; + }; + + inline uint32_t ReadableStreamEnd::read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, uint32_t n, bool sync, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !shared_, "stream state missing"); + trap_if(trap_cx, shared_->descriptor.element_size == 0, "invalid stream descriptor"); + trap_if(trap_cx, state_ != CopyState::Idle, "stream read busy"); + + uint32_t consumed = copy_from_queue(cx, ptr, 0, n, *shared_, trap); + if (consumed > 0 || n == 0) + { + set_pending_event({EventCode::STREAM_READ, handle_index, pack_copy_result(CopyResult::Completed, consumed)}); + auto event = get_pending_event(trap); + state_ = CopyState::Idle; + return event.payload; + } + + if (shared_->writable_dropped) + { + set_pending_event({EventCode::STREAM_READ, handle_index, pack_copy_result(CopyResult::Dropped, 0)}); + auto event = get_pending_event(trap); + state_ = CopyState::Done; + return event.payload; + } + + trap_if(trap_cx, sync, "sync stream read would block"); + shared_->pending_read = SharedStreamState::PendingRead{cx, ptr, n, 0, handle_index, this}; + state_ = CopyState::Copying; + return BLOCKED; + } + + inline uint32_t ReadableStreamEnd::cancel(bool sync, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ != CopyState::Copying, "no pending stream read"); + trap_if(trap_cx, !shared_ || !shared_->pending_read, "no pending stream read"); + + auto pending = std::move(*shared_->pending_read); + shared_->pending_read.reset(); + set_pending_event({EventCode::STREAM_READ, pending.handle_index, pack_copy_result(CopyResult::Cancelled, pending.progress)}); + state_ = CopyState::Done; + + if (sync) + { + auto event = get_pending_event(trap); + return event.payload; + } + return BLOCKED; + } + + inline void ReadableStreamEnd::drop(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ == CopyState::Copying, "cannot drop pending stream read"); + trap_if(trap_cx, shared_ && shared_->pending_read.has_value(), "pending read must complete before drop"); + if (shared_) + { + shared_->readable_dropped = true; + } + state_ = CopyState::Done; + Waitable::drop(trap); + } + + inline void ReadableStreamEnd::complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) + { + set_pending_event({EventCode::STREAM_READ, handle_index, pack_copy_result(result, progress)}); + state_ = (result == CopyResult::Completed) ? CopyState::Idle : CopyState::Done; + } + + inline void satisfy_pending_read(SharedStreamState &state, const HostTrap &trap) + { + if (!state.pending_read) + { + return; + } + auto &pending = *state.pending_read; + uint32_t remaining = pending.requested - pending.progress; + uint32_t consumed = copy_from_queue(pending.cx, pending.ptr, pending.progress, remaining, state, trap); + pending.progress += consumed; + if (pending.progress >= pending.requested) + { + pending.endpoint->complete_async(pending.handle_index, CopyResult::Completed, pending.progress, trap); + state.pending_read.reset(); + } + } + + inline uint32_t WritableStreamEnd::write(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, uint32_t n, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !shared_, "stream state missing"); + trap_if(trap_cx, shared_->descriptor.element_size == 0, "invalid stream descriptor"); + trap_if(trap_cx, state_ != CopyState::Idle, "stream write busy"); + + copy_into_queue(cx, ptr, n, *shared_, trap); + satisfy_pending_read(*shared_, trap); + + set_pending_event({EventCode::STREAM_WRITE, handle_index, pack_copy_result(CopyResult::Completed, n)}); + auto event = get_pending_event(trap); + state_ = CopyState::Idle; + return event.payload; + } + + inline uint32_t WritableStreamEnd::cancel(bool, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, true, "no pending stream write"); + return BLOCKED; + } + + inline void WritableStreamEnd::drop(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ == CopyState::Copying, "cannot drop pending stream write"); + if (shared_) + { + if (shared_->pending_read) + { + auto pending = std::move(*shared_->pending_read); + shared_->pending_read.reset(); + pending.endpoint->complete_async(pending.handle_index, CopyResult::Dropped, pending.progress, trap); + } + shared_->writable_dropped = true; + } + state_ = CopyState::Done; + Waitable::drop(trap); + } + + class ReadableFutureEnd; + class WritableFutureEnd; + + struct SharedFutureState + { + explicit SharedFutureState(const FutureDescriptor &desc) : descriptor(desc), value(desc.element_size) {} + + FutureDescriptor descriptor; + bool readable_dropped = false; + bool writable_dropped = false; + bool value_ready = false; + std::vector value; + + struct PendingRead + { + std::shared_ptr cx; + uint32_t ptr = 0; + uint32_t handle_index = 0; + ReadableFutureEnd *endpoint = nullptr; + }; + + std::optional pending_read; + }; + + class ReadableFutureEnd : public Waitable + { + public: + explicit ReadableFutureEnd(std::shared_ptr shared) : shared_(std::move(shared)) {} + + const FutureDescriptor &descriptor() const + { + return shared_->descriptor; + } + + uint32_t read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, bool sync, const HostTrap &trap); + uint32_t cancel(bool sync, const HostTrap &trap); + void drop(const HostTrap &trap); + void complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); + + private: + std::shared_ptr shared_; + CopyState state_ = CopyState::Idle; + }; + + class WritableFutureEnd : public Waitable + { + public: + explicit WritableFutureEnd(std::shared_ptr shared) : shared_(std::move(shared)) {} + + const FutureDescriptor &descriptor() const + { + return shared_->descriptor; + } + + uint32_t write(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, const HostTrap &trap); + uint32_t cancel(bool sync, const HostTrap &trap); + void drop(const HostTrap &trap); + + private: + std::shared_ptr shared_; + CopyState state_ = CopyState::Idle; + }; + + inline uint32_t ReadableFutureEnd::read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, bool sync, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !shared_, "future state missing"); + trap_if(trap_cx, shared_->descriptor.element_size == 0, "invalid future descriptor"); + trap_if(trap_cx, state_ != CopyState::Idle, "future read busy"); + + if (shared_->value_ready) + { + ensure_memory_range(*cx, ptr, 1, shared_->descriptor.alignment, shared_->descriptor.element_size); + std::memcpy(cx->opts.memory.data() + ptr, shared_->value.data(), shared_->descriptor.element_size); + set_pending_event({EventCode::FUTURE_READ, handle_index, pack_copy_result(CopyResult::Completed, 1)}); + auto event = get_pending_event(trap); + state_ = CopyState::Idle; + return event.payload; + } + + if (shared_->writable_dropped) + { + set_pending_event({EventCode::FUTURE_READ, handle_index, pack_copy_result(CopyResult::Dropped, 0)}); + auto event = get_pending_event(trap); + state_ = CopyState::Done; + return event.payload; + } + + trap_if(trap_cx, sync, "sync future read would block"); + shared_->pending_read = SharedFutureState::PendingRead{cx, ptr, handle_index, this}; + state_ = CopyState::Copying; + return BLOCKED; + } + + inline uint32_t ReadableFutureEnd::cancel(bool sync, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ != CopyState::Copying, "no pending future read"); + trap_if(trap_cx, !shared_ || !shared_->pending_read, "no pending future read"); + + auto pending = std::move(*shared_->pending_read); + shared_->pending_read.reset(); + set_pending_event({EventCode::FUTURE_READ, pending.handle_index, pack_copy_result(CopyResult::Cancelled, 0)}); + state_ = CopyState::Done; + + if (sync) + { + auto event = get_pending_event(trap); + return event.payload; + } + return BLOCKED; + } + + inline void ReadableFutureEnd::drop(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ == CopyState::Copying, "cannot drop pending future read"); + trap_if(trap_cx, shared_ && shared_->pending_read.has_value(), "pending future read must complete before drop"); + if (shared_) + { + shared_->readable_dropped = true; + } + state_ = CopyState::Done; + Waitable::drop(trap); + } + + inline void ReadableFutureEnd::complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) + { + set_pending_event({EventCode::FUTURE_READ, handle_index, pack_copy_result(result, progress)}); + state_ = (result == CopyResult::Completed) ? CopyState::Idle : CopyState::Done; + } + + inline uint32_t WritableFutureEnd::write(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !shared_, "future state missing"); + trap_if(trap_cx, shared_->descriptor.element_size == 0, "invalid future descriptor"); + trap_if(trap_cx, shared_->value_ready, "future already resolved"); + + ensure_memory_range(*cx, ptr, 1, shared_->descriptor.alignment, shared_->descriptor.element_size); + std::memcpy(shared_->value.data(), cx->opts.memory.data() + ptr, shared_->descriptor.element_size); + shared_->value_ready = true; + + if (shared_->pending_read) + { + auto pending = std::move(*shared_->pending_read); + shared_->pending_read.reset(); + ensure_memory_range(*pending.cx, pending.ptr, 1, shared_->descriptor.alignment, shared_->descriptor.element_size); + std::memcpy(pending.cx->opts.memory.data() + pending.ptr, shared_->value.data(), shared_->descriptor.element_size); + pending.endpoint->complete_async(pending.handle_index, CopyResult::Completed, 1, trap); + } + + set_pending_event({EventCode::FUTURE_WRITE, handle_index, pack_copy_result(CopyResult::Completed, 1)}); + auto event = get_pending_event(trap); + state_ = CopyState::Idle; + return event.payload; + } + + inline uint32_t WritableFutureEnd::cancel(bool, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, true, "no pending future write"); + return BLOCKED; + } + + inline void WritableFutureEnd::drop(const HostTrap &trap) + { + if (shared_ && !shared_->value_ready) + { + if (shared_->pending_read) + { + auto pending = std::move(*shared_->pending_read); + shared_->pending_read.reset(); + pending.endpoint->complete_async(pending.handle_index, CopyResult::Dropped, 0, trap); + } + shared_->writable_dropped = true; + } + state_ = CopyState::Done; + Waitable::drop(trap); + } + struct ResourceType { ComponentInstance *impl = nullptr; @@ -220,14 +968,6 @@ namespace cmcpp std::unordered_map tables_; }; - struct Waitable - { - }; - - struct WaitableSet - { - }; - struct ErrorContext { }; @@ -236,7 +976,11 @@ namespace cmcpp { bool may_leave = true; bool may_enter = true; + bool exclusive = false; + uint32_t backpressure = 0; + uint32_t num_waiting_to_enter = 0; HandleTables handles; + InstanceTable table; }; class ContextLocalStorage @@ -322,6 +1066,191 @@ namespace cmcpp return element.rep; } + inline uint32_t canon_waitable_set_new(ComponentInstance &inst, const HostTrap &trap) + { + auto wset = std::make_shared(); + return inst.table.add(wset, trap); + } + + inline uint32_t canon_waitable_set_wait(bool /*cancellable*/, GuestMemory mem, ComponentInstance &inst, uint32_t set_index, uint32_t ptr, const HostTrap &trap) + { + auto wset = inst.table.get(set_index, trap); + wset->begin_wait(); + if (!wset->has_pending_event()) + { + wset->end_wait(); + write_event_fields(mem, ptr, 0, 0, trap); + return BLOCKED; + } + auto event = wset->take_pending_event(trap); + wset->end_wait(); + write_event_fields(mem, ptr, event.index, event.payload, trap); + return static_cast(event.code); + } + + inline uint32_t canon_waitable_set_poll(bool /*cancellable*/, GuestMemory mem, ComponentInstance &inst, uint32_t set_index, uint32_t ptr, const HostTrap &trap) + { + auto wset = inst.table.get(set_index, trap); + if (!wset->has_pending_event()) + { + write_event_fields(mem, ptr, 0, 0, trap); + return static_cast(EventCode::NONE); + } + auto event = wset->take_pending_event(trap); + write_event_fields(mem, ptr, event.index, event.payload, trap); + return static_cast(event.code); + } + + inline void canon_waitable_set_drop(ComponentInstance &inst, uint32_t set_index, const HostTrap &trap) + { + auto wset = inst.table.remove(set_index, trap); + wset->drop(trap); + } + + inline void canon_waitable_join(ComponentInstance &inst, uint32_t waitable_index, uint32_t set_index, const HostTrap &trap) + { + auto waitable = inst.table.get(waitable_index, trap); + if (set_index == 0) + { + waitable->join(nullptr, trap); + return; + } + auto wset = inst.table.get(set_index, trap); + waitable->join(wset.get(), trap); + } + + inline uint64_t canon_stream_new(ComponentInstance &inst, const StreamDescriptor &descriptor, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, descriptor.element_size == 0, "stream descriptor invalid"); + auto shared = std::make_shared(descriptor); + auto readable = std::make_shared(shared); + auto writable = std::make_shared(shared); + uint32_t readable_index = inst.table.add(readable, trap); + uint32_t writable_index = inst.table.add(writable, trap); + return (static_cast(writable_index) << 32) | readable_index; + } + + inline uint32_t canon_stream_read(ComponentInstance &inst, + const StreamDescriptor &descriptor, + uint32_t readable_index, + const std::shared_ptr &cx, + uint32_t ptr, + uint32_t n, + bool sync, + const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !cx, "lift/lower context required"); + auto readable = inst.table.get(readable_index, trap); + validate_descriptor(descriptor, readable->descriptor(), trap); + return readable->read(cx, readable_index, ptr, n, sync, trap); + } + + inline uint32_t canon_stream_write(ComponentInstance &inst, + const StreamDescriptor &descriptor, + uint32_t writable_index, + const std::shared_ptr &cx, + uint32_t ptr, + uint32_t n, + const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !cx, "lift/lower context required"); + auto writable = inst.table.get(writable_index, trap); + validate_descriptor(descriptor, writable->descriptor(), trap); + return writable->write(cx, writable_index, ptr, n, trap); + } + + inline uint32_t canon_stream_cancel_read(ComponentInstance &inst, uint32_t readable_index, bool sync, const HostTrap &trap) + { + auto readable = inst.table.get(readable_index, trap); + return readable->cancel(sync, trap); + } + + inline uint32_t canon_stream_cancel_write(ComponentInstance &inst, uint32_t writable_index, bool sync, const HostTrap &trap) + { + auto writable = inst.table.get(writable_index, trap); + return writable->cancel(sync, trap); + } + + inline void canon_stream_drop_readable(ComponentInstance &inst, uint32_t readable_index, const HostTrap &trap) + { + auto readable = inst.table.remove(readable_index, trap); + readable->drop(trap); + } + + inline void canon_stream_drop_writable(ComponentInstance &inst, uint32_t writable_index, const HostTrap &trap) + { + auto writable = inst.table.remove(writable_index, trap); + writable->drop(trap); + } + + inline uint64_t canon_future_new(ComponentInstance &inst, const FutureDescriptor &descriptor, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, descriptor.element_size == 0, "future descriptor invalid"); + auto shared = std::make_shared(descriptor); + auto readable = std::make_shared(shared); + auto writable = std::make_shared(shared); + uint32_t readable_index = inst.table.add(readable, trap); + uint32_t writable_index = inst.table.add(writable, trap); + return (static_cast(writable_index) << 32) | readable_index; + } + + inline uint32_t canon_future_read(ComponentInstance &inst, + const FutureDescriptor &descriptor, + uint32_t readable_index, + const std::shared_ptr &cx, + uint32_t ptr, + bool sync, + const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !cx, "lift/lower context required"); + auto readable = inst.table.get(readable_index, trap); + validate_descriptor(descriptor, readable->descriptor(), trap); + return readable->read(cx, readable_index, ptr, sync, trap); + } + + inline uint32_t canon_future_write(ComponentInstance &inst, + const FutureDescriptor &descriptor, + uint32_t writable_index, + const std::shared_ptr &cx, + uint32_t ptr, + const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !cx, "lift/lower context required"); + auto writable = inst.table.get(writable_index, trap); + validate_descriptor(descriptor, writable->descriptor(), trap); + return writable->write(cx, writable_index, ptr, trap); + } + + inline uint32_t canon_future_cancel_read(ComponentInstance &inst, uint32_t readable_index, bool sync, const HostTrap &trap) + { + auto readable = inst.table.get(readable_index, trap); + return readable->cancel(sync, trap); + } + + inline uint32_t canon_future_cancel_write(ComponentInstance &inst, uint32_t writable_index, bool sync, const HostTrap &trap) + { + auto writable = inst.table.get(writable_index, trap); + return writable->cancel(sync, trap); + } + + inline void canon_future_drop_readable(ComponentInstance &inst, uint32_t readable_index, const HostTrap &trap) + { + auto readable = inst.table.remove(readable_index, trap); + readable->drop(trap); + } + + inline void canon_future_drop_writable(ComponentInstance &inst, uint32_t writable_index, const HostTrap &trap) + { + auto writable = inst.table.remove(writable_index, trap); + writable->drop(trap); + } + // ---------------------------- struct InstanceContext diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index 5489776..0000000 --- a/run_tests.py +++ /dev/null @@ -1,469 +0,0 @@ -import definitions -from definitions import * - -def equal_modulo_string_encoding(s, t): - if s is None and t is None: - return True - if isinstance(s, (bool,int,float,str)) and isinstance(t, (bool,int,float,str)): - return s == t - if isinstance(s, tuple) and isinstance(t, tuple): - assert(isinstance(s[0], str)) - assert(isinstance(t[0], str)) - return s[0] == t[0] - if isinstance(s, dict) and isinstance(t, dict): - return all(equal_modulo_string_encoding(sv,tv) for sv,tv in zip(s.values(), t.values(), strict=True)) - if isinstance(s, list) and isinstance(t, list): - return all(equal_modulo_string_encoding(sv,tv) for sv,tv in zip(s, t, strict=True)) - assert(False) - -class Heap: - def __init__(self, arg): - self.memory = bytearray(arg) - self.last_alloc = 0 - - def realloc(self, original_ptr, original_size, alignment, new_size): - if original_ptr != 0 and new_size < original_size: - return align_to(original_ptr, alignment) - ret = align_to(self.last_alloc, alignment) - self.last_alloc = ret + new_size - if self.last_alloc > len(self.memory): - print('oom: have {} need {}'.format(len(self.memory), self.last_alloc)) - trap() - self.memory[ret : ret + original_size] = self.memory[original_ptr : original_ptr + original_size] - return ret - -def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): - opts = CanonicalOptions() - opts.memory = memory - opts.string_encoding = encoding - opts.realloc = realloc - opts.post_return = post_return - return opts - -def mk_cx(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): - opts = mk_opts(memory, encoding, realloc, post_return) - return LiftLowerContext(opts, ComponentInstance()) - -def mk_str(s): - return (s, 'utf8', len(s.encode('utf-8'))) - -def mk_tup(*a): - def mk_tup_rec(x): - if isinstance(x, list): - return { str(i):mk_tup_rec(v) for i,v in enumerate(x) } - return x - return { str(i):mk_tup_rec(v) for i,v in enumerate(a) } - -def fail(msg): - raise BaseException(msg) - -def test(t, vals_to_lift, v, - cx = mk_cx(), - dst_encoding = None, - lower_t = None, - lower_v = None): - def test_name(): - return "test({},{},{}):".format(t, vals_to_lift, v) - - vi = ValueIter([Value(ft, v) for ft,v in zip(flatten_type(t), vals_to_lift, strict=True)]) - - if v is None: - try: - got = lift_flat(cx, vi, t) - fail("{} expected trap, but got {}".format(test_name(), got)) - except Trap: - return - - got = lift_flat(cx, vi, t) - assert(vi.i == len(vi.values)) - if got != v: - fail("{} initial lift_flat() expected {} but got {}".format(test_name(), v, got)) - - if lower_t is None: - lower_t = t - if lower_v is None: - lower_v = v - - heap = Heap(5*len(cx.opts.memory)) - if dst_encoding is None: - dst_encoding = cx.opts.string_encoding - cx = mk_cx(heap.memory, dst_encoding, heap.realloc) - lowered_vals = lower_flat(cx, v, lower_t) - assert(flatten_type(lower_t) == list(map(lambda v: v.t, lowered_vals))) - - vi = ValueIter(lowered_vals) - got = lift_flat(cx, vi, lower_t) - if not equal_modulo_string_encoding(got, lower_v): - fail("{} re-lift expected {} but got {}".format(test_name(), lower_v, got)) - -# Empty record types are not permitted yet. -#test(Record([]), [], {}) -test(Record([Field('x',U8()), Field('y',U16()), Field('z',U32())]), [1,2,3], {'x':1,'y':2,'z':3}) -test(Tuple([Tuple([U8(),U8()]),U8()]), [1,2,3], {'0':{'0':1,'1':2},'1':3}) -# Empty flags types are not permitted yet. -#t = Flags([]) -#test(t, [], {}) -t = Flags(['a','b']) -test(t, [0], {'a':False,'b':False}) -test(t, [2], {'a':False,'b':True}) -test(t, [3], {'a':True,'b':True}) -test(t, [4], {'a':False,'b':False}) -test(Flags([str(i) for i in range(33)]), [0xffffffff,0x1], { str(i):True for i in range(33) }) -t = Variant([Case('x',U8()),Case('y',Float32()),Case('z',None)]) -test(t, [0,42], {'x': 42}) -test(t, [0,256], {'x': 0}) -test(t, [1,0x4048f5c3], {'y': 3.140000104904175}) -test(t, [2,0xffffffff], {'z': None}) -t = Option(Float32()) -test(t, [0,3.14], {'none':None}) -test(t, [1,3.14], {'some':3.14}) -t = Result(U8(),U32()) -test(t, [0, 42], {'ok':42}) -test(t, [1, 1000], {'error':1000}) -t = Variant([Case('w',U8()), Case('x',U8(),'w'), Case('y',U8()), Case('z',U8(),'x')]) -test(t, [0, 42], {'w':42}) -test(t, [1, 42], {'x|w':42}) -test(t, [2, 42], {'y':42}) -test(t, [3, 42], {'z|x|w':42}) -t2 = Variant([Case('w',U8())]) -test(t, [0, 42], {'w':42}, lower_t=t2, lower_v={'w':42}) -test(t, [1, 42], {'x|w':42}, lower_t=t2, lower_v={'w':42}) -test(t, [3, 42], {'z|x|w':42}, lower_t=t2, lower_v={'w':42}) - -def test_pairs(t, pairs): - for arg,expect in pairs: - test(t, [arg], expect) - -test_pairs(Bool(), [(0,False),(1,True),(2,True),(4294967295,True)]) -test_pairs(U8(), [(127,127),(128,128),(255,255),(256,0), - (4294967295,255),(4294967168,128),(4294967167,127)]) -test_pairs(S8(), [(127,127),(128,-128),(255,-1),(256,0), - (4294967295,-1),(4294967168,-128),(4294967167,127)]) -test_pairs(U16(), [(32767,32767),(32768,32768),(65535,65535),(65536,0), - ((1<<32)-1,65535),((1<<32)-32768,32768),((1<<32)-32769,32767)]) -test_pairs(S16(), [(32767,32767),(32768,-32768),(65535,-1),(65536,0), - ((1<<32)-1,-1),((1<<32)-32768,-32768),((1<<32)-32769,32767)]) -test_pairs(U32(), [((1<<31)-1,(1<<31)-1),(1<<31,1<<31),(((1<<32)-1),(1<<32)-1)]) -test_pairs(S32(), [((1<<31)-1,(1<<31)-1),(1<<31,-(1<<31)),((1<<32)-1,-1)]) -test_pairs(U64(), [((1<<63)-1,(1<<63)-1), (1<<63,1<<63), ((1<<64)-1,(1<<64)-1)]) -test_pairs(S64(), [((1<<63)-1,(1<<63)-1), (1<<63,-(1<<63)), ((1<<64)-1,-1)]) -test_pairs(Float32(), [(3.14,3.14)]) -test_pairs(Float64(), [(3.14,3.14)]) -test_pairs(Char(), [(0,'\x00'), (65,'A'), (0xD7FF,'\uD7FF'), (0xD800,None), (0xDFFF,None)]) -test_pairs(Char(), [(0xE000,'\uE000'), (0x10FFFF,'\U0010FFFF'), (0x110000,None), (0xFFFFFFFF,None)]) -test_pairs(Enum(['a','b']), [(0,{'a':None}), (1,{'b':None}), (2,None)]) - -def test_nan32(inbits, outbits): - origf = decode_i32_as_float(inbits) - f = lift_flat(mk_cx(), ValueIter([Value('f32', origf)]), Float32()) - if DETERMINISTIC_PROFILE: - assert(encode_float_as_i32(f) == outbits) - else: - assert(not math.isnan(origf) or math.isnan(f)) - cx = mk_cx(int.to_bytes(inbits, 4, 'little')) - f = load(cx, 0, Float32()) - if DETERMINISTIC_PROFILE: - assert(encode_float_as_i32(f) == outbits) - else: - assert(not math.isnan(origf) or math.isnan(f)) - -def test_nan64(inbits, outbits): - origf = decode_i64_as_float(inbits) - f = lift_flat(mk_cx(), ValueIter([Value('f64', origf)]), Float64()) - if DETERMINISTIC_PROFILE: - assert(encode_float_as_i64(f) == outbits) - else: - assert(not math.isnan(origf) or math.isnan(f)) - cx = mk_cx(int.to_bytes(inbits, 8, 'little')) - f = load(cx, 0, Float64()) - if DETERMINISTIC_PROFILE: - assert(encode_float_as_i64(f) == outbits) - else: - assert(not math.isnan(origf) or math.isnan(f)) - -test_nan32(0x7fc00000, CANONICAL_FLOAT32_NAN) -test_nan32(0x7fc00001, CANONICAL_FLOAT32_NAN) -test_nan32(0x7fe00000, CANONICAL_FLOAT32_NAN) -test_nan32(0x7fffffff, CANONICAL_FLOAT32_NAN) -test_nan32(0xffffffff, CANONICAL_FLOAT32_NAN) -test_nan32(0x7f800000, 0x7f800000) -test_nan32(0x3fc00000, 0x3fc00000) -test_nan64(0x7ff8000000000000, CANONICAL_FLOAT64_NAN) -test_nan64(0x7ff8000000000001, CANONICAL_FLOAT64_NAN) -test_nan64(0x7ffc000000000000, CANONICAL_FLOAT64_NAN) -test_nan64(0x7fffffffffffffff, CANONICAL_FLOAT64_NAN) -test_nan64(0xffffffffffffffff, CANONICAL_FLOAT64_NAN) -test_nan64(0x7ff0000000000000, 0x7ff0000000000000) -test_nan64(0x3ff0000000000000, 0x3ff0000000000000) - -def test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units): - heap = Heap(len(encoded)) - heap.memory[:] = encoded[:] - cx = mk_cx(heap.memory, src_encoding) - v = (s, src_encoding, tagged_code_units) - test(String(), [0, tagged_code_units], v, cx, dst_encoding) - -def test_string(src_encoding, dst_encoding, s): - if src_encoding == 'utf8': - encoded = s.encode('utf-8') - tagged_code_units = len(encoded) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) - elif src_encoding == 'utf16': - encoded = s.encode('utf-16-le') - tagged_code_units = int(len(encoded) / 2) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) - elif src_encoding == 'latin1+utf16': - try: - encoded = s.encode('latin-1') - tagged_code_units = len(encoded) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) - except UnicodeEncodeError: - pass - encoded = s.encode('utf-16-le') - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) - -encodings = ['utf8', 'utf16', 'latin1+utf16'] - -fun_strings = ['', 'a', 'hi', '\x00', 'a\x00b', '\x80', '\x80b', 'ab\xefc', - '\u01ffy', 'xy\u01ff', 'a\ud7ffb', 'a\u02ff\u03ff\u04ffbc', - '\uf123', '\uf123\uf123abc', 'abcdef\uf123'] - -for src_encoding in encodings: - for dst_encoding in encodings: - for s in fun_strings: - test_string(src_encoding, dst_encoding, s) - -def test_heap(t, expect, args, byte_array): - heap = Heap(byte_array) - cx = mk_cx(heap.memory) - test(t, args, expect, cx) - -# Empty record types are not permitted yet. -#test_heap(List(Record([])), [{},{},{}], [0,3], []) -test_heap(List(Bool()), [True,False,True], [0,3], [1,0,1]) -test_heap(List(Bool()), [True,False,True], [0,3], [1,0,2]) -test_heap(List(Bool()), [True,False,True], [3,3], [0xff,0xff,0xff, 1,0,1]) -test_heap(List(U8()), [1,2,3], [0,3], [1,2,3]) -test_heap(List(U16()), [1,2,3], [0,3], [1,0, 2,0, 3,0 ]) -test_heap(List(U16()), None, [1,3], [0, 1,0, 2,0, 3,0 ]) -test_heap(List(U32()), [1,2,3], [0,3], [1,0,0,0, 2,0,0,0, 3,0,0,0]) -test_heap(List(U64()), [1,2], [0,2], [1,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0]) -test_heap(List(S8()), [-1,-2,-3], [0,3], [0xff,0xfe,0xfd]) -test_heap(List(S16()), [-1,-2,-3], [0,3], [0xff,0xff, 0xfe,0xff, 0xfd,0xff]) -test_heap(List(S32()), [-1,-2,-3], [0,3], [0xff,0xff,0xff,0xff, 0xfe,0xff,0xff,0xff, 0xfd,0xff,0xff,0xff]) -test_heap(List(S64()), [-1,-2], [0,2], [0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff, 0xfe,0xff,0xff,0xff,0xff,0xff,0xff,0xff]) -test_heap(List(Char()), ['A','B','c'], [0,3], [65,00,00,00, 66,00,00,00, 99,00,00,00]) -test_heap(List(String()), [mk_str("hi"),mk_str("wat")], [0,2], - [16,0,0,0, 2,0,0,0, 21,0,0,0, 3,0,0,0, - ord('h'), ord('i'), 0xf,0xf,0xf, ord('w'), ord('a'), ord('t')]) -test_heap(List(List(U8())), [[3,4,5],[],[6,7]], [0,3], - [24,0,0,0, 3,0,0,0, 0,0,0,0, 0,0,0,0, 27,0,0,0, 2,0,0,0, - 3,4,5, 6,7]) -test_heap(List(List(U16())), [[5,6]], [0,1], - [8,0,0,0, 2,0,0,0, - 5,0, 6,0]) -test_heap(List(List(U16())), None, [0,1], - [9,0,0,0, 2,0,0,0, - 0, 5,0, 6,0]) -test_heap(List(Tuple([U8(),U8(),U16(),U32()])), [mk_tup(6,7,8,9),mk_tup(4,5,6,7)], [0,2], - [6, 7, 8,0, 9,0,0,0, 4, 5, 6,0, 7,0,0,0]) -test_heap(List(Tuple([U8(),U16(),U8(),U32()])), [mk_tup(6,7,8,9),mk_tup(4,5,6,7)], [0,2], - [6,0xff, 7,0, 8,0xff,0xff,0xff, 9,0,0,0, 4,0xff, 5,0, 6,0xff,0xff,0xff, 7,0,0,0]) -test_heap(List(Tuple([U16(),U8()])), [mk_tup(6,7),mk_tup(8,9)], [0,2], - [6,0, 7, 0x0ff, 8,0, 9, 0xff]) -test_heap(List(Tuple([Tuple([U16(),U8()]),U8()])), [mk_tup([4,5],6),mk_tup([7,8],9)], [0,2], - [4,0, 5,0xff, 6,0xff, 7,0, 8,0xff, 9,0xff]) -# Empty flags types are not permitted yet. -#t = List(Flags([])) -#test_heap(t, [{},{},{}], [0,3], -# []) -#t = List(Tuple([Flags([]), U8()])) -#test_heap(t, [mk_tup({}, 42), mk_tup({}, 43), mk_tup({}, 44)], [0,3], -# [42,43,44]) -t = List(Flags(['a','b'])) -test_heap(t, [{'a':False,'b':False},{'a':False,'b':True},{'a':True,'b':True}], [0,3], - [0,2,3]) -test_heap(t, [{'a':False,'b':False},{'a':False,'b':True},{'a':False,'b':False}], [0,3], - [0,2,4]) -t = List(Flags([str(i) for i in range(9)])) -v = [{ str(i):b for i in range(9) } for b in [True,False]] -test_heap(t, v, [0,2], - [0xff,0x1, 0,0]) -test_heap(t, v, [0,2], - [0xff,0x3, 0,0]) -t = List(Flags([str(i) for i in range(17)])) -v = [{ str(i):b for i in range(17) } for b in [True,False]] -test_heap(t, v, [0,2], - [0xff,0xff,0x1,0, 0,0,0,0]) -test_heap(t, v, [0,2], - [0xff,0xff,0x3,0, 0,0,0,0]) -t = List(Flags([str(i) for i in range(33)])) -v = [{ str(i):b for i in range(33) } for b in [True,False]] -test_heap(t, v, [0,2], - [0xff,0xff,0xff,0xff,0x1,0,0,0, 0,0,0,0,0,0,0,0]) -test_heap(t, v, [0,2], - [0xff,0xff,0xff,0xff,0x3,0,0,0, 0,0,0,0,0,0,0,0]) - -def test_flatten(t, params, results): - expect = CoreFuncType(params, results) - - if len(params) > definitions.MAX_FLAT_PARAMS: - expect.params = ['i32'] - - if len(results) > definitions.MAX_FLAT_RESULTS: - expect.results = ['i32'] - got = flatten_functype(t, 'lift') - assert(got == expect) - - if len(results) > definitions.MAX_FLAT_RESULTS: - expect.params += ['i32'] - expect.results = [] - got = flatten_functype(t, 'lower') - assert(got == expect) - -test_flatten(FuncType([U8(),Float32(),Float64()],[]), ['i32','f32','f64'], []) -test_flatten(FuncType([U8(),Float32(),Float64()],[Float32()]), ['i32','f32','f64'], ['f32']) -test_flatten(FuncType([U8(),Float32(),Float64()],[U8()]), ['i32','f32','f64'], ['i32']) -test_flatten(FuncType([U8(),Float32(),Float64()],[Tuple([Float32()])]), ['i32','f32','f64'], ['f32']) -test_flatten(FuncType([U8(),Float32(),Float64()],[Tuple([Float32(),Float32()])]), ['i32','f32','f64'], ['f32','f32']) -test_flatten(FuncType([U8(),Float32(),Float64()],[Float32(),Float32()]), ['i32','f32','f64'], ['f32','f32']) -test_flatten(FuncType([U8() for _ in range(17)],[]), ['i32' for _ in range(17)], []) -test_flatten(FuncType([U8() for _ in range(17)],[Tuple([U8(),U8()])]), ['i32' for _ in range(17)], ['i32','i32']) - -def test_roundtrip(t, v): - before = definitions.MAX_FLAT_RESULTS - definitions.MAX_FLAT_RESULTS = 16 - - ft = FuncType([t],[t]) - callee = lambda x: x - - callee_heap = Heap(1000) - callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc, lambda x: () ) - callee_inst = ComponentInstance() - lifted_callee = lambda args: canon_lift(callee_opts, callee_inst, callee, ft, args) - - caller_heap = Heap(1000) - caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) - caller_inst = ComponentInstance() - caller_cx = LiftLowerContext(caller_opts, caller_inst) - - flat_args = lower_flat(caller_cx, v, t) - flat_results = canon_lower(caller_opts, caller_inst, lifted_callee, True, ft, flat_args) - got = lift_flat(caller_cx, ValueIter(flat_results), t) - - if got != v: - fail("test_roundtrip({},{},{}) got {}".format(t, v, caller_args, got)) - - assert(caller_inst.may_leave and caller_inst.may_enter) - assert(callee_inst.may_leave and callee_inst.may_enter) - definitions.MAX_FLAT_RESULTS = before - -test_roundtrip(S8(), -1) -test_roundtrip(Tuple([U16(),U16()]), mk_tup(3,4)) -test_roundtrip(List(String()), [mk_str("hello there")]) -test_roundtrip(List(List(String())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) -test_roundtrip(List(Option(Tuple([String(),U16()]))), [{'some':mk_tup(mk_str("answer"),42)}]) - -def test_handles(): - before = definitions.MAX_FLAT_RESULTS - definitions.MAX_FLAT_RESULTS = 16 - - dtor_value = None - def dtor(x): - nonlocal dtor_value - dtor_value = x - - rt = ResourceType(ComponentInstance(), dtor) # usable in imports and exports - - inst = ComponentInstance() - rt2 = ResourceType(inst, dtor) # only usable in exports - opts = mk_opts() - - def host_import(args): - assert(len(args) == 2) - assert(args[0] == 42) - assert(args[1] == 44) - return ([45], lambda:()) - - def core_wasm(args): - nonlocal dtor_value - - assert(len(args) == 4) - assert(len(inst.handles.table(rt).array) == 4) - assert(inst.handles.table(rt).array[0] is None) - assert(args[0].t == 'i32' and args[0].v == 1) - assert(args[1].t == 'i32' and args[1].v == 2) - assert(args[2].t == 'i32' and args[2].v == 3) - assert(args[3].t == 'i32' and args[3].v == 13) - assert(canon_resource_rep(inst, rt, 1) == 42) - assert(canon_resource_rep(inst, rt, 2) == 43) - assert(canon_resource_rep(inst, rt, 3) == 44) - - host_ft = FuncType([ - Borrow(rt), - Borrow(rt) - ],[ - Own(rt) - ]) - args = [ - Value('i32',1), - Value('i32',3) - ] - results = canon_lower(opts, inst, host_import, True, host_ft, args) - assert(len(results) == 1) - assert(results[0].t == 'i32' and results[0].v == 4) - assert(canon_resource_rep(inst, rt, 4) == 45) - - dtor_value = None - canon_resource_drop(inst, rt, 1) - assert(dtor_value == 42) - assert(len(inst.handles.table(rt).array) == 5) - assert(inst.handles.table(rt).array[1] is None) - assert(len(inst.handles.table(rt).free) == 1) - - h = canon_resource_new(inst, rt, 46) - assert(h == 1) - assert(len(inst.handles.table(rt).array) == 5) - assert(inst.handles.table(rt).array[1] is not None) - assert(len(inst.handles.table(rt).free) == 0) - - dtor_value = None - canon_resource_drop(inst, rt, 3) - assert(dtor_value is None) - assert(len(inst.handles.table(rt).array) == 5) - assert(inst.handles.table(rt).array[3] is None) - assert(len(inst.handles.table(rt).free) == 1) - - return [Value('i32', 1), Value('i32', 2), Value('i32', 4)] - - ft = FuncType([ - Own(rt), - Own(rt), - Borrow(rt), - Borrow(rt2) - ],[ - Own(rt), - Own(rt), - Own(rt) - ]) - args = [ - 42, - 43, - 44, - 13 - ] - got,post_return = canon_lift(opts, inst, core_wasm, ft, args) - - assert(len(got) == 3) - assert(got[0] == 46) - assert(got[1] == 43) - assert(got[2] == 45) - assert(len(inst.handles.table(rt).array) == 5) - assert(all(inst.handles.table(rt).array[i] is None for i in range(4))) - assert(len(inst.handles.table(rt).free) == 4) - definitions.MAX_FLAT_RESULTS = before - -test_handles() - -print("All tests passed") diff --git a/test/main.cpp b/test/main.cpp index b39015b..44b692a 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -15,6 +15,7 @@ using namespace cmcpp; #include #include #include +#include // #include #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN @@ -597,7 +598,6 @@ TEST_CASE("Float Special Values - Enhanced") auto v_zero = lower_flat(*cx, zero); auto result_zero = lift_flat(*cx, v_zero); CHECK(result_zero == 0.0f); - float32_t neg_zero = -0.0f; auto v_neg_zero = lower_flat(*cx, neg_zero); auto result_neg_zero = lift_flat(*cx, v_neg_zero); @@ -620,6 +620,163 @@ TEST_CASE("Float Special Values - Enhanced") CHECK(result_max64 == max_val64); } +TEST_CASE("Waitable set surfaces stream readiness") +{ + ComponentInstance inst; + HostTrap host_trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + auto desc = make_stream_descriptor(); + uint64_t handles = canon_stream_new(inst, desc, host_trap); + uint32_t readable = static_cast(handles & 0xFFFFFFFFu); + uint32_t writable = static_cast(handles >> 32); + + uint32_t waitable_set = canon_waitable_set_new(inst, host_trap); + canon_waitable_join(inst, readable, waitable_set, host_trap); + + Heap heap(256); + auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + { delete ptr; }); + + uint32_t read_ptr = 0; + uint32_t write_ptr = 32; + uint32_t event_ptr = 128; + + int32_t to_write[2] = {42, 87}; + std::memcpy(heap.memory.data() + write_ptr, to_write, sizeof(to_write)); + + auto blocked = canon_stream_read(inst, desc, readable, cx, read_ptr, 2, false, host_trap); + CHECK(blocked == BLOCKED); + + GuestMemory mem(heap.memory.data(), heap.memory.size()); + auto code = canon_waitable_set_poll(false, mem, inst, waitable_set, event_ptr, host_trap); + CHECK(code == static_cast(EventCode::NONE)); + + auto write_payload = canon_stream_write(inst, desc, writable, cx, write_ptr, 2, host_trap); + CHECK((write_payload & 0xF) == static_cast(CopyResult::Completed)); + auto write_count = write_payload >> 4; + CHECK(write_count == 2); + + code = canon_waitable_set_poll(false, mem, inst, waitable_set, event_ptr, host_trap); + CHECK(code == static_cast(EventCode::STREAM_READ)); + + uint32_t reported_index = 0; + uint32_t payload = 0; + std::memcpy(&reported_index, heap.memory.data() + event_ptr, sizeof(uint32_t)); + std::memcpy(&payload, heap.memory.data() + event_ptr + sizeof(uint32_t), sizeof(uint32_t)); + CHECK(reported_index == readable); + CHECK((payload & 0xF) == static_cast(CopyResult::Completed)); + auto read_count = payload >> 4; + CHECK(read_count == 2); + + int32_t read_values[2] = {0, 0}; + std::memcpy(read_values, heap.memory.data() + read_ptr, sizeof(read_values)); + CHECK(read_values[0] == 42); + CHECK(read_values[1] == 87); + + canon_stream_drop_readable(inst, readable, host_trap); + canon_stream_drop_writable(inst, writable, host_trap); + canon_waitable_set_drop(inst, waitable_set, host_trap); +} + +TEST_CASE("Stream cancellation posts events") +{ + ComponentInstance inst; + HostTrap host_trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + auto desc = make_stream_descriptor(); + uint64_t handles = canon_stream_new(inst, desc, host_trap); + uint32_t readable = static_cast(handles & 0xFFFFFFFFu); + + uint32_t waitable_set = canon_waitable_set_new(inst, host_trap); + canon_waitable_join(inst, readable, waitable_set, host_trap); + + Heap heap(128); + auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + { delete ptr; }); + + uint32_t read_ptr = 0; + uint32_t event_ptr = 64; + + auto blocked = canon_stream_read(inst, desc, readable, cx, read_ptr, 1, false, host_trap); + CHECK(blocked == BLOCKED); + + auto cancel_payload = canon_stream_cancel_read(inst, readable, false, host_trap); + CHECK(cancel_payload == BLOCKED); + + GuestMemory mem(heap.memory.data(), heap.memory.size()); + auto code = canon_waitable_set_poll(false, mem, inst, waitable_set, event_ptr, host_trap); + CHECK(code == static_cast(EventCode::STREAM_READ)); + + uint32_t payload = 0; + std::memcpy(&payload, heap.memory.data() + event_ptr + sizeof(uint32_t), sizeof(uint32_t)); + CHECK((payload & 0xF) == static_cast(CopyResult::Cancelled)); + auto cancel_count = payload >> 4; + CHECK(cancel_count == 0); + + canon_stream_drop_readable(inst, readable, host_trap); + canon_waitable_set_drop(inst, waitable_set, host_trap); +} + +TEST_CASE("Future lifecycle completes") +{ + ComponentInstance inst; + HostTrap host_trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + auto desc = make_future_descriptor(); + uint64_t handles = canon_future_new(inst, desc, host_trap); + uint32_t readable = static_cast(handles & 0xFFFFFFFFu); + uint32_t writable = static_cast(handles >> 32); + + uint32_t waitable_set = canon_waitable_set_new(inst, host_trap); + canon_waitable_join(inst, readable, waitable_set, host_trap); + + Heap heap(256); + auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + { delete ptr; }); + + uint32_t read_ptr = 0; + uint32_t write_ptr = 32; + uint32_t event_ptr = 96; + + auto read_blocked = canon_future_read(inst, desc, readable, cx, read_ptr, false, host_trap); + CHECK(read_blocked == BLOCKED); + + int32_t value = 99; + std::memcpy(heap.memory.data() + write_ptr, &value, sizeof(int32_t)); + + auto write_payload = canon_future_write(inst, desc, writable, cx, write_ptr, host_trap); + CHECK((write_payload & 0xF) == static_cast(CopyResult::Completed)); + auto write_count = write_payload >> 4; + CHECK(write_count == 1); + + GuestMemory mem(heap.memory.data(), heap.memory.size()); + auto code = canon_waitable_set_poll(false, mem, inst, waitable_set, event_ptr, host_trap); + CHECK(code == static_cast(EventCode::FUTURE_READ)); + + uint32_t payload = 0; + std::memcpy(&payload, heap.memory.data() + event_ptr + sizeof(uint32_t), sizeof(uint32_t)); + CHECK((payload & 0xF) == static_cast(CopyResult::Completed)); + auto read_count = payload >> 4; + CHECK(read_count == 1); + + int32_t observed = 0; + std::memcpy(&observed, heap.memory.data() + read_ptr, sizeof(int32_t)); + CHECK(observed == value); + + canon_future_drop_readable(inst, readable, host_trap); + canon_future_drop_writable(inst, writable, host_trap); + canon_waitable_set_drop(inst, waitable_set, host_trap); +} + const char *const hw = "hello World!"; const char *const hw8 = "hello 世界-🌍-!"; const char16_t *hw16 = u"hello 世界-🌍-!"; From cf2d0bbdf7a27caf77a500bd2dcf633927407efd Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 10:10:20 +0100 Subject: [PATCH 3/6] fix: add backpressure and task lifecycle Signed-off-by: Gordon Smith --- README.md | 9 +- include/cmcpp/context.hpp | 312 ++++++++++++++++++++++++++++++++++++++ include/cmcpp/runtime.hpp | 195 +++++++++++++++++++++--- test/main.cpp | 117 ++++++++++++++ 4 files changed, 607 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 8466642..5273bde 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,7 @@ This repository contains a C++ ABI implementation of the WebAssembly Component M - [x] F32 - [x] F64 - [x] Char -- [x] String -- [x] utf8 String -- [x] utf16 String -- [x] latin1+utf16 String +- [x] Strings (UTF-8, UTF-16, Latin-1+UTF-16) - [x] List - [x] Record - [x] Tuple @@ -46,6 +43,8 @@ This repository contains a C++ ABI implementation of the WebAssembly Component M - [x] Option - [x] Result - [x] Flags +- [x] Streams (readable/writable) +- [x] Futures (readable/writable) - [ ] Own - [ ] Borrow @@ -181,6 +180,8 @@ The canonical Component Model runtime is cooperative: hosts must drive pending w - `FuncInst` is the callable signature hosts use to wrap guest functions. - `Thread::create` builds a pending task with user-supplied readiness/resume callbacks. - `Call::from_thread` returns a cancellation-capable handle to the caller. +- `Task` coordinates canonical backpressure, `canon_task.{return,cancel}`, and `canon_yield` helpers exposed through `context.hpp`. +- `canon_backpressure_{set,inc,dec}` update in-flight counters; most canonical entry points now guard `ComponentInstance::may_leave` before touching guest state. Typical usage: diff --git a/include/cmcpp/context.hpp b/include/cmcpp/context.hpp index b85df17..430c0fe 100644 --- a/include/cmcpp/context.hpp +++ b/include/cmcpp/context.hpp @@ -2,6 +2,7 @@ #define CMCPP_CONTEXT_HPP #include "traits.hpp" +#include "runtime.hpp" #include #include @@ -974,6 +975,7 @@ namespace cmcpp struct ComponentInstance { + Store *store = nullptr; bool may_leave = true; bool may_enter = true; bool exclusive = false; @@ -983,6 +985,297 @@ namespace cmcpp InstanceTable table; }; + inline void ensure_may_leave(ComponentInstance &inst, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, !inst.may_leave, "component may not leave"); + } + + inline void canon_backpressure_set(ComponentInstance &inst, bool enabled) + { + inst.backpressure = enabled ? 1u : 0u; + } + + inline void canon_backpressure_inc(ComponentInstance &inst, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, inst.backpressure >= 0x1'0000u, "backpressure overflow"); + inst.backpressure += 1; + } + + inline void canon_backpressure_dec(ComponentInstance &inst, const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, inst.backpressure == 0, "backpressure underflow"); + inst.backpressure -= 1; + } + + class Task : public std::enable_shared_from_this + { + public: + enum class State + { + Initial, + PendingCancel, + CancelDelivered, + Resolved + }; + + Task(ComponentInstance &instance, + CanonicalOptions options = {}, + SupertaskPtr supertask = {}, + OnResolve on_resolve = {}) + : opts_(std::move(options)), inst_(&instance), supertask_(std::move(supertask)), on_resolve_(std::move(on_resolve)) + { + } + + void set_thread(const std::shared_ptr &thread) + { + thread_ = thread; + if (thread_) + { + thread_->set_allow_cancellation(!opts_.sync); + thread_->set_in_event_loop(opts_.callback.has_value()); + if (inst_) + { + auto super = std::make_shared(); + super->instance = inst_; + super->thread = thread_; + super->parent = supertask_; + supertask_ = std::move(super); + } + } + } + + std::shared_ptr thread() const + { + return thread_; + } + + void set_on_resolve(OnResolve on_resolve) + { + on_resolve_ = std::move(on_resolve); + } + + bool enter(const HostTrap &trap) + { + auto *thread_ptr = thread_.get(); + auto *inst = inst_; + if (!thread_ptr || !inst) + { + return false; + } + + auto has_backpressure = [inst, this]() -> bool + { + return inst->backpressure > 0 || (needs_exclusive() && inst->exclusive); + }; + + if (has_backpressure() || inst->num_waiting_to_enter > 0) + { + inst->num_waiting_to_enter += 1; + bool completed = thread_ptr->suspend_until([has_backpressure]() + { return !has_backpressure(); }, + true); + inst->num_waiting_to_enter -= 1; + if (!completed) + { + if (state_ == State::CancelDelivered) + { + cancel(trap); + } + return false; + } + } + + if (needs_exclusive()) + { + inst->exclusive = true; + } + return true; + } + + void exit() + { + if (!inst_) + { + return; + } + if (needs_exclusive()) + { + inst_->exclusive = false; + } + } + + void request_cancellation() + { + if (state_ != State::Initial || !thread_) + { + return; + } + + if (ready_for_cancellation()) + { + state_ = State::CancelDelivered; + } + else + { + state_ = State::PendingCancel; + } + thread_->request_cancellation(); + } + + bool suspend_until(Thread::ReadyFn ready, bool cancellable) + { + if (cancellable && state_ == State::CancelDelivered) + { + return false; + } + if (cancellable && state_ == State::PendingCancel) + { + state_ = State::CancelDelivered; + return false; + } + if (!thread_) + { + return false; + } + bool completed = thread_->suspend_until(std::move(ready), cancellable); + if (!completed && cancellable && state_ == State::PendingCancel) + { + state_ = State::CancelDelivered; + } + return completed; + } + + Event yield_until(Thread::ReadyFn ready, bool cancellable) + { + if (!suspend_until(std::move(ready), cancellable)) + { + return {EventCode::TASK_CANCELLED, 0, 0}; + } + return {EventCode::NONE, 0, 0}; + } + + void return_result(std::vector result, const HostTrap &trap) + { + ensure_resolvable(trap); + if (on_resolve_) + { + on_resolve_(std::optional>(std::move(result))); + } + state_ = State::Resolved; + } + + void cancel(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ != State::CancelDelivered, "task cancellation not delivered"); + trap_if(trap_cx, num_borrows_ > 0, "task has outstanding borrows"); + if (on_resolve_) + { + on_resolve_(std::nullopt); + } + state_ = State::Resolved; + } + + State state() const + { + return state_; + } + + ComponentInstance *component_instance() const + { + return inst_; + } + + const CanonicalOptions &options() const + { + return opts_; + } + + void incr_borrows() + { + num_borrows_ += 1; + } + + void decr_borrows() + { + if (num_borrows_ > 0) + { + num_borrows_ -= 1; + } + } + + private: + bool needs_exclusive() const + { + return opts_.sync || opts_.callback.has_value(); + } + + void ensure_resolvable(const HostTrap &trap) + { + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, state_ == State::Resolved, "task already resolved"); + trap_if(trap_cx, num_borrows_ > 0, "task has outstanding borrows"); + } + + bool ready_for_cancellation() const + { + if (!thread_) + { + return false; + } + return thread_->cancellable() && !(thread_->in_event_loop() && inst_ && inst_->exclusive); + } + + CanonicalOptions opts_; + ComponentInstance *inst_ = nullptr; + SupertaskPtr supertask_; + OnResolve on_resolve_; + uint32_t num_borrows_ = 0; + std::shared_ptr thread_; + State state_ = State::Initial; + }; + + inline void canon_task_return(Task &task, std::vector result, const HostTrap &trap) + { + if (auto *inst = task.component_instance()) + { + ensure_may_leave(*inst, trap); + } + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, task.options().sync, "task.return requires async context"); + task.return_result(std::move(result), trap); + } + + inline void canon_task_cancel(Task &task, const HostTrap &trap) + { + if (auto *inst = task.component_instance()) + { + ensure_may_leave(*inst, trap); + } + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, task.options().sync, "task.cancel requires async context"); + task.cancel(trap); + } + + inline uint32_t canon_yield(bool cancellable, Task &task, const HostTrap &trap) + { + if (auto *inst = task.component_instance()) + { + ensure_may_leave(*inst, trap); + } + if (task.state() == Task::State::CancelDelivered || task.state() == Task::State::PendingCancel) + { + return 1u; + } + auto event = task.yield_until([] + { return true; }, + cancellable); + return event.code == EventCode::TASK_CANCELLED ? 1u : 0u; + } + class ContextLocalStorage { public: @@ -1068,12 +1361,14 @@ namespace cmcpp inline uint32_t canon_waitable_set_new(ComponentInstance &inst, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto wset = std::make_shared(); return inst.table.add(wset, trap); } inline uint32_t canon_waitable_set_wait(bool /*cancellable*/, GuestMemory mem, ComponentInstance &inst, uint32_t set_index, uint32_t ptr, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto wset = inst.table.get(set_index, trap); wset->begin_wait(); if (!wset->has_pending_event()) @@ -1090,6 +1385,7 @@ namespace cmcpp inline uint32_t canon_waitable_set_poll(bool /*cancellable*/, GuestMemory mem, ComponentInstance &inst, uint32_t set_index, uint32_t ptr, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto wset = inst.table.get(set_index, trap); if (!wset->has_pending_event()) { @@ -1103,12 +1399,14 @@ namespace cmcpp inline void canon_waitable_set_drop(ComponentInstance &inst, uint32_t set_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto wset = inst.table.remove(set_index, trap); wset->drop(trap); } inline void canon_waitable_join(ComponentInstance &inst, uint32_t waitable_index, uint32_t set_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto waitable = inst.table.get(waitable_index, trap); if (set_index == 0) { @@ -1121,6 +1419,7 @@ namespace cmcpp inline uint64_t canon_stream_new(ComponentInstance &inst, const StreamDescriptor &descriptor, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, descriptor.element_size == 0, "stream descriptor invalid"); auto shared = std::make_shared(descriptor); @@ -1140,6 +1439,7 @@ namespace cmcpp bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, !cx, "lift/lower context required"); auto readable = inst.table.get(readable_index, trap); @@ -1155,6 +1455,7 @@ namespace cmcpp uint32_t n, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, !cx, "lift/lower context required"); auto writable = inst.table.get(writable_index, trap); @@ -1164,30 +1465,35 @@ namespace cmcpp inline uint32_t canon_stream_cancel_read(ComponentInstance &inst, uint32_t readable_index, bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto readable = inst.table.get(readable_index, trap); return readable->cancel(sync, trap); } inline uint32_t canon_stream_cancel_write(ComponentInstance &inst, uint32_t writable_index, bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto writable = inst.table.get(writable_index, trap); return writable->cancel(sync, trap); } inline void canon_stream_drop_readable(ComponentInstance &inst, uint32_t readable_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto readable = inst.table.remove(readable_index, trap); readable->drop(trap); } inline void canon_stream_drop_writable(ComponentInstance &inst, uint32_t writable_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto writable = inst.table.remove(writable_index, trap); writable->drop(trap); } inline uint64_t canon_future_new(ComponentInstance &inst, const FutureDescriptor &descriptor, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, descriptor.element_size == 0, "future descriptor invalid"); auto shared = std::make_shared(descriptor); @@ -1206,6 +1512,7 @@ namespace cmcpp bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, !cx, "lift/lower context required"); auto readable = inst.table.get(readable_index, trap); @@ -1220,6 +1527,7 @@ namespace cmcpp uint32_t ptr, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto trap_cx = make_trap_context(trap); trap_if(trap_cx, !cx, "lift/lower context required"); auto writable = inst.table.get(writable_index, trap); @@ -1229,24 +1537,28 @@ namespace cmcpp inline uint32_t canon_future_cancel_read(ComponentInstance &inst, uint32_t readable_index, bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto readable = inst.table.get(readable_index, trap); return readable->cancel(sync, trap); } inline uint32_t canon_future_cancel_write(ComponentInstance &inst, uint32_t writable_index, bool sync, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto writable = inst.table.get(writable_index, trap); return writable->cancel(sync, trap); } inline void canon_future_drop_readable(ComponentInstance &inst, uint32_t readable_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto readable = inst.table.remove(readable_index, trap); readable->drop(trap); } inline void canon_future_drop_writable(ComponentInstance &inst, uint32_t writable_index, const HostTrap &trap) { + ensure_may_leave(inst, trap); auto writable = inst.table.remove(writable_index, trap); writable->drop(trap); } diff --git a/include/cmcpp/runtime.hpp b/include/cmcpp/runtime.hpp index 7ad5545..524188a 100644 --- a/include/cmcpp/runtime.hpp +++ b/include/cmcpp/runtime.hpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include #include @@ -13,6 +15,7 @@ namespace cmcpp { class Store; + struct ComponentInstance; struct Supertask; using SupertaskPtr = std::shared_ptr; @@ -30,6 +33,7 @@ namespace cmcpp static std::shared_ptr create(Store &store, ReadyFn ready, ResumeFn resume, bool cancellable = false, CancelFn on_cancel = {}); Thread(Store &store, ReadyFn ready, ResumeFn resume, bool cancellable, CancelFn on_cancel); + bool ready() const; void resume(); void request_cancellation(); @@ -37,6 +41,13 @@ namespace cmcpp bool cancelled() const; bool completed() const; + bool suspend_until(ReadyFn ready, bool cancellable, bool force_yield = false); + void set_ready(ReadyFn ready); + void set_allow_cancellation(bool allow); + bool allow_cancellation() const; + void set_in_event_loop(bool value); + bool in_event_loop() const; + private: enum class State { @@ -51,10 +62,13 @@ namespace cmcpp ReadyFn ready_; ResumeFn resume_; CancelFn on_cancel_; + bool allow_cancellation_; bool cancellable_; bool cancelled_; + bool in_event_loop_; mutable std::mutex mutex_; State state_; + std::atomic reschedule_requested_{false}; }; class Call @@ -63,38 +77,38 @@ namespace cmcpp using CancelRequest = std::function; Call() = default; - explicit Call(CancelRequest cancel) : request_cancellation_(std::move(cancel)) {} + explicit Call(CancelRequest cancel_req, bool cancellable = true) + : request_cancellation_(std::move(cancel_req)), cancellable_(cancellable) {} void request_cancellation() const { + if (!cancellable_) + { + return; + } if (request_cancellation_) { request_cancellation_(); } } - static Call from_thread(const std::shared_ptr &thread) + bool cancellable() const { - if (!thread) - { - return Call(); - } - std::weak_ptr weak = thread; - return Call([weak]() - { - if (auto locked = weak.lock()) - { - locked->request_cancellation(); - } }); + return cancellable_; } + static Call from_thread(const std::shared_ptr &thread); + private: CancelRequest request_cancellation_; + bool cancellable_ = false; }; struct Supertask { SupertaskPtr parent; + std::weak_ptr thread; + ComponentInstance *instance = nullptr; }; using FuncInst = std::function; @@ -106,11 +120,14 @@ namespace cmcpp void tick(); void schedule(const std::shared_ptr &thread); std::size_t pending_size() const; + void enqueue(std::function microtask); private: friend class Thread; + mutable std::mutex mutex_; std::vector> pending_; + std::deque> microtasks_; }; inline std::shared_ptr Thread::create(Store &store, ReadyFn ready, ResumeFn resume, bool cancellable, CancelFn on_cancel) @@ -121,7 +138,15 @@ namespace cmcpp } inline Thread::Thread(Store &store, ReadyFn ready, ResumeFn resume, bool cancellable, CancelFn on_cancel) - : store_(&store), ready_(std::move(ready)), resume_(std::move(resume)), on_cancel_(std::move(on_cancel)), cancellable_(cancellable), cancelled_(false), state_(State::Pending) + : store_(&store), + ready_(std::move(ready)), + resume_(std::move(resume)), + on_cancel_(std::move(on_cancel)), + allow_cancellation_(cancellable), + cancellable_(cancellable), + cancelled_(false), + in_event_loop_(false), + state_(State::Pending) { } @@ -144,6 +169,7 @@ namespace cmcpp auto self = shared_from_this(); ResumeFn resume; bool was_cancelled = false; + { std::lock_guard lock(mutex_); if (state_ != State::Pending) @@ -153,6 +179,7 @@ namespace cmcpp state_ = State::Running; resume = resume_; was_cancelled = cancelled_; + cancelled_ = false; } bool keep_pending = false; @@ -161,7 +188,8 @@ namespace cmcpp keep_pending = resume(was_cancelled); } - set_pending(keep_pending, self); + bool requested = reschedule_requested_.exchange(false, std::memory_order_relaxed); + set_pending(keep_pending || requested, self); } inline void Thread::request_cancellation() @@ -169,13 +197,14 @@ namespace cmcpp CancelFn cancel; { std::lock_guard lock(mutex_); - if (cancelled_ || !cancellable_) + if (!allow_cancellation_ || cancelled_) { return; } cancelled_ = true; cancel = on_cancel_; } + if (cancel) { cancel(); @@ -200,18 +229,113 @@ namespace cmcpp return state_ == State::Completed; } + inline bool Thread::suspend_until(ReadyFn ready, bool cancellable, bool force_yield) + { + bool ready_now = false; + if (ready && !force_yield) + { + ready_now = ready(); + } + + if (ready_now) + { + return true; + } + + auto gate = std::make_shared>(false); + ReadyFn wrapped = [ready = std::move(ready), gate, force_yield]() mutable -> bool + { + if (force_yield && !gate->exchange(true, std::memory_order_relaxed)) + { + return false; + } + if (!ready) + { + return true; + } + return ready(); + }; + + { + std::lock_guard lock(mutex_); + ready_ = std::move(wrapped); + cancellable_ = allow_cancellation_ && cancellable; + } + + reschedule_requested_.store(true, std::memory_order_relaxed); + return false; + } + + inline void Thread::set_ready(ReadyFn ready) + { + std::lock_guard lock(mutex_); + ready_ = std::move(ready); + } + + inline void Thread::set_allow_cancellation(bool allow) + { + std::lock_guard lock(mutex_); + allow_cancellation_ = allow; + if (!allow) + { + cancellable_ = false; + } + } + + inline bool Thread::allow_cancellation() const + { + std::lock_guard lock(mutex_); + return allow_cancellation_; + } + + inline void Thread::set_in_event_loop(bool value) + { + std::lock_guard lock(mutex_); + in_event_loop_ = value; + } + + inline bool Thread::in_event_loop() const + { + std::lock_guard lock(mutex_); + return in_event_loop_; + } + inline void Thread::set_pending(bool pending_again, const std::shared_ptr &self) { { std::lock_guard lock(mutex_); state_ = pending_again ? State::Pending : State::Completed; + if (!pending_again) + { + ready_ = nullptr; + cancellable_ = false; + } } + if (pending_again) { store_->schedule(self); } } + inline Call Call::from_thread(const std::shared_ptr &thread) + { + if (!thread) + { + return Call(); + } + std::weak_ptr weak = thread; + return Call( + [weak]() + { + if (auto locked = weak.lock()) + { + locked->request_cancellation(); + } + }, + thread->allow_cancellation()); + } + inline Call Store::invoke(const FuncInst &func, SupertaskPtr caller, OnStart on_start, OnResolve on_resolve) { if (!func) @@ -223,18 +347,35 @@ namespace cmcpp inline void Store::tick() { + std::function microtask; std::shared_ptr selected; + { std::lock_guard lock(mutex_); - auto it = std::find_if(pending_.begin(), pending_.end(), [](const std::shared_ptr &thread) - { return thread && thread->ready(); }); - if (it == pending_.end()) + if (!microtasks_.empty()) { - return; + microtask = std::move(microtasks_.front()); + microtasks_.pop_front(); } - selected = *it; - pending_.erase(it); + else + { + auto it = std::find_if(pending_.begin(), pending_.end(), [](const std::shared_ptr &thread) + { return thread && thread->ready(); }); + if (it == pending_.end()) + { + return; + } + selected = *it; + pending_.erase(it); + } + } + + if (microtask) + { + microtask(); + return; } + if (selected) { selected->resume(); @@ -256,6 +397,16 @@ namespace cmcpp std::lock_guard lock(mutex_); return pending_.size(); } + + inline void Store::enqueue(std::function microtask) + { + if (!microtask) + { + return; + } + std::lock_guard lock(mutex_); + microtasks_.push_back(std::move(microtask)); + } } #endif diff --git a/test/main.cpp b/test/main.cpp index 44b692a..84c56b8 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -206,6 +206,123 @@ TEST_CASE("Async runtime propagates cancellation") CHECK(thread->completed()); } +TEST_CASE("Backpressure counters and may_leave guards") +{ + ComponentInstance inst; + HostTrap trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + canon_backpressure_set(inst, true); + CHECK(inst.backpressure == 1); + canon_backpressure_inc(inst, trap); + CHECK(inst.backpressure == 2); + canon_backpressure_dec(inst, trap); + CHECK(inst.backpressure == 1); + canon_backpressure_set(inst, false); + CHECK(inst.backpressure == 0); + CHECK_THROWS(canon_backpressure_dec(inst, trap)); + + Store store; + inst.store = &store; + inst.may_leave = false; + CHECK_THROWS(canon_waitable_set_new(inst, trap)); + inst.may_leave = true; + CHECK_NOTHROW(canon_waitable_set_new(inst, trap)); +} + +TEST_CASE("Task yield, cancel, and return") +{ + Store store; + ComponentInstance inst; + inst.store = &store; + HostTrap trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + CanonicalOptions async_opts; + async_opts.sync = false; + + bool resolved_called = false; + std::optional> resolved_value; + + auto cancel_task = std::make_shared(inst, async_opts, nullptr, [&](std::optional> values) + { + resolved_called = true; + resolved_value = std::move(values); }); + + auto cancel_gate = std::make_shared>(true); + auto cancel_thread = Thread::create( + store, + [cancel_gate]() + { + return cancel_gate->load(); + }, + [cancel_task, &trap](bool was_cancelled) + { + CHECK(was_cancelled); + REQUIRE(cancel_task->enter(trap)); + auto event_code = canon_yield(true, *cancel_task, trap); + CHECK(event_code == 1); + canon_task_cancel(*cancel_task, trap); + cancel_task->exit(); + return false; + }, + true, + [cancel_gate]() + { + cancel_gate->store(true); + }); + + cancel_task->set_thread(cancel_thread); + cancel_task->request_cancellation(); + CHECK(cancel_task->state() == Task::State::CancelDelivered); + store.tick(); + + CHECK(resolved_called); + CHECK_FALSE(resolved_value.has_value()); + CHECK(store.pending_size() == 0); + + resolved_called = false; + resolved_value.reset(); + + auto return_task = std::make_shared(inst, async_opts, nullptr, [&](std::optional> values) + { + resolved_called = true; + resolved_value = std::move(values); }); + + auto return_gate = std::make_shared>(true); + auto return_thread = Thread::create( + store, + [return_gate]() + { + return return_gate->load(); + }, + [return_task, &trap](bool was_cancelled) + { + CHECK_FALSE(was_cancelled); + REQUIRE(return_task->enter(trap)); + auto event_code = canon_yield(false, *return_task, trap); + CHECK(event_code == 0); + std::vector payload; + payload.emplace_back(int32_t(42)); + canon_task_return(*return_task, std::move(payload), trap); + return_task->exit(); + return false; + }); + + return_task->set_thread(return_thread); + store.tick(); + + CHECK(resolved_called); + REQUIRE(resolved_value.has_value()); + REQUIRE(resolved_value->size() == 1); + CHECK(std::any_cast((*resolved_value)[0]) == 42); + CHECK(store.pending_size() == 0); +} + TEST_CASE("Resource handle lifecycle mirrors canonical definitions") { ComponentInstance resource_impl; From e2eeba91de6b805ef11d3d1147072ffb67508ed3 Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 10:33:45 +0100 Subject: [PATCH 4/6] fix: Support context locals and error-context APIs - **Labels:** `enhancement`, `abi` Signed-off-by: Gordon Smith --- include/cmcpp.hpp | 1 + include/cmcpp/context.hpp | 78 +++++++++++++++++++++++---------- include/cmcpp/error_context.hpp | 74 +++++++++++++++++++++++++++++++ test/main.cpp | 76 ++++++++++++++++++++++++++++++++ 4 files changed, 206 insertions(+), 23 deletions(-) create mode 100644 include/cmcpp/error_context.hpp diff --git a/include/cmcpp.hpp b/include/cmcpp.hpp index c664e21..e5c3166 100644 --- a/include/cmcpp.hpp +++ b/include/cmcpp.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/include/cmcpp/context.hpp b/include/cmcpp/context.hpp index 430c0fe..708646d 100644 --- a/include/cmcpp/context.hpp +++ b/include/cmcpp/context.hpp @@ -5,6 +5,7 @@ #include "runtime.hpp" #include +#include #include #include #include @@ -969,10 +970,6 @@ namespace cmcpp std::unordered_map tables_; }; - struct ErrorContext - { - }; - struct ComponentInstance { Store *store = nullptr; @@ -1010,6 +1007,27 @@ namespace cmcpp inst.backpressure -= 1; } + class ContextLocalStorage + { + public: + static constexpr uint32_t LENGTH = 1; + + ContextLocalStorage() = default; + + void set(uint32_t index, int32_t value) + { + storage_[index] = value; + } + + int32_t get(uint32_t index) const + { + return storage_[index]; + } + + private: + std::array storage_{}; + }; + class Task : public std::enable_shared_from_this { public: @@ -1184,6 +1202,16 @@ namespace cmcpp return state_; } + ContextLocalStorage &context() + { + return context_; + } + + const ContextLocalStorage &context() const + { + return context_; + } + ComponentInstance *component_instance() const { return inst_; @@ -1236,6 +1264,7 @@ namespace cmcpp uint32_t num_borrows_ = 0; std::shared_ptr thread_; State state_ = State::Initial; + ContextLocalStorage context_{}; }; inline void canon_task_return(Task &task, std::vector result, const HostTrap &trap) @@ -1276,25 +1305,6 @@ namespace cmcpp return event.code == EventCode::TASK_CANCELLED ? 1u : 0u; } - class ContextLocalStorage - { - public: - static constexpr int LENGTH = 2; - int array[LENGTH] = {0, 0}; - - ContextLocalStorage() = default; - - void set(int i, int v) - { - array[i] = v; - } - - int get(int i) - { - return array[i]; - } - }; - struct Subtask : Waitable { }; @@ -1359,6 +1369,28 @@ namespace cmcpp return element.rep; } + inline int32_t canon_context_get(Task &task, uint32_t index, const HostTrap &trap) + { + if (auto *inst = task.component_instance()) + { + ensure_may_leave(*inst, trap); + } + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, index >= ContextLocalStorage::LENGTH, "context index out of bounds"); + return task.context().get(index); + } + + inline void canon_context_set(Task &task, uint32_t index, int32_t value, const HostTrap &trap) + { + if (auto *inst = task.component_instance()) + { + ensure_may_leave(*inst, trap); + } + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, index >= ContextLocalStorage::LENGTH, "context index out of bounds"); + task.context().set(index, value); + } + inline uint32_t canon_waitable_set_new(ComponentInstance &inst, const HostTrap &trap) { ensure_may_leave(inst, trap); diff --git a/include/cmcpp/error_context.hpp b/include/cmcpp/error_context.hpp new file mode 100644 index 0000000..4fbd293 --- /dev/null +++ b/include/cmcpp/error_context.hpp @@ -0,0 +1,74 @@ +#ifndef CMCPP_ERROR_CONTEXT_HPP +#define CMCPP_ERROR_CONTEXT_HPP + +#include "context.hpp" +#include "string.hpp" +#include + +namespace cmcpp +{ + class ErrorContext : public TableEntry + { + public: + explicit ErrorContext(string_t message) + : debug_message_(std::move(message)) + { + } + + const string_t &debug_message() const + { + return debug_message_; + } + + private: + string_t debug_message_{}; + }; + + inline uint32_t canon_error_context_new(Task &task, + const LiftLowerContext *cx, + uint32_t ptr, + uint32_t tagged_code_units, + const HostTrap &trap) + { + auto *inst = task.component_instance(); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, inst == nullptr, "task missing component instance"); + ensure_may_leave(*inst, trap); + + string_t message; + if (cx != nullptr) + { + message = string::load_from_range(*cx, ptr, tagged_code_units); + } + + auto entry = std::make_shared(std::move(message)); + return inst->table.add(entry, trap); + } + + inline void canon_error_context_debug_message(Task &task, + LiftLowerContext &cx, + uint32_t index, + uint32_t ptr, + const HostTrap &trap) + { + auto *inst = task.component_instance(); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, inst == nullptr, "task missing component instance"); + ensure_may_leave(*inst, trap); + + auto errctx = inst->table.get(index, trap); + string::store(cx, errctx->debug_message(), ptr); + } + + inline void canon_error_context_drop(Task &task, uint32_t index, const HostTrap &trap) + { + auto *inst = task.component_instance(); + auto trap_cx = make_trap_context(trap); + trap_if(trap_cx, inst == nullptr, "task missing component instance"); + ensure_may_leave(*inst, trap); + + inst->table.remove(index, trap); + } +} + +#endif // CMCPP_ERROR_CONTEXT_HPP diff --git a/test/main.cpp b/test/main.cpp index 84c56b8..277127d 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -5,6 +5,7 @@ using namespace cmcpp; #include +#include #include #include #include @@ -232,6 +233,81 @@ TEST_CASE("Backpressure counters and may_leave guards") CHECK_NOTHROW(canon_waitable_set_new(inst, trap)); } +TEST_CASE("Context locals provide per-task storage") +{ + ComponentInstance inst; + HostTrap trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + Task task(inst); + + CHECK(ContextLocalStorage::LENGTH == 1); + CHECK(canon_context_get(task, 0, trap) == 0); + + canon_context_set(task, 0, 42, trap); + CHECK(canon_context_get(task, 0, trap) == 42); + + CHECK_THROWS(canon_context_get(task, ContextLocalStorage::LENGTH, trap)); + CHECK_THROWS(canon_context_set(task, ContextLocalStorage::LENGTH, 99, trap)); + + inst.may_leave = false; + CHECK_THROWS(canon_context_get(task, 0, trap)); + inst.may_leave = true; +} + +TEST_CASE("Error context APIs manage debug messages") +{ + Store store; + ComponentInstance inst; + inst.store = &store; + + HostTrap trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + Task task(inst); + + Heap heap(1024); + auto cx = createLiftLowerContext(&heap, Encoding::Utf8); + cx->inst = &inst; + + const std::string guest_message = "component failure"; + const uint32_t guest_ptr = 32; + std::copy(guest_message.begin(), guest_message.end(), heap.memory.begin() + guest_ptr); + uint32_t tagged_code_units = static_cast(guest_message.size()); + + uint32_t err_index = canon_error_context_new(task, cx.get(), guest_ptr, tagged_code_units, trap); + CHECK(err_index != 0); + + const uint32_t record_ptr = 128; + canon_error_context_debug_message(task, *cx, err_index, record_ptr, trap); + + uint32_t stored_ptr = integer::load(*cx, record_ptr); + uint32_t stored_len = integer::load(*cx, record_ptr + 4); + auto roundtrip = string::load_from_range(*cx, stored_ptr, stored_len); + CHECK(roundtrip == guest_message); + + inst.may_leave = false; + CHECK_THROWS(canon_error_context_debug_message(task, *cx, err_index, record_ptr + 16, trap)); + inst.may_leave = true; + + canon_error_context_drop(task, err_index, trap); + CHECK_THROWS(canon_error_context_drop(task, err_index, trap)); + + uint32_t empty_index = canon_error_context_new(task, nullptr, 0, 0, trap); + CHECK(empty_index != 0); + + const uint32_t empty_record = 192; + canon_error_context_debug_message(task, *cx, empty_index, empty_record, trap); + uint32_t empty_len = integer::load(*cx, empty_record + 4); + CHECK(empty_len == 0); + + canon_error_context_drop(task, empty_index, trap); +} + TEST_CASE("Task yield, cancel, and return") { Store store; From 92f1cda9543582c4702e30df8af69af138aa266a Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 10:54:10 +0100 Subject: [PATCH 5/6] fix: Finish function flattening utilities Signed-off-by: Gordon Smith --- include/cmcpp/flags.hpp | 21 ++-- include/cmcpp/func.hpp | 214 +++++++++++++++++++++------------------ include/cmcpp/lower.hpp | 2 +- include/cmcpp/traits.hpp | 1 + test/main.cpp | 94 +++++++++++++++++ 5 files changed, 224 insertions(+), 108 deletions(-) diff --git a/include/cmcpp/flags.hpp b/include/cmcpp/flags.hpp index 196f10d..851233a 100644 --- a/include/cmcpp/flags.hpp +++ b/include/cmcpp/flags.hpp @@ -1,6 +1,8 @@ #ifndef CMCPP_FLAGS_HPP #define CMCPP_FLAGS_HPP +#include + #include "context.hpp" #include "integer.hpp" #include "util.hpp" @@ -12,14 +14,14 @@ namespace cmcpp template int32_t pack_flags_into_int(const T &v) { - return v.to_ulong(); + return static_cast(v.to_ulong()); } template void store(LiftLowerContext &cx, const T &v, offset ptr) { auto i = pack_flags_into_int(v); - std::memcpy(&cx.opts.memory[ptr], i, ValTrait::size); + std::memcpy(&cx.opts.memory[ptr], &i, ValTrait::size); } template @@ -31,24 +33,25 @@ namespace cmcpp template T unpack_flags_from_int(const uint32_t &buff) { - return {buff}; + T value{}; + using bitset_type = typename ValTrait::inner_type; + static_cast(value) = bitset_type(static_cast(buff)); + return value; } template T load(const LiftLowerContext &cx, uint32_t ptr) { - uint8_t buff[ValTrait::size]; - std::memcpy(&buff, &cx.opts.memory[ptr], ValTrait::size); - return unpack_flags_from_int(buff); + uint32_t raw = 0; + std::memcpy(&raw, &cx.opts.memory[ptr], ValTrait::size); + return unpack_flags_from_int(raw); } template T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) { auto i = vi.next(); - uint8_t buff[ValTrait::size]; - std::memcpy(&buff, &i, ValTrait::size); - return unpack_flags_from_int(i); + return unpack_flags_from_int(static_cast(i)); } } diff --git a/include/cmcpp/func.hpp b/include/cmcpp/func.hpp index e0bc137..716116f 100644 --- a/include/cmcpp/func.hpp +++ b/include/cmcpp/func.hpp @@ -2,109 +2,127 @@ #define CMCPP_FUNC_HPP #include "context.hpp" +#include "flags.hpp" namespace cmcpp { namespace func { - // template - // int32_t pack_flags_into_int(const T &v) - // { - // return v.to_ulong(); - // } - - // template - // void store(LiftLowerContext &cx, const T &v, offset ptr) - // { - // auto i = pack_flags_into_int(v); - // std::memcpy(&cx.opts.memory[ptr], i, ValTrait::size); - // } - - // template - // WasmValVector lower_flat(LiftLowerContext &cx, const T &v) - // { - // return {pack_flags_into_int(v)}; - // } - - // template - // T unpack_flags_from_int(const uint32_t &buff) - // { - // return {buff}; - // } - - // template - // T load(const LiftLowerContext &cx, uint32_t ptr) - // { - // uint8_t buff[ValTrait::size]; - // std::memcpy(&buff, &cx.opts.memory[ptr], ValTrait::size); - // return unpack_flags_from_int(buff); - // } - - // template - // T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) - // { - // auto i = vi.next(); - // uint8_t buff[ValTrait::size]; - // std::memcpy(&buff, &i, ValTrait::size); - // return unpack_flags_from_int(i); - // } - - // enum class ContextType - // { - // Lift, - // Lower - // }; - - // template - // inline core_func_t flatten(LiftLowerContext &cx, ContextType context) - // { - // std::vector flat_params(ValTrait::flat_params_types.begin(), ValTrait::flat_params_types.end()); - // std::vector flat_results(ValTrait::flat_result_types.begin(), ValTrait::flat_result_types.end()); - // // if (cx.opts.sync == true) - // { - // if (flat_params.size() > MAX_FLAT_PARAMS) - // { - // flat_params = {WasmValType::i32}; - // } - // if (flat_results.size() > MAX_FLAT_RESULTS) - // { - // switch (context) - // { - // case ContextType::Lift: - // flat_results = {WasmValType::i32}; - // break; - // case ContextType::Lower: - // flat_params.push_back(WasmValType::i32); - // flat_results = {}; - // } - // } - // } - // return {flat_params, flat_results}; - // } - - // template - // inline void store(LiftLowerContext &cx, const T &v, uint32_t ptr) - // { - // flags::store(cx, v, ptr); - // } - - // template - // inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v) - // { - // return flags::lower_flat(cx, v); - // } - - // template - // inline T load(const LiftLowerContext &cx, uint32_t ptr) - // { - // return flags::load(cx, ptr); - // } - - // template - // inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) - // { - // return flags::lift_flat(cx, vi); - // } + enum class ContextType + { + Lift, + Lower + }; + + template + inline int32_t pack_flags_into_int(const T &v) + { + return flags::pack_flags_into_int(v); + } + + template + inline void store(LiftLowerContext &cx, const T &v, offset ptr) + { + flags::store(cx, v, ptr); + } + + template + inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v) + { + return flags::lower_flat(cx, v); + } + + template + inline T unpack_flags_from_int(uint32_t value) + { + return flags::unpack_flags_from_int(value); + } + + template + inline T load(const LiftLowerContext &cx, uint32_t ptr) + { + return flags::load(cx, ptr); + } + + template + inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) + { + return flags::lift_flat(cx, vi); + } + + template + inline core_func_t flatten(const CanonicalOptions &opts, ContextType context) + { + using params_trait = ValTrait::params_t>; + using result_trait = ValTrait::result_t>; + + WasmValTypeVector flat_params(params_trait::flat_types.begin(), params_trait::flat_types.end()); + WasmValTypeVector flat_results(result_trait::flat_types.begin(), result_trait::flat_types.end()); + + const size_t raw_param_count = flat_params.size(); + const size_t raw_result_count = flat_results.size(); + + auto pointer_type = []() -> WasmValTypeVector + { + return {WasmValType::i32}; + }; + + if (opts.sync) + { + if (raw_param_count > MAX_FLAT_PARAMS) + { + flat_params = pointer_type(); + } + + if (raw_result_count > MAX_FLAT_RESULTS) + { + if (context == ContextType::Lift) + { + flat_results = pointer_type(); + } + else + { + flat_params.push_back(WasmValType::i32); + flat_results.clear(); + } + } + } + else + { + if (context == ContextType::Lift) + { + if (raw_param_count > MAX_FLAT_PARAMS) + { + flat_params = pointer_type(); + } + + if (opts.callback.has_value()) + { + flat_results = pointer_type(); + } + else + { + flat_results.clear(); + } + } + else + { + if (raw_param_count > MAX_FLAT_ASYNC_PARAMS) + { + flat_params = pointer_type(); + } + + if (raw_result_count > 0) + { + flat_params.push_back(WasmValType::i32); + } + + flat_results = pointer_type(); + } + } + + return {std::move(flat_params), std::move(flat_results)}; + } } } #endif diff --git a/include/cmcpp/lower.hpp b/include/cmcpp/lower.hpp index c4dbca8..adc2f3a 100644 --- a/include/cmcpp/lower.hpp +++ b/include/cmcpp/lower.hpp @@ -62,7 +62,7 @@ namespace cmcpp if (out_param == nullptr) { ptr = cx.opts.realloc(0, 0, ValTrait::alignment, ValTrait::size); - flat_vals = {ptr}; + flat_vals = {static_cast(ptr)}; } else { diff --git a/include/cmcpp/traits.hpp b/include/cmcpp/traits.hpp index aa400d3..30dd27f 100644 --- a/include/cmcpp/traits.hpp +++ b/include/cmcpp/traits.hpp @@ -809,6 +809,7 @@ namespace cmcpp // Func -------------------------------------------------------------------- constexpr uint32_t MAX_FLAT_PARAMS = 16; constexpr uint32_t MAX_FLAT_RESULTS = 1; + constexpr uint32_t MAX_FLAT_ASYNC_PARAMS = 4; template struct func_t_impl; diff --git a/test/main.cpp b/test/main.cpp index 277127d..74a26d4 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1317,6 +1317,14 @@ TEST_CASE("Flags") CHECK(flags.test<"ccc">() == f.test<"ccc">()); CHECK(flags == f); + auto encoded_flags = func::pack_flags_into_int(flags); + auto decoded_flags = func::unpack_flags_from_int(static_cast(encoded_flags)); + CHECK(decoded_flags == flags); + + func::store(*cx, flags, 0); + auto loaded_flags = func::load(*cx, 0); + CHECK(loaded_flags == flags); + using MyFlags2 = flags_t<"one", "two", "three", "four", "five", "six", "seven", "8", "nine">; CHECK(ValTrait::size == 2); CHECK(MyFlags2::labelsSize == 9); @@ -1487,6 +1495,7 @@ TEST_CASE("Records") auto v = lower_flat(*cx, p_in); auto p_out = lift_flat(*cx, v); CHECK(p_in.name == p_out.name); + CHECK(p_in.age == p_out.age); CHECK(p_in.weight == p_out.weight); @@ -1551,6 +1560,91 @@ TEST_CASE("Records") CHECK(pex3_in.p.phones == pex3_out.p.phones); } +TEST_CASE("Function flattening honors canonical limits") +{ + CanonicalOptions opts; + opts.sync = true; + + using HeavyParamFn = std::function; + + auto lift_flat = func::flatten(opts, func::ContextType::Lift); + CHECK(lift_flat.params.size() == 1); + CHECK(lift_flat.params[0] == WasmValType::i32); + CHECK(lift_flat.results.empty()); + + auto lower_flat = func::flatten(opts, func::ContextType::Lower); + CHECK(lower_flat.params.size() == 1); + CHECK(lower_flat.params[0] == WasmValType::i32); + CHECK(lower_flat.results.empty()); + + using HeavyResultFn = std::function()>; + + auto lift_results = func::flatten(opts, func::ContextType::Lift); + CHECK(lift_results.params.empty()); + CHECK(lift_results.results.size() == 1); + CHECK(lift_results.results[0] == WasmValType::i32); + + auto lower_results = func::flatten(opts, func::ContextType::Lower); + CHECK(lower_results.params.size() == 1); + CHECK(lower_results.params.back() == WasmValType::i32); + CHECK(lower_results.results.empty()); +} + +TEST_CASE("Async function flattening matches canonical ABI") +{ + using AsyncFn = std::function; + + CanonicalOptions async_opts; + async_opts.sync = false; + + auto lower_async = func::flatten(async_opts, func::ContextType::Lower); + CHECK(lower_async.params.size() == 2); + CHECK(lower_async.params[0] == WasmValType::i32); + CHECK(lower_async.params[1] == WasmValType::i32); + CHECK(lower_async.results.size() == 1); + CHECK(lower_async.results[0] == WasmValType::i32); + + using AsyncLiftFn = std::function; + + CanonicalOptions lift_opts; + lift_opts.sync = false; + + auto lift_no_callback = func::flatten(lift_opts, func::ContextType::Lift); + CHECK(lift_no_callback.params.size() == 1); + CHECK(lift_no_callback.params[0] == WasmValType::i32); + CHECK(lift_no_callback.results.empty()); + + CanonicalOptions lift_with_callback; + lift_with_callback.sync = false; + lift_with_callback.callback = GuestCallback([]() {}); + + auto lift_callback = func::flatten(lift_with_callback, func::ContextType::Lift); + CHECK(lift_callback.results.size() == 1); + CHECK(lift_callback.results[0] == WasmValType::i32); +} + +TEST_CASE("Heap-based lowering triggers for oversized results") +{ + Heap heap(1024 * 1024); + auto cx = createLiftLowerContext(&heap, Encoding::Utf8); + + using ResultTuple = tuple_t; + ResultTuple value{"alpha", "beta"}; + + auto lowered = lower_flat_values(*cx, MAX_FLAT_RESULTS, nullptr, std::move(value)); + REQUIRE(lowered.size() == 1); + auto ptr = std::get(lowered[0]); + + auto stored = load(*cx, ptr); + CHECK(std::get<0>(stored) == "alpha"); + CHECK(std::get<1>(stored) == "beta"); + + CoreValueIter iter(lowered); + auto lifted = lift_flat_values(*cx, MAX_FLAT_RESULTS, iter); + CHECK(std::get<0>(lifted) == "alpha"); + CHECK(std::get<1>(lifted) == "beta"); +} + TEST_CASE("Variant") { Heap heap(1024 * 1024); From 945ab1ef2e0dab1e5f77bafb259ecd0ee49e52a7 Mon Sep 17 00:00:00 2001 From: Gordon Smith Date: Sun, 28 Sep 2025 11:19:03 +0100 Subject: [PATCH 6/6] fix: Wire canonical options and callbacks through lift/lower Signed-off-by: Gordon Smith --- include/cmcpp/context.hpp | 131 +++++++++++++++++++++++++++++++++----- include/cmcpp/lower.hpp | 6 ++ test/host-util.hpp | 12 +++- test/main.cpp | 111 ++++++++++++++++++++++++++++++-- 4 files changed, 240 insertions(+), 20 deletions(-) diff --git a/include/cmcpp/context.hpp b/include/cmcpp/context.hpp index 708646d..14dad52 100644 --- a/include/cmcpp/context.hpp +++ b/include/cmcpp/context.hpp @@ -26,11 +26,13 @@ namespace cmcpp { + enum class EventCode : uint8_t; + using HostTrap = std::function; using GuestRealloc = std::function; using GuestMemory = std::span; using GuestPostReturn = std::function; - using GuestCallback = std::function; + using GuestCallback = std::function; using HostUnicodeConversion = std::function(void *dest, uint32_t dest_byte_len, const void *src, uint32_t src_byte_len, Encoding from_encoding, Encoding to_encoding)>; // Canonical ABI Options --- @@ -83,8 +85,18 @@ namespace cmcpp LiftLowerContext(const HostTrap &host_trap, const HostUnicodeConversion &conversion, const LiftLowerOptions &options, ComponentInstance *instance = nullptr) : trap(host_trap), convert(conversion), opts(options), inst(instance) {} + void set_canonical_options(CanonicalOptions options); + CanonicalOptions *canonical_options(); + const CanonicalOptions *canonical_options() const; + bool is_sync() const; + void invoke_post_return() const; + void notify_async_event(EventCode code, uint32_t index, uint32_t payload) const; + void track_owning_lend(HandleElement &lending_handle); void exit_call(); + + private: + std::optional canonical_opts_; }; inline void trap_if(const LiftLowerContext &cx, bool condition, const char *message = nullptr) noexcept(false) @@ -103,6 +115,54 @@ namespace cmcpp throw std::runtime_error(msg); } + inline void LiftLowerContext::set_canonical_options(CanonicalOptions options) + { + canonical_opts_ = std::move(options); + opts = *canonical_opts_; + } + + inline CanonicalOptions *LiftLowerContext::canonical_options() + { + return canonical_opts_ ? &*canonical_opts_ : nullptr; + } + + inline const CanonicalOptions *LiftLowerContext::canonical_options() const + { + return canonical_opts_ ? &*canonical_opts_ : nullptr; + } + + inline bool LiftLowerContext::is_sync() const + { + if (auto *canon = canonical_options()) + { + return canon->sync; + } + return true; + } + + inline void LiftLowerContext::invoke_post_return() const + { + if (auto *canon = canonical_options()) + { + if (canon->post_return) + { + (*canon->post_return)(); + } + } + } + + inline void LiftLowerContext::notify_async_event(EventCode code, uint32_t index, uint32_t payload) const + { + if (auto *canon = canonical_options()) + { + trap_if(*this, canon->sync, "async continuation requires async canonical options"); + if (canon->callback) + { + (*canon->callback)(code, index, payload); + } + } + } + inline LiftLowerContext make_trap_context(const HostTrap &trap) { HostUnicodeConversion convert{}; @@ -518,7 +578,7 @@ namespace cmcpp uint32_t read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, uint32_t n, bool sync, const HostTrap &trap); uint32_t cancel(bool sync, const HostTrap &trap); void drop(const HostTrap &trap); - void complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); + void complete_async(const std::shared_ptr &cx, uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); private: std::shared_ptr shared_; @@ -582,7 +642,8 @@ namespace cmcpp auto pending = std::move(*shared_->pending_read); shared_->pending_read.reset(); - set_pending_event({EventCode::STREAM_READ, pending.handle_index, pack_copy_result(CopyResult::Cancelled, pending.progress)}); + auto payload = pack_copy_result(CopyResult::Cancelled, pending.progress); + set_pending_event({EventCode::STREAM_READ, pending.handle_index, payload}); state_ = CopyState::Done; if (sync) @@ -590,6 +651,10 @@ namespace cmcpp auto event = get_pending_event(trap); return event.payload; } + if (pending.cx) + { + pending.cx->notify_async_event(EventCode::STREAM_READ, pending.handle_index, payload); + } return BLOCKED; } @@ -606,10 +671,15 @@ namespace cmcpp Waitable::drop(trap); } - inline void ReadableStreamEnd::complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) + inline void ReadableStreamEnd::complete_async(const std::shared_ptr &cx, uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) { - set_pending_event({EventCode::STREAM_READ, handle_index, pack_copy_result(result, progress)}); + auto payload = pack_copy_result(result, progress); + set_pending_event({EventCode::STREAM_READ, handle_index, payload}); state_ = (result == CopyResult::Completed) ? CopyState::Idle : CopyState::Done; + if (cx) + { + cx->notify_async_event(EventCode::STREAM_READ, handle_index, payload); + } } inline void satisfy_pending_read(SharedStreamState &state, const HostTrap &trap) @@ -624,7 +694,7 @@ namespace cmcpp pending.progress += consumed; if (pending.progress >= pending.requested) { - pending.endpoint->complete_async(pending.handle_index, CopyResult::Completed, pending.progress, trap); + pending.endpoint->complete_async(pending.cx, pending.handle_index, CopyResult::Completed, pending.progress, trap); state.pending_read.reset(); } } @@ -662,7 +732,7 @@ namespace cmcpp { auto pending = std::move(*shared_->pending_read); shared_->pending_read.reset(); - pending.endpoint->complete_async(pending.handle_index, CopyResult::Dropped, pending.progress, trap); + pending.endpoint->complete_async(pending.cx, pending.handle_index, CopyResult::Dropped, pending.progress, trap); } shared_->writable_dropped = true; } @@ -707,7 +777,7 @@ namespace cmcpp uint32_t read(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, bool sync, const HostTrap &trap); uint32_t cancel(bool sync, const HostTrap &trap); void drop(const HostTrap &trap); - void complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); + void complete_async(const std::shared_ptr &cx, uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap); private: std::shared_ptr shared_; @@ -772,7 +842,8 @@ namespace cmcpp auto pending = std::move(*shared_->pending_read); shared_->pending_read.reset(); - set_pending_event({EventCode::FUTURE_READ, pending.handle_index, pack_copy_result(CopyResult::Cancelled, 0)}); + auto payload = pack_copy_result(CopyResult::Cancelled, 0); + set_pending_event({EventCode::FUTURE_READ, pending.handle_index, payload}); state_ = CopyState::Done; if (sync) @@ -780,6 +851,10 @@ namespace cmcpp auto event = get_pending_event(trap); return event.payload; } + if (pending.cx) + { + pending.cx->notify_async_event(EventCode::FUTURE_READ, pending.handle_index, payload); + } return BLOCKED; } @@ -796,10 +871,15 @@ namespace cmcpp Waitable::drop(trap); } - inline void ReadableFutureEnd::complete_async(uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) + inline void ReadableFutureEnd::complete_async(const std::shared_ptr &cx, uint32_t handle_index, CopyResult result, uint32_t progress, const HostTrap &trap) { - set_pending_event({EventCode::FUTURE_READ, handle_index, pack_copy_result(result, progress)}); + auto payload = pack_copy_result(result, progress); + set_pending_event({EventCode::FUTURE_READ, handle_index, payload}); state_ = (result == CopyResult::Completed) ? CopyState::Idle : CopyState::Done; + if (cx) + { + cx->notify_async_event(EventCode::FUTURE_READ, handle_index, payload); + } } inline uint32_t WritableFutureEnd::write(const std::shared_ptr &cx, uint32_t handle_index, uint32_t ptr, const HostTrap &trap) @@ -819,7 +899,7 @@ namespace cmcpp shared_->pending_read.reset(); ensure_memory_range(*pending.cx, pending.ptr, 1, shared_->descriptor.alignment, shared_->descriptor.element_size); std::memcpy(pending.cx->opts.memory.data() + pending.ptr, shared_->value.data(), shared_->descriptor.element_size); - pending.endpoint->complete_async(pending.handle_index, CopyResult::Completed, 1, trap); + pending.endpoint->complete_async(pending.cx, pending.handle_index, CopyResult::Completed, 1, trap); } set_pending_event({EventCode::FUTURE_WRITE, handle_index, pack_copy_result(CopyResult::Completed, 1)}); @@ -843,7 +923,7 @@ namespace cmcpp { auto pending = std::move(*shared_->pending_read); shared_->pending_read.reset(); - pending.endpoint->complete_async(pending.handle_index, CopyResult::Dropped, 0, trap); + pending.endpoint->complete_async(pending.cx, pending.handle_index, CopyResult::Dropped, 0, trap); } shared_->writable_dropped = true; } @@ -1603,10 +1683,31 @@ namespace cmcpp HostUnicodeConversion convert; GuestRealloc realloc; - std::unique_ptr createLiftLowerContext(const GuestMemory &memory, const Encoding &string_encoding = Encoding::Utf8, const std::optional &post_return = std::nullopt) + std::unique_ptr createLiftLowerContext(const GuestMemory &memory, + const Encoding &string_encoding = Encoding::Utf8, + const std::optional &post_return = std::nullopt, + bool sync = true, + const std::optional &callback = std::nullopt) { - LiftLowerOptions opts(string_encoding, memory, realloc); + CanonicalOptions options; + options.string_encoding = string_encoding; + options.memory = memory; + options.realloc = realloc; + options.post_return = post_return; + options.sync = sync; + options.callback = callback; + return createLiftLowerContext(std::move(options)); + } + + std::unique_ptr createLiftLowerContext(CanonicalOptions options) + { + if (!options.realloc) + { + options.realloc = realloc; + } + LiftLowerOptions opts(options.string_encoding, options.memory, options.realloc); auto retVal = std::make_unique(trap, convert, opts); + retVal->set_canonical_options(std::move(options)); return retVal; } }; diff --git a/include/cmcpp/lower.hpp b/include/cmcpp/lower.hpp index adc2f3a..312040a 100644 --- a/include/cmcpp/lower.hpp +++ b/include/cmcpp/lower.hpp @@ -78,6 +78,10 @@ namespace cmcpp template inline WasmValVector lower_flat_values(LiftLowerContext &cx, uint32_t max_flat, uint32_t *out_param, Ts &&...vs) { + if (auto *canon = cx.canonical_options()) + { + trap_if(cx, canon->sync && max_flat == 0, "async lowering requires async canonical options"); + } WasmValVector retVal = {}; // cx.inst.may_leave=false; constexpr auto flat_types = ValTrait>::flat_types; @@ -93,9 +97,11 @@ namespace cmcpp retVal.insert(retVal.end(), flat.begin(), flat.end()); }; (lower_v(vs), ...); + cx.invoke_post_return(); return retVal; } // cx.inst.may_leave=true; + cx.invoke_post_return(); return retVal; } diff --git a/test/host-util.hpp b/test/host-util.hpp index d7ee519..77c60ae 100644 --- a/test/host-util.hpp +++ b/test/host-util.hpp @@ -43,12 +43,22 @@ class Heap } }; +inline std::unique_ptr createLiftLowerContext(Heap *heap, CanonicalOptions options); + inline std::unique_ptr createLiftLowerContext(Heap *heap, Encoding encoding) +{ + CanonicalOptions options; + options.string_encoding = encoding; + return createLiftLowerContext(heap, std::move(options)); +} + +inline std::unique_ptr createLiftLowerContext(Heap *heap, CanonicalOptions options) { std::unique_ptr instanceContext = std::make_unique(trap, convert, [heap](int original_ptr, int original_size, int alignment, int new_size) -> int { return heap->realloc(original_ptr, original_size, alignment, new_size); }); - return instanceContext->createLiftLowerContext(heap->memory, encoding); + options.memory = heap->memory; + return instanceContext->createLiftLowerContext(std::move(options)); } diff --git a/test/main.cpp b/test/main.cpp index 74a26d4..d62fce1 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -399,6 +399,103 @@ TEST_CASE("Task yield, cancel, and return") CHECK(store.pending_size() == 0); } +TEST_CASE("Canonical options control lift/lower callbacks") +{ + SUBCASE("post_return runs once for heap spill") + { + Heap heap(1024); + int post_return_calls = 0; + + CanonicalOptions options; + options.post_return = GuestPostReturn([&]() + { ++post_return_calls; }); + + auto cx = createLiftLowerContext(&heap, std::move(options)); + using ResultTuple = tuple_t; + ResultTuple value{"alpha", "beta"}; + + auto lowered = lower_flat_values(*cx, MAX_FLAT_RESULTS, nullptr, std::move(value)); + CHECK(lowered.size() == 1); + CHECK(post_return_calls == 1); + } + + SUBCASE("post_return runs once when using provided out pointer") + { + Heap heap(1024); + int post_return_calls = 0; + + CanonicalOptions options; + options.post_return = GuestPostReturn([&]() + { ++post_return_calls; }); + + auto cx = createLiftLowerContext(&heap, std::move(options)); + using ResultTuple = tuple_t; + ResultTuple value{"gamma", "delta"}; + + uint32_t out_ptr = align_to(32u, ValTrait::alignment); + auto lowered = lower_flat_values(*cx, MAX_FLAT_RESULTS, &out_ptr, std::move(value)); + CHECK(lowered.empty()); + CHECK(post_return_calls == 1); + } + + SUBCASE("sync context traps on async lowering") + { + Heap heap(1024); + CanonicalOptions options; + auto cx = createLiftLowerContext(&heap, std::move(options)); + using ResultTuple = tuple_t; + CHECK_THROWS(lower_flat_values(*cx, 0, nullptr, ResultTuple{"omega", "sigma"})); + } + + SUBCASE("async callback fires for stream completion") + { + Heap heap(1024); + bool callback_called = false; + EventCode observed_code = EventCode::NONE; + uint32_t observed_index = 0; + uint32_t observed_payload = 0; + + CanonicalOptions options; + options.sync = false; + options.callback = GuestCallback([&](EventCode code, uint32_t index, uint32_t payload) + { + callback_called = true; + observed_code = code; + observed_index = index; + observed_payload = payload; }); + + auto cx_unique = createLiftLowerContext(&heap, options); + std::shared_ptr cx(cx_unique.release(), [](LiftLowerContext *ptr) + { delete ptr; }); + + HostTrap trap = [](const char *msg) + { + throw std::runtime_error(msg ? msg : "trap"); + }; + + auto descriptor = make_stream_descriptor(); + auto shared_state = std::make_shared(descriptor); + ReadableStreamEnd readable(shared_state); + WritableStreamEnd writable(shared_state); + + uint32_t read_ptr = 0; + uint32_t write_ptr = 64; + heap.memory[write_ptr] = 0x42; + + auto blocked = readable.read(cx, 1, read_ptr, 1, false, trap); + CHECK(blocked == BLOCKED); + CHECK_FALSE(callback_called); + + writable.write(cx, 2, write_ptr, 1, trap); + + CHECK(callback_called); + CHECK(observed_code == EventCode::STREAM_READ); + CHECK(observed_index == 1); + CHECK(observed_payload == pack_copy_result(CopyResult::Completed, 1)); + CHECK(heap.memory[read_ptr] == heap.memory[write_ptr]); + } +} + TEST_CASE("Resource handle lifecycle mirrors canonical definitions") { ComponentInstance resource_impl; @@ -830,7 +927,9 @@ TEST_CASE("Waitable set surfaces stream readiness") canon_waitable_join(inst, readable, waitable_set, host_trap); Heap heap(256); - auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + CanonicalOptions options; + options.sync = false; + auto cx = std::shared_ptr(createLiftLowerContext(&heap, options).release(), [](LiftLowerContext *ptr) { delete ptr; }); uint32_t read_ptr = 0; @@ -890,7 +989,9 @@ TEST_CASE("Stream cancellation posts events") canon_waitable_join(inst, readable, waitable_set, host_trap); Heap heap(128); - auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + CanonicalOptions options; + options.sync = false; + auto cx = std::shared_ptr(createLiftLowerContext(&heap, options).release(), [](LiftLowerContext *ptr) { delete ptr; }); uint32_t read_ptr = 0; @@ -933,7 +1034,9 @@ TEST_CASE("Future lifecycle completes") canon_waitable_join(inst, readable, waitable_set, host_trap); Heap heap(256); - auto cx = std::shared_ptr(createLiftLowerContext(&heap, Encoding::Utf8).release(), [](LiftLowerContext *ptr) + CanonicalOptions options; + options.sync = false; + auto cx = std::shared_ptr(createLiftLowerContext(&heap, options).release(), [](LiftLowerContext *ptr) { delete ptr; }); uint32_t read_ptr = 0; @@ -1616,7 +1719,7 @@ TEST_CASE("Async function flattening matches canonical ABI") CanonicalOptions lift_with_callback; lift_with_callback.sync = false; - lift_with_callback.callback = GuestCallback([]() {}); + lift_with_callback.callback = GuestCallback([](EventCode, uint32_t, uint32_t) {}); auto lift_callback = func::flatten(lift_with_callback, func::ContextType::Lift); CHECK(lift_callback.results.size() == 1);