@@ -853,26 +853,97 @@ def copy(self, inst, src, on_partial_copy, on_copy_done):
853853
854854#### Future State
855855
856- class FutureEnd (StreamEnd ):
857- def close_after_copy (self , copy_op , inst , buffer , on_copy_done ):
858- assert (buffer .remain () == 1 )
859- def on_copy_done_wrapper (why ):
860- if buffer .remain () == 0 :
861- self .shared .close ()
862- on_copy_done (why )
863- ret = copy_op (inst , buffer , on_partial_copy = None , on_copy_done = on_copy_done_wrapper )
864- if ret == 'done' and buffer .remain () == 0 :
865- self .shared .close ()
866- return ret
856+ LiftValue = Callable [[], any ]
857+ LowerValue = Callable [[any ], None ]
858+
859+ class ReadableFuture :
860+ t : ValType
861+ read : Callable [[ComponentInstance , LowerValue , OnCopyDone ], Literal ['done' ,'blocked' ]]
862+ cancel : Callable [[], None ]
863+ drop : Callable [[]]
864+
865+ class SharedFutureImpl (ReadableFuture ):
866+ dropped : bool
867+ pending_inst : Optional [ComponentInstance ]
868+ pending_copy_value : Optional [LiftValue | LowerValue ]
869+ pending_on_copy_done : Optional [OnCopyDone ]
870+
871+ def __init__ (self , t ):
872+ self .t = t
873+ self .dropped = False
874+ self .reset_pending ()
875+
876+ def reset_pending (self ):
877+ self .set_pending (None , None , None )
878+
879+ def set_pending (self , inst , copy_value , on_copy_done ):
880+ self .pending_inst = inst
881+ self .pending_copy_value = copy_value
882+ self .pending_on_copy_done = on_copy_done
883+
884+ def reset_and_notify_pending (self , why ):
885+ pending_on_copy_done = self .pending_on_copy_done
886+ self .reset_pending ()
887+ pending_on_copy_done (why )
888+
889+ def cancel (self ):
890+ assert (not self .dropped )
891+ self .reset_and_notify_pending ('cancelled' )
892+
893+ def drop (self ):
894+ assert (not self .dropped )
895+ self .dropped = True
896+ if self .pending_on_copy_done :
897+ self .reset_and_notify_pending ('closed' )
898+
899+ def read (self , inst , lower_value , on_copy_done ):
900+ assert (not self .dropped )
901+ return self .copy (inst , lower_value , on_copy_done , self .pending_copy_value , lower_value )
902+
903+ def write (self , inst , lift_value , on_copy_done ):
904+ if self .dropped :
905+ return 'closed'
906+ return self .copy (inst , lift_value , on_copy_done , lift_value , self .pending_copy_value )
907+
908+ def copy (self , inst , copy_value , on_copy_done , lift_value , lower_value ):
909+ if not self .pending_copy_value :
910+ self .set_pending (inst , copy_value , on_copy_done )
911+ return 'blocked'
912+ else :
913+ trap_if (inst is self .pending_inst and self .t is not None ) # temporary
914+ lower_value (lift_value ())
915+ self .reset_and_notify_pending ('done' )
916+ return 'done'
917+
918+ class FutureEnd (Waitable ):
919+ shared : ReadableFuture
920+ copying : bool
921+ done : bool
922+
923+ def __init__ (self , shared ):
924+ Waitable .__init__ (self )
925+ self .shared = shared
926+ self .copying = False
927+ self .done = False
928+
929+ def drop (self ):
930+ trap_if (self .copying )
931+ Waitable .drop (self )
867932
868933class ReadableFutureEnd (FutureEnd ):
869- def copy (self , inst , dst , on_partial_copy , on_copy_done ):
870- return self .close_after_copy (self .shared .read , inst , dst , on_copy_done )
934+ def copy (self , inst , lower_value , on_copy_done ):
935+ return self .shared .read (inst , lower_value , on_copy_done )
936+
937+ def drop (self ):
938+ self .shared .drop ()
939+ FutureEnd .drop (self )
871940
872941class WritableFutureEnd (FutureEnd ):
873- def copy (self , inst , src , on_partial_copy , on_copy_done ):
874- return self .close_after_copy (self .shared .write , inst , src , on_copy_done )
942+ def copy (self , inst , lift_value , on_copy_done ):
943+ return self .shared .write (inst , lift_value , on_copy_done )
944+
875945 def drop (self ):
946+ trap_if (not self .done )
876947 FutureEnd .drop (self )
877948
878949### Despecialization
@@ -1201,19 +1272,20 @@ def lift_borrow(cx, i, t):
12011272 return h .rep
12021273
12031274def lift_stream (cx , i , t ):
1204- return lift_async_value (ReadableStreamEnd , cx , i , t )
1275+ assert (not contains_borrow (t ))
1276+ e = cx .inst .table .remove (i )
1277+ trap_if (not isinstance (e , ReadableStreamEnd ))
1278+ trap_if (e .shared .t != t )
1279+ trap_if (e .copying )
1280+ return e .shared
12051281
12061282def lift_future (cx , i , t ):
1207- v = lift_async_value (ReadableFutureEnd , cx , i , t )
1208- trap_if (v .closed ())
1209- return v
1210-
1211- def lift_async_value (ReadableEndT , cx , i , t ):
12121283 assert (not contains_borrow (t ))
12131284 e = cx .inst .table .remove (i )
1214- trap_if (not isinstance (e , ReadableEndT ))
1285+ trap_if (not isinstance (e , ReadableFutureEnd ))
12151286 trap_if (e .shared .t != t )
12161287 trap_if (e .copying )
1288+ trap_if (e .done )
12171289 return e .shared
12181290
12191291### Storing
@@ -1507,16 +1579,14 @@ def lower_borrow(cx, rep, t):
15071579 return cx .inst .table .add (h )
15081580
15091581def lower_stream (cx , v , t ):
1510- return lower_async_value (ReadableStreamEnd , cx , v , t )
1582+ assert (isinstance (v , ReadableStream ))
1583+ assert (not contains_borrow (t ))
1584+ return cx .inst .table .add (ReadableStreamEnd (v ))
15111585
15121586def lower_future (cx , v , t ):
1513- assert (not v .closed ())
1514- return lower_async_value (ReadableFutureEnd , cx , v , t )
1515-
1516- def lower_async_value (ReadableEndT , cx , v , t ):
1517- assert (isinstance (v , ReadableStream ))
1587+ assert (isinstance (v , ReadableFuture ))
15181588 assert (not contains_borrow (t ))
1519- return cx .inst .table .add (ReadableEndT (v ))
1589+ return cx .inst .table .add (ReadableFutureEnd (v ))
15201590
15211591### Flattening
15221592
@@ -2158,45 +2228,37 @@ async def canon_stream_new(stream_t, task):
21582228
21592229async def canon_future_new (future_t , task ):
21602230 trap_if (not task .inst .may_leave )
2161- shared = SharedStreamImpl (future_t .t )
2231+ shared = SharedFutureImpl (future_t .t )
21622232 ri = task .inst .table .add (ReadableFutureEnd (shared ))
21632233 wi = task .inst .table .add (WritableFutureEnd (shared ))
21642234 return [ ri | (wi << 32 ) ]
21652235
2166- ### 🔀 `canon { stream,future} .{read,write}`
2236+ ### 🔀 `canon stream.{read,write}`
21672237
21682238async def canon_stream_read (stream_t , opts , task , i , ptr , n ):
2169- return await copy (ReadableStreamEnd , WritableBufferGuestImpl , EventCode .STREAM_READ ,
2170- stream_t , opts , task , i , ptr , n )
2239+ return await stream_copy (ReadableStreamEnd , WritableBufferGuestImpl , EventCode .STREAM_READ ,
2240+ stream_t , opts , task , i , ptr , n )
21712241
21722242async def canon_stream_write (stream_t , opts , task , i , ptr , n ):
2173- return await copy (WritableStreamEnd , ReadableBufferGuestImpl , EventCode .STREAM_WRITE ,
2174- stream_t , opts , task , i , ptr , n )
2175-
2176- async def canon_future_read (future_t , opts , task , i , ptr ):
2177- return await copy (ReadableFutureEnd , WritableBufferGuestImpl , EventCode .FUTURE_READ ,
2178- future_t , opts , task , i , ptr , 1 )
2179-
2180- async def canon_future_write (future_t , opts , task , i , ptr ):
2181- return await copy (WritableFutureEnd , ReadableBufferGuestImpl , EventCode .FUTURE_WRITE ,
2182- future_t , opts , task , i , ptr , 1 )
2243+ return await stream_copy (WritableStreamEnd , ReadableBufferGuestImpl , EventCode .STREAM_WRITE ,
2244+ stream_t , opts , task , i , ptr , n )
21832245
2184- async def copy (EndT , BufferT , event_code , stream_or_future_t , opts , task , i , ptr , n ):
2246+ async def stream_copy (EndT , BufferT , event_code , stream_t , opts , task , i , ptr , n ):
21852247 trap_if (not task .inst .may_leave )
21862248 e = task .inst .table .get (i )
21872249 trap_if (not isinstance (e , EndT ))
2188- trap_if (e .shared .t != stream_or_future_t .t )
2250+ trap_if (e .shared .t != stream_t .t )
21892251 trap_if (e .copying )
21902252
2191- assert (not contains_borrow (stream_or_future_t ))
2253+ assert (not contains_borrow (stream_t ))
21922254 cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2193- buffer = BufferT (stream_or_future_t .t , cx , ptr , n )
2255+ buffer = BufferT (stream_t .t , cx , ptr , n )
21942256
21952257 def copy_event (why , revoke_buffer ):
21962258 revoke_buffer ()
21972259 assert (e .copying )
21982260 e .copying = False
2199- return (event_code , i , pack_copy_result (task , e , buffer , why ))
2261+ return (event_code , i , pack_stream_result (task , e , buffer , why ))
22002262
22012263 def on_partial_copy (revoke_buffer ):
22022264 e .set_event (partial (copy_event , 'completed' , revoke_buffer ))
@@ -2205,7 +2267,7 @@ def on_copy_done(why):
22052267 e .set_event (partial (copy_event , why , revoke_buffer = lambda :()))
22062268
22072269 if e .copy (task .inst , buffer , on_partial_copy , on_copy_done ) == 'done' :
2208- return [pack_copy_result (task , e , buffer , 'completed' )]
2270+ return [pack_stream_result (task , e , buffer , 'completed' )]
22092271 else :
22102272 e .copying = True
22112273 if opts .sync :
@@ -2221,7 +2283,7 @@ def on_copy_done(why):
22212283CLOSED = 0x1
22222284CANCELLED = 0x2
22232285
2224- def pack_copy_result (task , e , buffer , why ):
2286+ def pack_stream_result (task , e , buffer , why ):
22252287 if e .shared .closed ():
22262288 result = CLOSED
22272289 elif why == 'cancelled' :
@@ -2235,6 +2297,66 @@ def pack_copy_result(task, e, buffer, why):
22352297 assert (packed != BLOCKED )
22362298 return packed
22372299
2300+ ### 🔀 `canon future.{read,write}`
2301+
2302+ async def canon_future_read (future_t , opts , task , i , ptr ):
2303+ def lower_value (v ):
2304+ if future_t .t :
2305+ assert (not contains_borrow (future_t ))
2306+ cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2307+ store (cx , v , future_t .t , ptr )
2308+
2309+ return await future_copy (ReadableFutureEnd , EventCode .FUTURE_READ , lower_value ,
2310+ future_t , opts , task , i , ptr )
2311+
2312+ async def canon_future_write (future_t , opts , task , i , ptr ):
2313+ def lift_value ():
2314+ if future_t .t :
2315+ assert (not contains_borrow (future_t ))
2316+ cx = LiftLowerContext (opts , task .inst , borrow_scope = None )
2317+ return load (cx , ptr , future_t .t )
2318+
2319+ return await future_copy (WritableFutureEnd , EventCode .FUTURE_WRITE , lift_value ,
2320+ future_t , opts , task , i , ptr )
2321+
2322+ async def future_copy (EndT , event_code , copy_value , future_t , opts , task , i , ptr ):
2323+ trap_if (not task .inst .may_leave )
2324+ e = task .inst .table .get (i )
2325+ trap_if (not isinstance (e , EndT ))
2326+ trap_if (e .shared .t != future_t .t )
2327+ trap_if (e .copying or e .done )
2328+
2329+ def on_copy_done (why ):
2330+ def copy_event ():
2331+ assert (e .copying )
2332+ e .copying = False
2333+ if why != 'cancelled' :
2334+ e .done = True
2335+ return (event_code , i , pack_future_result (task , e , why ))
2336+ assert (not e .has_pending_event ())
2337+ e .set_event (copy_event )
2338+
2339+ result = e .copy (task .inst , copy_value , on_copy_done )
2340+ if result != 'blocked' :
2341+ e .done = True
2342+ return [pack_future_result (task , e , result )]
2343+ else :
2344+ e .copying = True
2345+ if opts .sync :
2346+ await task .wait_on (e .wait_for_pending_event (), sync = True )
2347+ code ,index ,payload = e .get_event ()
2348+ assert (code == event_code and index == i )
2349+ return [payload ]
2350+ else :
2351+ return [BLOCKED ]
2352+
2353+ def pack_future_result (task , e , why ):
2354+ match why :
2355+ case 'cancelled' : return CANCELLED
2356+ case 'closed' : return CLOSED
2357+ case 'done' : return (CLOSED | (1 << 4 ))
2358+ assert (False )
2359+
22382360### 🔀 `canon {stream,future}.cancel-{read,write}`
22392361
22402362async def canon_stream_cancel_read (stream_t , sync , task , i ):
0 commit comments