Skip to content

Commit 21d926d

Browse files
committed
CABI: simplify stream code a bit more (no change)
1 parent a7029a6 commit 21d926d

3 files changed

Lines changed: 190 additions & 197 deletions

File tree

design/mvp/CanonicalABI.md

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,16 +2827,13 @@ of component-level values with types `ts`.
28272827
def lift_flat_values(cx, max_flat, vi, ts):
28282828
flat_types = flatten_types(ts)
28292829
if len(flat_types) > max_flat:
2830-
return lift_heap_values(cx, vi, ts)
2830+
ptr = vi.next('i32')
2831+
tuple_type = TupleType(ts)
2832+
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
2833+
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
2834+
return list(load(cx, ptr, tuple_type).values())
28312835
else:
28322836
return [ lift_flat(cx, vi, t) for t in ts ]
2833-
2834-
def lift_heap_values(cx, vi, ts):
2835-
ptr = vi.next('i32')
2836-
tuple_type = TupleType(ts)
2837-
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
2838-
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
2839-
return list(load(cx, ptr, tuple_type).values())
28402837
```
28412838

28422839
Symmetrically, the `lower_flat_values` function defines how to lower a
@@ -2850,27 +2847,23 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
28502847
cx.inst.may_leave = False
28512848
flat_types = flatten_types(ts)
28522849
if len(flat_types) > max_flat:
2853-
flat_vals = lower_heap_values(cx, vs, ts, out_param)
2850+
tuple_type = TupleType(ts)
2851+
tuple_value = {str(i): v for i,v in enumerate(vs)}
2852+
if out_param is None:
2853+
ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type))
2854+
flat_vals = [ptr]
2855+
else:
2856+
ptr = out_param.next('i32')
2857+
flat_vals = []
2858+
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
2859+
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
2860+
store(cx, tuple_value, tuple_type, ptr)
28542861
else:
28552862
flat_vals = []
28562863
for i in range(len(vs)):
28572864
flat_vals += lower_flat(cx, vs[i], ts[i])
28582865
cx.inst.may_leave = True
28592866
return flat_vals
2860-
2861-
def lower_heap_values(cx, vs, ts, out_param):
2862-
tuple_type = TupleType(ts)
2863-
tuple_value = {str(i): v for i,v in enumerate(vs)}
2864-
if out_param is None:
2865-
ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type))
2866-
flat_vals = [ptr]
2867-
else:
2868-
ptr = out_param.next('i32')
2869-
flat_vals = []
2870-
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
2871-
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
2872-
store(cx, tuple_value, tuple_type, ptr)
2873-
return flat_vals
28742867
```
28752868
The `may_leave` flag is guarded by `canon_lower` below to prevent a component
28762869
from calling out of the component while in the middle of lowering, ensuring

design/mvp/canonical-abi/definitions.py

Lines changed: 80 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class Buffer:
307307
MAX_LENGTH = 2**28 - 1
308308
t: ValType
309309
remain: Callable[[], int]
310+
is_zero_length: Callable[[], bool]
310311

311312
class ReadableBuffer(Buffer):
312313
read: Callable[[int], list[any]]
@@ -335,6 +336,9 @@ def __init__(self, t, cx, ptr, length):
335336
def remain(self):
336337
return self.length - self.progress
337338

339+
def is_zero_length(self):
340+
return self.length == 0
341+
338342
class ReadableBufferGuestImpl(BufferGuestImpl):
339343
def read(self, n):
340344
assert(n <= self.remain())
@@ -749,13 +753,18 @@ def drop(self):
749753

750754
#### Stream State
751755

756+
class StreamResult(IntEnum):
757+
COMPLETED = 0
758+
CLOSED = 1
759+
CANCELLED = 2
760+
752761
RevokeBuffer = Callable[[], None]
753-
OnPartialCopy = Callable[[RevokeBuffer], None]
754-
OnCopyDone = Callable[[Literal['completed','cancelled']], None]
762+
OnCopy = Callable[[RevokeBuffer], None]
763+
OnStreamResult = Callable[[StreamResult], None]
755764

756765
class ReadableStream:
757766
t: ValType
758-
read: Callable[[ComponentInstance, WritableBuffer, OnPartialCopy, OnCopyDone], Literal['done','blocked']]
767+
read: Callable[[ComponentInstance, WritableBuffer, OnCopy, OnStreamResult], Optional[StreamResult]]
759768
cancel: Callable[[], None]
760769
close: Callable[[]]
761770
closed: Callable[[], bool]
@@ -764,8 +773,8 @@ class SharedStreamImpl(ReadableStream):
764773
closed_: bool
765774
pending_inst: Optional[ComponentInstance]
766775
pending_buffer: Optional[Buffer]
767-
pending_on_partial_copy: Optional[OnPartialCopy]
768-
pending_on_copy_done: Optional[OnCopyDone]
776+
pending_on_copy: Optional[OnCopy]
777+
pending_on_result: Optional[OnStreamResult]
769778

770779
def __init__(self, t):
771780
self.t = t
@@ -775,59 +784,55 @@ def __init__(self, t):
775784
def reset_pending(self):
776785
self.set_pending(None, None, None, None)
777786

778-
def set_pending(self, inst, buffer, on_partial_copy, on_copy_done):
787+
def set_pending(self, inst, buffer, on_copy, on_result):
779788
self.pending_inst = inst
780789
self.pending_buffer = buffer
781-
self.pending_on_partial_copy = on_partial_copy
782-
self.pending_on_copy_done = on_copy_done
790+
self.pending_on_copy = on_copy
791+
self.pending_on_result = on_result
783792

784-
def reset_and_notify_pending(self, why):
785-
pending_on_copy_done = self.pending_on_copy_done
793+
def reset_and_notify_pending(self, result):
794+
pending_on_result = self.pending_on_result
786795
self.reset_pending()
787-
pending_on_copy_done(why)
796+
pending_on_result(result)
788797

789798
def cancel(self):
790-
self.reset_and_notify_pending('cancelled')
799+
self.reset_and_notify_pending(StreamResult.CANCELLED)
791800

792801
def close(self):
793802
if not self.closed_:
794803
self.closed_ = True
795804
if self.pending_buffer:
796-
self.reset_and_notify_pending('completed')
805+
self.reset_and_notify_pending(StreamResult.CLOSED)
797806

798807
def closed(self):
799808
return self.closed_
800809

801-
def read(self, inst, dst, on_partial_copy, on_copy_done):
802-
return self.copy(inst, dst, on_partial_copy, on_copy_done, self.pending_buffer, dst)
810+
def read(self, inst, dst, on_copy, on_result):
811+
return self.copy(inst, dst, on_copy, on_result, self.pending_buffer, dst)
803812

804-
def write(self, inst, src, on_partial_copy, on_copy_done):
805-
return self.copy(inst, src, on_partial_copy, on_copy_done, src, self.pending_buffer)
813+
def write(self, inst, src, on_copy, on_result):
814+
return self.copy(inst, src, on_copy, on_result, src, self.pending_buffer)
806815

807-
def copy(self, inst, buffer, on_partial_copy, on_copy_done, src, dst):
816+
def copy(self, inst, buffer, on_copy, on_result, src, dst):
808817
if self.closed_:
809-
return 'done'
818+
return StreamResult.CLOSED
810819
elif not self.pending_buffer:
811-
self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
812-
return 'blocked'
820+
self.set_pending(inst, buffer, on_copy, on_result)
821+
return None
813822
else:
814823
assert(self.t == src.t == dst.t)
815824
trap_if(inst is self.pending_inst and self.t is not None) # temporary
816825
if self.pending_buffer.remain() > 0:
817826
if buffer.remain() > 0:
818827
dst.write(src.read(min(src.remain(), dst.remain())))
819-
if self.pending_buffer.remain() > 0:
820-
self.pending_on_partial_copy(self.reset_pending)
821-
else:
822-
self.reset_and_notify_pending('completed')
823-
return 'done'
828+
self.pending_on_copy(self.reset_pending)
829+
return StreamResult.COMPLETED
830+
elif buffer is src and buffer.remain() == 0 and self.pending_buffer.is_zero_length():
831+
return StreamResult.COMPLETED
824832
else:
825-
if buffer.remain() > 0 or buffer is dst:
826-
self.reset_and_notify_pending('completed')
827-
self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
828-
return 'blocked'
829-
else:
830-
return 'done'
833+
self.reset_and_notify_pending(StreamResult.COMPLETED)
834+
self.set_pending(inst, buffer, on_copy, on_result)
835+
return None
831836

832837
class StreamEnd(Waitable):
833838
shared: ReadableStream
@@ -844,34 +849,35 @@ def drop(self):
844849
Waitable.drop(self)
845850

846851
class ReadableStreamEnd(StreamEnd):
847-
def copy(self, inst, dst, on_partial_copy, on_copy_done):
848-
return self.shared.read(inst, dst, on_partial_copy, on_copy_done)
852+
def copy(self, inst, dst, on_copy, on_result):
853+
return self.shared.read(inst, dst, on_copy, on_result)
849854

850855
class WritableStreamEnd(StreamEnd):
851-
def copy(self, inst, src, on_partial_copy, on_copy_done):
852-
return self.shared.write(inst, src, on_partial_copy, on_copy_done)
856+
def copy(self, inst, src, on_copy, on_result):
857+
return self.shared.write(inst, src, on_copy, on_result)
853858

854859
#### Future State
855860

856861
class FutureEnd(StreamEnd):
857-
def close_after_copy(self, copy_op, inst, buffer, on_copy_done):
862+
def close_after_copy(self, copy_op, inst, buffer, on_result):
858863
assert(buffer.remain() == 1)
859-
def on_copy_done_wrapper(why):
864+
def on_result_wrapper(result):
860865
if buffer.remain() == 0:
861866
self.shared.close()
862-
on_copy_done(why)
863-
ret = copy_op(inst, buffer, on_partial_copy = None, on_copy_done = on_copy_done_wrapper)
864-
if ret == 'done' and buffer.remain() == 0:
867+
on_result(result)
868+
ret = copy_op(inst, buffer, on_copy = lambda _:(), on_result = on_result_wrapper)
869+
if ret is not None and buffer.remain() == 0:
865870
self.shared.close()
871+
return StreamResult.CLOSED
866872
return ret
867873

868874
class ReadableFutureEnd(FutureEnd):
869-
def copy(self, inst, dst, on_partial_copy, on_copy_done):
870-
return self.close_after_copy(self.shared.read, inst, dst, on_copy_done)
875+
def copy(self, inst, dst, on_copy, on_result):
876+
return self.close_after_copy(self.shared.read, inst, dst, on_result)
871877

872878
class WritableFutureEnd(FutureEnd):
873-
def copy(self, inst, src, on_partial_copy, on_copy_done):
874-
return self.close_after_copy(self.shared.write, inst, src, on_copy_done)
879+
def copy(self, inst, src, on_copy, on_result):
880+
return self.close_after_copy(self.shared.write, inst, src, on_result)
875881
def drop(self):
876882
trap_if(not self.shared.closed())
877883
FutureEnd.drop(self)
@@ -1802,43 +1808,36 @@ def lower_flat_flags(v, labels):
18021808
def lift_flat_values(cx, max_flat, vi, ts):
18031809
flat_types = flatten_types(ts)
18041810
if len(flat_types) > max_flat:
1805-
return lift_heap_values(cx, vi, ts)
1811+
ptr = vi.next('i32')
1812+
tuple_type = TupleType(ts)
1813+
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1814+
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
1815+
return list(load(cx, ptr, tuple_type).values())
18061816
else:
18071817
return [ lift_flat(cx, vi, t) for t in ts ]
18081818

1809-
def lift_heap_values(cx, vi, ts):
1810-
ptr = vi.next('i32')
1811-
tuple_type = TupleType(ts)
1812-
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1813-
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
1814-
return list(load(cx, ptr, tuple_type).values())
1815-
18161819
def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
18171820
cx.inst.may_leave = False
18181821
flat_types = flatten_types(ts)
18191822
if len(flat_types) > max_flat:
1820-
flat_vals = lower_heap_values(cx, vs, ts, out_param)
1823+
tuple_type = TupleType(ts)
1824+
tuple_value = {str(i): v for i,v in enumerate(vs)}
1825+
if out_param is None:
1826+
ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type))
1827+
flat_vals = [ptr]
1828+
else:
1829+
ptr = out_param.next('i32')
1830+
flat_vals = []
1831+
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1832+
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
1833+
store(cx, tuple_value, tuple_type, ptr)
18211834
else:
18221835
flat_vals = []
18231836
for i in range(len(vs)):
18241837
flat_vals += lower_flat(cx, vs[i], ts[i])
18251838
cx.inst.may_leave = True
18261839
return flat_vals
18271840

1828-
def lower_heap_values(cx, vs, ts, out_param):
1829-
tuple_type = TupleType(ts)
1830-
tuple_value = {str(i): v for i,v in enumerate(vs)}
1831-
if out_param is None:
1832-
ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type))
1833-
flat_vals = [ptr]
1834-
else:
1835-
ptr = out_param.next('i32')
1836-
flat_vals = []
1837-
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1838-
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
1839-
store(cx, tuple_value, tuple_type, ptr)
1840-
return flat_vals
1841-
18421841
### `canon lift`
18431842

18441843
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
@@ -2169,6 +2168,8 @@ async def canon_future_new(future_t, task):
21692168

21702169
### 🔀 `canon {stream,future}.{read,write}`
21712170

2171+
BLOCKED = 0xffff_ffff
2172+
21722173
async def canon_stream_read(stream_t, opts, task, i, ptr, n):
21732174
return await copy(ReadableStreamEnd, WritableBufferGuestImpl, EventCode.STREAM_READ,
21742175
stream_t, opts, task, i, ptr, n)
@@ -2196,20 +2197,21 @@ async def copy(EndT, BufferT, event_code, stream_or_future_t, opts, task, i, ptr
21962197
cx = LiftLowerContext(opts, task.inst, borrow_scope = None)
21972198
buffer = BufferT(stream_or_future_t.t, cx, ptr, n)
21982199

2199-
def copy_event(why, revoke_buffer):
2200+
def copy_event(result, revoke_buffer):
22002201
revoke_buffer()
22012202
assert(e.copying)
22022203
e.copying = False
2203-
return (event_code, i, pack_copy_result(task, e, buffer, why))
2204+
return (event_code, i, pack_copy_result(result, buffer))
22042205

2205-
def on_partial_copy(revoke_buffer):
2206-
e.set_event(partial(copy_event, 'completed', revoke_buffer))
2206+
def on_copy(revoke_buffer):
2207+
e.set_event(partial(copy_event, StreamResult.COMPLETED, revoke_buffer))
22072208

2208-
def on_copy_done(why):
2209-
e.set_event(partial(copy_event, why, revoke_buffer = lambda:()))
2209+
def on_result(result):
2210+
e.set_event(partial(copy_event, result, revoke_buffer = lambda:()))
22102211

2211-
if e.copy(task.inst, buffer, on_partial_copy, on_copy_done) == 'done':
2212-
return [pack_copy_result(task, e, buffer, 'completed')]
2212+
result = e.copy(task.inst, buffer, on_copy, on_result)
2213+
if result is not None:
2214+
return [pack_copy_result(result, buffer)]
22132215
else:
22142216
e.copying = True
22152217
if opts.sync:
@@ -2220,20 +2222,8 @@ def on_copy_done(why):
22202222
else:
22212223
return [BLOCKED]
22222224

2223-
BLOCKED = 0xffff_ffff
2224-
COMPLETED = 0x0
2225-
CLOSED = 0x1
2226-
CANCELLED = 0x2
2227-
2228-
def pack_copy_result(task, e, buffer, why):
2229-
if e.shared.closed():
2230-
result = CLOSED
2231-
elif why == 'cancelled':
2232-
result = CANCELLED
2233-
else:
2234-
assert(why == 'completed')
2235-
assert(not isinstance(e, FutureEnd))
2236-
result = COMPLETED
2225+
def pack_copy_result(result, buffer):
2226+
assert(0 <= result < 2**4)
22372227
assert(buffer.progress <= Buffer.MAX_LENGTH < 2**28)
22382228
packed = result | (buffer.progress << 4)
22392229
assert(packed != BLOCKED)

0 commit comments

Comments
 (0)