@@ -88,6 +88,7 @@ class InstanceType(ExternType):
8888class FuncType (ExternType ):
8989 params : list [tuple [str ,ValType ]]
9090 result : list [ValType | tuple [str ,ValType ]]
91+ async_ : bool = False
9192 def param_types (self ):
9293 return self .extract_types (self .params )
9394 def result_type (self ):
@@ -402,6 +403,7 @@ def resume(self, suspend_result = SuspendResult.NOT_CANCELLED):
402403 assert (not self .running ())
403404
404405 def suspend (self , cancellable ) -> SuspendResult :
406+ assert (self .task .may_suspend ())
405407 assert (self .running () and not self .cancellable and self .suspend_result is None )
406408 self .cancellable = cancellable
407409 self .parent_lock .release ()
@@ -566,8 +568,13 @@ def trap_if_on_the_stack(self, inst):
566568 def needs_exclusive (self ):
567569 return not self .opts .async_ or self .opts .callback
568570
571+ def may_suspend (self ):
572+ return self .ft .async_ or self .state == Task .State .RESOLVED
573+
569574 def enter (self , thread ):
570575 assert (thread in self .threads and thread .task is self )
576+ if not self .ft .async_ :
577+ return True
571578 def has_backpressure ():
572579 return self .inst .backpressure > 0 or (self .needs_exclusive () and self .inst .exclusive )
573580 if has_backpressure () or self .inst .num_waiting_to_enter > 0 :
@@ -584,6 +591,8 @@ def has_backpressure():
584591
585592 def exit (self ):
586593 assert (len (self .threads ) > 0 )
594+ if not self .ft .async_ :
595+ return
587596 if self .needs_exclusive ():
588597 assert (self .inst .exclusive )
589598 self .inst .exclusive = False
@@ -2023,12 +2032,17 @@ def thread_func(thread):
20232032 inst .exclusive = False
20242033 match code :
20252034 case CallbackCode .YIELD :
2026- event = task .yield_until (lambda : not inst .exclusive , thread , cancellable = True )
2035+ if not thread .task .may_suspend ():
2036+ event = (EventCode .NONE , 0 , 0 )
2037+ else :
2038+ event = task .yield_until (lambda : not inst .exclusive , thread , cancellable = True )
20272039 case CallbackCode .WAIT :
2040+ trap_if (not thread .task .may_suspend ())
20282041 wset = inst .table .get (si )
20292042 trap_if (not isinstance (wset , WaitableSet ))
20302043 event = task .wait_until (lambda : not inst .exclusive , thread , wset , cancellable = True )
20312044 case CallbackCode .POLL :
2045+ trap_if (not thread .task .may_suspend ())
20322046 wset = inst .table .get (si )
20332047 trap_if (not isinstance (wset , WaitableSet ))
20342048 event = task .poll_until (lambda : not inst .exclusive , thread , wset , cancellable = True )
@@ -2069,6 +2083,7 @@ def call_and_trap_on_throw(callee, thread, args):
20692083
20702084def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20712085 trap_if (not thread .task .inst .may_leave )
2086+ trap_if (ft .async_ and not opts .async_ and not thread .task .may_suspend ())
20722087 subtask = Subtask ()
20732088 cx = LiftLowerContext (opts , thread .task .inst , subtask )
20742089
@@ -2108,6 +2123,7 @@ def on_resolve(result):
21082123 flat_results = lower_flat_values (cx , max_flat_results , result , ft .result_type (), flat_args )
21092124
21102125 subtask .callee = callee (thread .task , on_start , on_resolve )
2126+ assert (ft .async_ or subtask .resolved ())
21112127
21122128 if not opts .async_ :
21132129 if not subtask .resolved ():
@@ -2142,31 +2158,30 @@ def canon_resource_new(rt, thread, rep):
21422158
21432159### `canon resource.drop`
21442160
2145- def canon_resource_drop (rt , async_ , thread , i ):
2161+ def canon_resource_drop (rt , thread , i ):
21462162 trap_if (not thread .task .inst .may_leave )
21472163 inst = thread .task .inst
21482164 h = inst .table .remove (i )
21492165 trap_if (not isinstance (h , ResourceHandle ))
21502166 trap_if (h .rt is not rt )
21512167 trap_if (h .num_lends != 0 )
2152- flat_results = [] if not async_ else [0 ]
21532168 if h .own :
21542169 assert (h .borrow_scope is None )
21552170 if inst is rt .impl :
21562171 if rt .dtor :
21572172 rt .dtor (h .rep )
21582173 else :
21592174 if rt .dtor :
2160- caller_opts = CanonicalOptions (async_ = async_ )
2175+ caller_opts = CanonicalOptions (async_ = False )
21612176 callee_opts = CanonicalOptions (async_ = rt .dtor_async , callback = rt .dtor_callback )
2162- ft = FuncType ([U32Type ()],[])
2177+ ft = FuncType ([U32Type ()],[], async_ = False )
21632178 callee = partial (canon_lift , callee_opts , rt .impl , ft , rt .dtor )
2164- flat_results = canon_lower (caller_opts , ft , callee , thread , [h .rep ])
2179+ [] = canon_lower (caller_opts , ft , callee , thread , [h .rep ])
21652180 else :
21662181 thread .task .trap_if_on_the_stack (rt .impl )
21672182 else :
21682183 h .borrow_scope .num_borrows -= 1
2169- return flat_results
2184+ return []
21702185
21712186### `canon resource.rep`
21722187
@@ -2244,6 +2259,7 @@ def canon_waitable_set_new(thread):
22442259
22452260def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22462261 trap_if (not thread .task .inst .may_leave )
2262+ trap_if (not thread .task .may_suspend ())
22472263 wset = thread .task .inst .table .get (si )
22482264 trap_if (not isinstance (wset , WaitableSet ))
22492265 event = thread .task .wait_until (lambda : True , thread , wset , cancellable )
@@ -2260,6 +2276,7 @@ def unpack_event(mem, thread, ptr, e: EventTuple):
22602276
22612277def canon_waitable_set_poll (cancellable , mem , thread , si , ptr ):
22622278 trap_if (not thread .task .inst .may_leave )
2279+ trap_if (not thread .task .may_suspend ())
22632280 wset = thread .task .inst .table .get (si )
22642281 trap_if (not isinstance (wset , WaitableSet ))
22652282 event = thread .task .poll_until (lambda : True , thread , wset , cancellable )
@@ -2294,6 +2311,7 @@ def canon_waitable_join(thread, wi, si):
22942311
22952312def canon_subtask_cancel (async_ , thread , i ):
22962313 trap_if (not thread .task .inst .may_leave )
2314+ trap_if (not async_ and not thread .task .may_suspend ())
22972315 subtask = thread .task .inst .table .get (i )
22982316 trap_if (not isinstance (subtask , Subtask ))
22992317 trap_if (subtask .resolve_delivered ())
@@ -2350,6 +2368,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23502368
23512369def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23522370 trap_if (not thread .task .inst .may_leave )
2371+ trap_if (not opts .async_ and not thread .task .may_suspend ())
23532372 e = thread .task .inst .table .get (i )
23542373 trap_if (not isinstance (e , EndT ))
23552374 trap_if (e .shared .t != stream_t .t )
@@ -2401,6 +2420,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24012420
24022421def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24032422 trap_if (not thread .task .inst .may_leave )
2423+ trap_if (not opts .async_ and not thread .task .may_suspend ())
24042424 e = thread .task .inst .table .get (i )
24052425 trap_if (not isinstance (e , EndT ))
24062426 trap_if (e .shared .t != future_t .t )
@@ -2451,6 +2471,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24512471
24522472def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24532473 trap_if (not thread .task .inst .may_leave )
2474+ trap_if (not async_ and not thread .task .may_suspend ())
24542475 e = thread .task .inst .table .get (i )
24552476 trap_if (not isinstance (e , EndT ))
24562477 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2527,6 +2548,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25272548
25282549def canon_thread_suspend (cancellable , thread ):
25292550 trap_if (not thread .task .inst .may_leave )
2551+ trap_if (not thread .task .may_suspend ())
25302552 suspend_result = thread .task .suspend (thread , cancellable )
25312553 return [suspend_result ]
25322554
@@ -2554,6 +2576,8 @@ def canon_thread_yield_to(cancellable, thread, i):
25542576
25552577def canon_thread_yield (cancellable , thread ):
25562578 trap_if (not thread .task .inst .may_leave )
2579+ if not thread .task .may_suspend ():
2580+ return [SuspsendResult .COMPLETED ]
25572581 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
25582582 match event_code :
25592583 case EventCode .NONE :
0 commit comments