@@ -853,26 +853,98 @@ def copy(self, inst, src, on_partial_copy, on_copy_done):
853853
854854#### Future State
855855
856- class FutureEnd (StreamEnd ):
857- def close_after_copy (self , copy_op , inst , buffer , on_copy_done ):
858- assert (buffer .remain () == 1 )
859- def on_copy_done_wrapper (why ):
860- if buffer .remain () == 0 :
861- self .shared .close ()
862- on_copy_done (why )
863- ret = copy_op (inst , buffer , on_partial_copy = None , on_copy_done = on_copy_done_wrapper )
864- if ret == 'done' and buffer .remain () == 0 :
865- self .shared .close ()
866- return ret
856+ OnRead = Callable [[any | Literal ['cancelled' ]]], None ]
857+ OnWrite = Callable [[Optional [Literal ['cancelled' ,'dropped' ]]], Optional [any ]]
858+
859+ class ReadableFuture :
860+ t : ValType
861+ read : Callable [[ComponentInstance , OnRead ], Literal ['done' ,'blocked' ]]
862+ cancel : Callable [[], None ]
863+ drop : Callable [[]]
864+
865+ class SharedFutureImpl (ReadableFuture ):
866+ reader_dropped : bool
867+ pending_inst : Optional [ComponentInstance ]
868+ pending_on_read : Optional [OnRead ]
869+ pending_on_write : Optional [OnWrite ]
870+
871+ def __init__ (self , t ):
872+ self .t = t
873+ self .reader_dropped = False
874+ self .reset_pending ()
875+
876+ def reset_pending (self ):
877+ self .pending_inst = None
878+ self .pending_on_read = None
879+ self .pending_on_write = None
880+
881+ def cancel (self ):
882+ assert (self .pending_on_read ^ self .pending_on_write )
883+ if self .pending_on_read :
884+ self .pending_on_read ('cancelled' )
885+ else :
886+ self .pending_on_write ('cancelled' )
887+ self .reset_pending ()
888+
889+ def drop (self ):
890+ assert (not self .reader_dropped and not self .pending_on_read )
891+ self .reader_dropped = True
892+ if self .pending_on_write :
893+ self .pending_on_write ('dropped' )
894+ self .reset_pending ()
895+
896+ def read (self , inst , on_read ):
897+ assert (not self .reader_dropped and not self .pending_on_read )
898+ if not self .pending_on_write :
899+ self .pending_inst = inst
900+ self .pending_on_read = on_read
901+ return 'blocked'
902+ else :
903+ trap_if (inst is self .pending_inst and self .t is not None ) # temporary
904+ on_read (self .pending_write ())
905+ return 'done'
906+
907+ def write (self , inst , on_write ):
908+ assert (not self .pending_on_write )
909+ if self .reader_dropped :
910+ return 'done'
911+ elif not self .pending_on_read :
912+ self .pending_inst = inst
913+ self .pending_on_write = on_write
914+ else :
915+ trap_if (inst is self .pending_inst and self .t is not None ) # temporary
916+ self .pending_on_read (on_write ())
917+ return 'done'
918+
919+ class FutureEnd (Waitable ):
920+ shared : ReadableFuture
921+ copying : bool
922+ done : bool
923+
924+ def __init__ (self , shared ):
925+ Waitable .__init__ (self )
926+ self .shared = shared
927+ self .copying = False
928+ self .done = False
929+
930+ def drop (self ):
931+ trap_if (self .copying )
932+ Waitable .drop (self )
867933
868934class ReadableFutureEnd (FutureEnd ):
869- def copy (self , inst , dst , on_partial_copy , on_copy_done ):
870- return self .close_after_copy (self .shared .read , inst , dst , on_copy_done )
935+ def copy (self , inst , on_read ):
936+ return self .shared .read (inst , on_read )
937+
938+ def drop (self ):
939+ self .shared .drop ()
940+ FutureEnd .drop (self )
871941
872942class WritableFutureEnd (FutureEnd ):
873- def copy (self , inst , src , on_partial_copy , on_copy_done ):
874- return self .close_after_copy (self .shared .write , inst , src , on_copy_done )
943+ def copy (self , inst , on_write ):
944+ return self .shared .write (inst , on_write )
945+
875946 def drop (self ):
947+ trap_if (not self .done )
876948 FutureEnd .drop (self )
877949
878950### Despecialization
@@ -2158,50 +2230,42 @@ async def canon_stream_new(stream_t, task):
21582230
21592231async def canon_future_new (future_t , task ):
21602232 trap_if (not task .inst .may_leave )
2161- shared = SharedStreamImpl (future_t .t )
2233+ shared = SharedFutureImpl (future_t .t )
21622234 ri = task .inst .table .add (ReadableFutureEnd (shared ))
21632235 wi = task .inst .table .add (WritableFutureEnd (shared ))
21642236 return [ ri | (wi << 32 ) ]
21652237
2166- ### 🔀 `canon { stream,future} .{read,write}`
2238+ ### 🔀 `canon stream.{read,write}`
21672239
21682240async def canon_stream_read (stream_t , opts , task , i , ptr , n ):
2169- return await copy (ReadableStreamEnd , WritableBufferGuestImpl , EventCode .STREAM_READ ,
2170- stream_t , opts , task , i , ptr , n )
2241+ return await stream_copy (ReadableStreamEnd , WritableBufferGuestImpl , EventCode .STREAM_READ ,
2242+ stream_t , opts , task , i , ptr , n )
21712243
21722244async def canon_stream_write (stream_t , opts , task , i , ptr , n ):
2173- return await copy (WritableStreamEnd , ReadableBufferGuestImpl , EventCode .STREAM_WRITE ,
2174- stream_t , opts , task , i , ptr , n )
2175-
2176- async def canon_future_read (future_t , opts , task , i , ptr ):
2177- return await copy (ReadableFutureEnd , WritableBufferGuestImpl , EventCode .FUTURE_READ ,
2178- future_t , opts , task , i , ptr , 1 )
2179-
2180- async def canon_future_write (future_t , opts , task , i , ptr ):
2181- return await copy (WritableFutureEnd , ReadableBufferGuestImpl , EventCode .FUTURE_WRITE ,
2182- future_t , opts , task , i , ptr , 1 )
2245+ return await stream_copy (WritableStreamEnd , ReadableBufferGuestImpl , EventCode .STREAM_WRITE ,
2246+ stream_t , opts , task , i , ptr , n )
21832247
2184- async def copy (EndT , BufferT , event_code , stream_or_future_t , opts , task , i , ptr , n ):
2248+ async def stream_copy (EndT , BufferT , event_code , stream_t , opts , task , i , ptr , n ):
21852249 trap_if (not task .inst .may_leave )
21862250 e = task .inst .table .get (i )
21872251 trap_if (not isinstance (e , EndT ))
2188- trap_if (e .shared .t != stream_or_future_t .t )
2252+ trap_if (e .shared .t != stream_t .t )
21892253 trap_if (e .copying )
21902254
2191- assert (not contains_borrow (stream_or_future_t ))
2255+ assert (not contains_borrow (stream_t ))
21922256 cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2193- buffer = BufferT (stream_or_future_t .t , cx , ptr , n )
2257+ buffer = BufferT (stream_t .t , cx , ptr , n )
21942258
2195- def copy_event (why , revoke_buffer ):
2259+ def stream_copy_event (why , revoke_buffer ):
21962260 revoke_buffer ()
21972261 e .copying = False
21982262 return (event_code , i , pack_copy_result (task , e , buffer , why ))
21992263
22002264 def on_partial_copy (revoke_buffer ):
2201- e .set_event (partial (copy_event , 'completed' , revoke_buffer ))
2265+ e .set_event (partial (stream_copy_event , 'completed' , revoke_buffer ))
22022266
22032267 def on_copy_done (why ):
2204- e .set_event (partial (copy_event , why , revoke_buffer = lambda :()))
2268+ e .set_event (partial (stream_copy_event , why , revoke_buffer = lambda :()))
22052269
22062270 if e .copy (task .inst , buffer , on_partial_copy , on_copy_done ) == 'done' :
22072271 return [pack_copy_result (task , e , buffer , 'completed' )]
@@ -2234,6 +2298,55 @@ def pack_copy_result(task, e, buffer, why):
22342298 assert (packed != BLOCKED )
22352299 return packed
22362300
2301+ ### 🔀 `canon future.{read,write}`
2302+
2303+ async def canon_future_read (future_t , opts , task , i , ptr ):
2304+ assert (not contains_borrow (stream_t ))
2305+ cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2306+ def lower (v ):
2307+ store (cx , v , future_t .t , ptr )
2308+ return await future_copy (ReadableFutureEnd , EventCode .FUTURE_READ , lower ,
2309+ future_t , opts , task , i , ptr )
2310+
2311+ async def canon_future_write (future_t , opts , task , i , ptr ):
2312+ assert (not contains_borrow (stream_t ))
2313+ cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2314+ def lift (_ ):
2315+ return load (cx , ptr , future_t .t )
2316+ return await future_copy (WritableFutureEnd , EventCode .FUTURE_WRITE , lift ,
2317+ future_t , opts , task , i , ptr )
2318+
2319+ async def future_copy (EndT , event_code , copy , future_t , opts , task , i , ptr ):
2320+ trap_if (not task .inst .may_leave )
2321+ e = task .inst .table .get (i )
2322+ trap_if (not isinstance (e , EndT ))
2323+ trap_if (e .shared .t != future_t .t )
2324+ trap_if (e .copying or e .done )
2325+
2326+ def future_copy_event (why ):
2327+ e .copying = False
2328+ if why != 'cancelled' :
2329+ e .done = True
2330+ return (event_code , i , pack_copy_result (task , e , buffer , why ))
2331+
2332+ def on_copy (why , v ):
2333+ assert (not e .has_pending_event ())
2334+ e .set_event (partial (future_copy_event , why ))
2335+ if why != 'dropped' and why != 'cancelled' :
2336+ return copy (v )
2337+
2338+ if e .copy (task .inst , on_copy ) == 'done' :
2339+ return [pack_copy_result (task , e , buffer , 'completed' )]
2340+ else :
2341+ if opts .sync :
2342+ await task .wait_on (e .wait_for_pending_event (), sync = True )
2343+ code ,index ,payload = e .get_event ()
2344+ assert (code == event_code and index == i )
2345+ return [payload ]
2346+ else :
2347+ e .copying = True
2348+ return [BLOCKED ]
2349+
22372350### 🔀 `canon {stream,future}.cancel-{read,write}`
22382351
22392352async def canon_stream_cancel_read (stream_t , sync , task , i ):
0 commit comments