@@ -756,7 +756,9 @@ pub(crate) fn poll_and_block<R: Send + Sync + 'static>(
756756 let result = future. await ?;
757757 tls:: get ( move |store| {
758758 let state = store. concurrent_state_mut ( ) ;
759- state. get_mut ( task) ?. result = Some ( Box :: new ( result) as _ ) ;
759+ let host_state = & mut state. get_mut ( task) ?. state ;
760+ assert ! ( matches!( host_state, HostTaskState :: CalleeStarted ) ) ;
761+ * host_state = HostTaskState :: CalleeFinished ( Box :: new ( result) ) ;
760762
761763 Waitable :: Host ( task) . set_event (
762764 state,
@@ -808,14 +810,11 @@ pub(crate) fn poll_and_block<R: Send + Sync + 'static>(
808810 }
809811
810812 // Retrieve and return the result.
811- Ok ( * store
812- . concurrent_state_mut ( )
813- . get_mut ( task) ?
814- . result
815- . take ( )
816- . unwrap ( )
817- . downcast ( )
818- . unwrap ( ) )
813+ let host_state = & mut store. concurrent_state_mut ( ) . get_mut ( task) ?. state ;
814+ match mem:: replace ( host_state, HostTaskState :: CalleeDone ) {
815+ HostTaskState :: CalleeFinished ( result) => Ok ( * result. downcast ( ) . unwrap ( ) ) ,
816+ _ => panic ! ( "unexpected host task state after completion" ) ,
817+ }
819818}
820819
821820/// Execute the specified guest call.
@@ -1550,7 +1549,7 @@ impl StoreOpaque {
15501549 pub fn enter_host_call ( & mut self ) -> Result < ( ) > {
15511550 let state = self . concurrent_state_mut ( ) ;
15521551 let caller = state. unwrap_current_guest_thread ( ) ;
1553- let task = state. push ( HostTask :: new ( caller) ) ?;
1552+ let task = state. push ( HostTask :: new ( caller, HostTaskState :: CalleeStarted ) ) ?;
15541553 log:: trace!( "new host task {task:?}" ) ;
15551554 self . set_thread ( task) ;
15561555 Ok ( ( ) )
@@ -2736,12 +2735,13 @@ impl Instance {
27362735 ///
27372736 /// Whether the future returns `Ready` immediately or later, the `lower`
27382737 /// function will be used to lower the result, if any, into the guest caller's
2739- /// stack and linear memory unless the task has been cancelled.
2738+ /// stack and linear memory. The `lower` function is invoked with `None` if
2739+ /// the future is cancelled.
27402740 pub ( crate ) fn first_poll < T : ' static , R : Send + ' static > (
27412741 self ,
27422742 mut store : StoreContextMut < ' _ , T > ,
27432743 future : impl Future < Output = Result < R > > + Send + ' static ,
2744- lower : impl FnOnce ( StoreContextMut < T > , R ) -> Result < ( ) > + Send + ' static ,
2744+ lower : impl FnOnce ( StoreContextMut < T > , Option < R > ) -> Result < ( ) > + Send + ' static ,
27452745 ) -> Result < Option < u32 > > {
27462746 let token = StoreToken :: new ( store. as_context_mut ( ) ) ;
27472747 let state = store. 0 . concurrent_state_mut ( ) ;
@@ -2751,9 +2751,9 @@ impl Instance {
27512751 // context state for the future.
27522752 let ( join_handle, future) = JoinHandle :: run ( future) ;
27532753 {
2754- let task = state. get_mut ( task) ?;
2755- assert ! ( task . join_handle . is_none ( ) ) ;
2756- task . join_handle = Some ( join_handle) ;
2754+ let state = & mut state. get_mut ( task) ?. state ;
2755+ assert ! ( matches! ( state , HostTaskState :: CalleeStarted ) ) ;
2756+ * state = HostTaskState :: CalleeRunning ( join_handle) ;
27572757 }
27582758
27592759 let mut future = Box :: pin ( future) ;
@@ -2771,7 +2771,7 @@ impl Instance {
27712771 match poll {
27722772 // It finished immediately; lower the result and delete the task.
27732773 Poll :: Ready ( Some ( result) ) => {
2774- lower ( store. as_context_mut ( ) , result?) ?;
2774+ lower ( store. as_context_mut ( ) , Some ( result?) ) ?;
27752775 return Ok ( None ) ;
27762776 }
27772777
@@ -2792,9 +2792,8 @@ impl Instance {
27922792 // the task returned.
27932793 let future = Box :: pin ( async move {
27942794 let result = match future. await {
2795- Some ( result) => result?,
2796- // Task was cancelled; nothing left to do.
2797- None => return Ok ( ( ) ) ,
2795+ Some ( result) => Some ( result?) ,
2796+ None => None ,
27982797 } ;
27992798 let on_complete = move |store : & mut dyn VMStore | {
28002799 // Restore the `current_thread` to be the host so `lower` knows
@@ -2805,15 +2804,16 @@ impl Instance {
28052804 assert ! ( state. current_thread. is_none( ) ) ;
28062805 store. 0 . set_thread ( task) ;
28072806
2807+ let status = if result. is_some ( ) {
2808+ Status :: Returned
2809+ } else {
2810+ Status :: ReturnCancelled
2811+ } ;
2812+
28082813 lower ( store. as_context_mut ( ) , result) ?;
28092814 let state = store. 0 . concurrent_state_mut ( ) ;
2810- state. get_mut ( task) ?. join_handle . take ( ) ;
2811- Waitable :: Host ( task) . set_event (
2812- state,
2813- Some ( Event :: Subtask {
2814- status : Status :: Returned ,
2815- } ) ,
2816- ) ?;
2815+ state. get_mut ( task) ?. state = HostTaskState :: CalleeDone ;
2816+ Waitable :: Host ( task) . set_event ( state, Some ( Event :: Subtask { status } ) ) ?;
28172817
28182818 // Go back to "no current thread" at the end.
28192819 store. 0 . set_thread ( CurrentThread :: None ) ;
@@ -3059,8 +3059,12 @@ impl Instance {
30593059 let ( waitable, expected_caller, delete) = if is_host {
30603060 let id = TableId :: < HostTask > :: new ( rep) ;
30613061 let task = concurrent_state. get_mut ( id) ?;
3062- if task. join_handle . is_some ( ) {
3063- bail ! ( "cannot drop a subtask which has not yet resolved" ) ;
3062+ match & task. state {
3063+ HostTaskState :: CalleeRunning ( _) => {
3064+ bail ! ( "cannot drop a subtask which has not yet resolved" ) ;
3065+ }
3066+ HostTaskState :: CalleeDone => { }
3067+ HostTaskState :: CalleeStarted | HostTaskState :: CalleeFinished ( _) => unreachable ! ( ) ,
30643068 }
30653069 ( Waitable :: Host ( id) , task. caller , true )
30663070 } else {
@@ -3537,11 +3541,30 @@ impl Instance {
35373541
35383542 log:: trace!( "subtask_cancel {waitable:?} (handle {task_id})" ) ;
35393543
3544+ let needs_block;
35403545 if let Waitable :: Host ( host_task) = waitable {
3541- if let Some ( handle) = concurrent_state. get_mut ( host_task) ?. join_handle . take ( ) {
3542- handle. abort ( ) ;
3543- return Ok ( Status :: ReturnCancelled as u32 ) ;
3546+ let state = & mut concurrent_state. get_mut ( host_task) ?. state ;
3547+ match mem:: replace ( state, HostTaskState :: CalleeDone ) {
3548+ // If the callee is still running, signal an abort is requested.
3549+ // Then fall through to determine what to do next.
3550+ HostTaskState :: CalleeRunning ( handle) => handle. abort ( ) ,
3551+
3552+ // Cancellation was already requested, so fail as the task can't
3553+ // be cancelled twice.
3554+ HostTaskState :: CalleeDone => {
3555+ bail ! ( "`subtask.cancel` called after terminal status delivered" ) ;
3556+ }
3557+
3558+ // These states should not be possible for a subtask that's
3559+ // visible from the guest, so panic here.
3560+ HostTaskState :: CalleeStarted | HostTaskState :: CalleeFinished ( _) => unreachable ! ( ) ,
35443561 }
3562+
3563+ // Cancelling host tasks always needs to block on them to await the
3564+ // result of the completion set up in `first_poll`. This'll resolve
3565+ // the race of `handle.abort()` above to see if it actually
3566+ // cancelled something or if the future ended up finishing.
3567+ needs_block = true ;
35453568 } else {
35463569 let caller = concurrent_state. unwrap_current_guest_thread ( ) ;
35473570 let guest_task = TableId :: < GuestTask > :: new ( rep) ;
@@ -3622,16 +3645,31 @@ impl Instance {
36223645 }
36233646 }
36243647
3625- let concurrent_state = store. concurrent_state_mut ( ) ;
3626- let task = concurrent_state. get_mut ( guest_task) ?;
3627- if !task. returned_or_cancelled ( ) {
3628- if async_ {
3629- return Ok ( BLOCKED ) ;
3630- } else {
3631- store. wait_for_event ( Waitable :: Guest ( guest_task) ) ?;
3632- }
3633- }
3648+ // Guest tasks need to block if they have not yet returned or
3649+ // cancelled, even as a result of the event delivery above.
3650+ needs_block = !store
3651+ . concurrent_state_mut ( )
3652+ . get_mut ( guest_task) ?
3653+ . returned_or_cancelled ( )
3654+ } else {
3655+ needs_block = false ;
3656+ }
3657+ } ;
3658+
3659+ // If we need to block waiting on the terminal status of this subtask
3660+ // then return immediately in `async` mode, or otherwise wait for the
3661+ // event to get signaled through the store.
3662+ if needs_block {
3663+ if async_ {
3664+ return Ok ( BLOCKED ) ;
36343665 }
3666+
3667+ // Wait for this waitable to get signaled with its terminal status
3668+ // from the completion callback enqueued by `first_poll`. Once
3669+ // that's done fall through to the sahred
3670+ store. wait_for_event ( waitable) ?;
3671+
3672+ // .. fall through to determine what event's in store for us.
36353673 }
36363674
36373675 let event = waitable. take_event ( store. concurrent_state_mut ( ) ) ?;
@@ -4147,24 +4185,39 @@ struct HostTask {
41474185 /// borrows to the host, for example.
41484186 call_context : CallContext ,
41494187
4150- /// For host tasks which end up doing some asynchronous work (e.g.
4151- /// async-lowered and didn't complete on the first poll) this handle is used
4152- /// as a signal to cancel the future as it resides in the store's
4153- /// `FuturesUnordered`.
4154- join_handle : Option < JoinHandle > ,
4188+ state : HostTaskState ,
4189+ }
41554190
4156- /// Box<Any> of the result of this host task.
4157- result : Option < LiftedResult > ,
4191+ enum HostTaskState {
4192+ /// A host task has been created and it's considered "started".
4193+ ///
4194+ /// The host task has yet to enter `first_poll` or `poll_and_block` which
4195+ /// is where this will get updated further.
4196+ CalleeStarted ,
4197+
4198+ /// State used for tasks in `first_poll` meaning that the guest did an async
4199+ /// lower of a host async function which is blocked. The specified handle is
4200+ /// linked to the future in the main `FuturesUnordered` of a store which is
4201+ /// used to cancel it if the guest requests cancellation.
4202+ CalleeRunning ( JoinHandle ) ,
4203+
4204+ /// Terminal state used for tasks in `poll_and_block` to store the result of
4205+ /// their computation. Note that this state is not used for tasks in
4206+ /// `first_poll`.
4207+ CalleeFinished ( LiftedResult ) ,
4208+
4209+ /// Terminal state for host tasks meaning that the task was cancelled or the
4210+ /// result was taken.
4211+ CalleeDone ,
41584212}
41594213
41604214impl HostTask {
4161- fn new ( caller : QualifiedThreadId ) -> Self {
4215+ fn new ( caller : QualifiedThreadId , state : HostTaskState ) -> Self {
41624216 Self {
41634217 common : WaitableCommon :: default ( ) ,
41644218 call_context : CallContext :: default ( ) ,
41654219 caller,
4166- join_handle : None ,
4167- result : None ,
4220+ state,
41684221 }
41694222 }
41704223}
0 commit comments