Skip to content

Commit 065baac

Browse files
Fix/improve host subtask cancellation (#12640)
This commit refactors some of the internals of `subtask.cancel` with respect to host subtasks. Notably a few panics and semantic bugs are fixed here. The main bug was that host subtasks could be aborted but their completion might have still been queued up which would produce the result somewhere or assert that the task exists. Cancellation is changed to use `wait_for_event` to ensure that this completion is executed before `subtask.cancel` returns. This helps keep host subtasks looking more similar to guest subtasks in that respect. Closes #12631 Closes #12632 Co-authored-by: Jelle van den Hooff <jelle@vandenhooff.name>
1 parent b0bdcf8 commit 065baac

4 files changed

Lines changed: 522 additions & 61 deletions

File tree

crates/wasmtime/src/runtime/component/concurrent.rs

Lines changed: 103 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,9 @@ pub(crate) fn poll_and_block<R: Send + Sync + 'static>(
756756
let result = future.await?;
757757
tls::get(move |store| {
758758
let state = store.concurrent_state_mut();
759-
state.get_mut(task)?.result = Some(Box::new(result) as _);
759+
let host_state = &mut state.get_mut(task)?.state;
760+
assert!(matches!(host_state, HostTaskState::CalleeStarted));
761+
*host_state = HostTaskState::CalleeFinished(Box::new(result));
760762

761763
Waitable::Host(task).set_event(
762764
state,
@@ -808,14 +810,11 @@ pub(crate) fn poll_and_block<R: Send + Sync + 'static>(
808810
}
809811

810812
// Retrieve and return the result.
811-
Ok(*store
812-
.concurrent_state_mut()
813-
.get_mut(task)?
814-
.result
815-
.take()
816-
.unwrap()
817-
.downcast()
818-
.unwrap())
813+
let host_state = &mut store.concurrent_state_mut().get_mut(task)?.state;
814+
match mem::replace(host_state, HostTaskState::CalleeDone) {
815+
HostTaskState::CalleeFinished(result) => Ok(*result.downcast().unwrap()),
816+
_ => panic!("unexpected host task state after completion"),
817+
}
819818
}
820819

821820
/// Execute the specified guest call.
@@ -1550,7 +1549,7 @@ impl StoreOpaque {
15501549
pub fn enter_host_call(&mut self) -> Result<()> {
15511550
let state = self.concurrent_state_mut();
15521551
let caller = state.unwrap_current_guest_thread();
1553-
let task = state.push(HostTask::new(caller))?;
1552+
let task = state.push(HostTask::new(caller, HostTaskState::CalleeStarted))?;
15541553
log::trace!("new host task {task:?}");
15551554
self.set_thread(task);
15561555
Ok(())
@@ -2736,12 +2735,13 @@ impl Instance {
27362735
///
27372736
/// Whether the future returns `Ready` immediately or later, the `lower`
27382737
/// function will be used to lower the result, if any, into the guest caller's
2739-
/// stack and linear memory unless the task has been cancelled.
2738+
/// stack and linear memory. The `lower` function is invoked with `None` if
2739+
/// the future is cancelled.
27402740
pub(crate) fn first_poll<T: 'static, R: Send + 'static>(
27412741
self,
27422742
mut store: StoreContextMut<'_, T>,
27432743
future: impl Future<Output = Result<R>> + Send + 'static,
2744-
lower: impl FnOnce(StoreContextMut<T>, R) -> Result<()> + Send + 'static,
2744+
lower: impl FnOnce(StoreContextMut<T>, Option<R>) -> Result<()> + Send + 'static,
27452745
) -> Result<Option<u32>> {
27462746
let token = StoreToken::new(store.as_context_mut());
27472747
let state = store.0.concurrent_state_mut();
@@ -2751,9 +2751,9 @@ impl Instance {
27512751
// context state for the future.
27522752
let (join_handle, future) = JoinHandle::run(future);
27532753
{
2754-
let task = state.get_mut(task)?;
2755-
assert!(task.join_handle.is_none());
2756-
task.join_handle = Some(join_handle);
2754+
let state = &mut state.get_mut(task)?.state;
2755+
assert!(matches!(state, HostTaskState::CalleeStarted));
2756+
*state = HostTaskState::CalleeRunning(join_handle);
27572757
}
27582758

27592759
let mut future = Box::pin(future);
@@ -2771,7 +2771,7 @@ impl Instance {
27712771
match poll {
27722772
// It finished immediately; lower the result and delete the task.
27732773
Poll::Ready(Some(result)) => {
2774-
lower(store.as_context_mut(), result?)?;
2774+
lower(store.as_context_mut(), Some(result?))?;
27752775
return Ok(None);
27762776
}
27772777

@@ -2792,9 +2792,8 @@ impl Instance {
27922792
// the task returned.
27932793
let future = Box::pin(async move {
27942794
let result = match future.await {
2795-
Some(result) => result?,
2796-
// Task was cancelled; nothing left to do.
2797-
None => return Ok(()),
2795+
Some(result) => Some(result?),
2796+
None => None,
27982797
};
27992798
let on_complete = move |store: &mut dyn VMStore| {
28002799
// Restore the `current_thread` to be the host so `lower` knows
@@ -2805,15 +2804,16 @@ impl Instance {
28052804
assert!(state.current_thread.is_none());
28062805
store.0.set_thread(task);
28072806

2807+
let status = if result.is_some() {
2808+
Status::Returned
2809+
} else {
2810+
Status::ReturnCancelled
2811+
};
2812+
28082813
lower(store.as_context_mut(), result)?;
28092814
let state = store.0.concurrent_state_mut();
2810-
state.get_mut(task)?.join_handle.take();
2811-
Waitable::Host(task).set_event(
2812-
state,
2813-
Some(Event::Subtask {
2814-
status: Status::Returned,
2815-
}),
2816-
)?;
2815+
state.get_mut(task)?.state = HostTaskState::CalleeDone;
2816+
Waitable::Host(task).set_event(state, Some(Event::Subtask { status }))?;
28172817

28182818
// Go back to "no current thread" at the end.
28192819
store.0.set_thread(CurrentThread::None);
@@ -3059,8 +3059,12 @@ impl Instance {
30593059
let (waitable, expected_caller, delete) = if is_host {
30603060
let id = TableId::<HostTask>::new(rep);
30613061
let task = concurrent_state.get_mut(id)?;
3062-
if task.join_handle.is_some() {
3063-
bail!("cannot drop a subtask which has not yet resolved");
3062+
match &task.state {
3063+
HostTaskState::CalleeRunning(_) => {
3064+
bail!("cannot drop a subtask which has not yet resolved");
3065+
}
3066+
HostTaskState::CalleeDone => {}
3067+
HostTaskState::CalleeStarted | HostTaskState::CalleeFinished(_) => unreachable!(),
30643068
}
30653069
(Waitable::Host(id), task.caller, true)
30663070
} else {
@@ -3537,11 +3541,30 @@ impl Instance {
35373541

35383542
log::trace!("subtask_cancel {waitable:?} (handle {task_id})");
35393543

3544+
let needs_block;
35403545
if let Waitable::Host(host_task) = waitable {
3541-
if let Some(handle) = concurrent_state.get_mut(host_task)?.join_handle.take() {
3542-
handle.abort();
3543-
return Ok(Status::ReturnCancelled as u32);
3546+
let state = &mut concurrent_state.get_mut(host_task)?.state;
3547+
match mem::replace(state, HostTaskState::CalleeDone) {
3548+
// If the callee is still running, signal an abort is requested.
3549+
// Then fall through to determine what to do next.
3550+
HostTaskState::CalleeRunning(handle) => handle.abort(),
3551+
3552+
// Cancellation was already requested, so fail as the task can't
3553+
// be cancelled twice.
3554+
HostTaskState::CalleeDone => {
3555+
bail!("`subtask.cancel` called after terminal status delivered");
3556+
}
3557+
3558+
// These states should not be possible for a subtask that's
3559+
// visible from the guest, so panic here.
3560+
HostTaskState::CalleeStarted | HostTaskState::CalleeFinished(_) => unreachable!(),
35443561
}
3562+
3563+
// Cancelling host tasks always needs to block on them to await the
3564+
// result of the completion set up in `first_poll`. This'll resolve
3565+
// the race of `handle.abort()` above to see if it actually
3566+
// cancelled something or if the future ended up finishing.
3567+
needs_block = true;
35453568
} else {
35463569
let caller = concurrent_state.unwrap_current_guest_thread();
35473570
let guest_task = TableId::<GuestTask>::new(rep);
@@ -3622,16 +3645,31 @@ impl Instance {
36223645
}
36233646
}
36243647

3625-
let concurrent_state = store.concurrent_state_mut();
3626-
let task = concurrent_state.get_mut(guest_task)?;
3627-
if !task.returned_or_cancelled() {
3628-
if async_ {
3629-
return Ok(BLOCKED);
3630-
} else {
3631-
store.wait_for_event(Waitable::Guest(guest_task))?;
3632-
}
3633-
}
3648+
// Guest tasks need to block if they have not yet returned or
3649+
// cancelled, even as a result of the event delivery above.
3650+
needs_block = !store
3651+
.concurrent_state_mut()
3652+
.get_mut(guest_task)?
3653+
.returned_or_cancelled()
3654+
} else {
3655+
needs_block = false;
3656+
}
3657+
};
3658+
3659+
// If we need to block waiting on the terminal status of this subtask
3660+
// then return immediately in `async` mode, or otherwise wait for the
3661+
// event to get signaled through the store.
3662+
if needs_block {
3663+
if async_ {
3664+
return Ok(BLOCKED);
36343665
}
3666+
3667+
// Wait for this waitable to get signaled with its terminal status
3668+
// from the completion callback enqueued by `first_poll`. Once
3669+
// that's done fall through to the sahred
3670+
store.wait_for_event(waitable)?;
3671+
3672+
// .. fall through to determine what event's in store for us.
36353673
}
36363674

36373675
let event = waitable.take_event(store.concurrent_state_mut())?;
@@ -4147,24 +4185,39 @@ struct HostTask {
41474185
/// borrows to the host, for example.
41484186
call_context: CallContext,
41494187

4150-
/// For host tasks which end up doing some asynchronous work (e.g.
4151-
/// async-lowered and didn't complete on the first poll) this handle is used
4152-
/// as a signal to cancel the future as it resides in the store's
4153-
/// `FuturesUnordered`.
4154-
join_handle: Option<JoinHandle>,
4188+
state: HostTaskState,
4189+
}
41554190

4156-
/// Box<Any> of the result of this host task.
4157-
result: Option<LiftedResult>,
4191+
enum HostTaskState {
4192+
/// A host task has been created and it's considered "started".
4193+
///
4194+
/// The host task has yet to enter `first_poll` or `poll_and_block` which
4195+
/// is where this will get updated further.
4196+
CalleeStarted,
4197+
4198+
/// State used for tasks in `first_poll` meaning that the guest did an async
4199+
/// lower of a host async function which is blocked. The specified handle is
4200+
/// linked to the future in the main `FuturesUnordered` of a store which is
4201+
/// used to cancel it if the guest requests cancellation.
4202+
CalleeRunning(JoinHandle),
4203+
4204+
/// Terminal state used for tasks in `poll_and_block` to store the result of
4205+
/// their computation. Note that this state is not used for tasks in
4206+
/// `first_poll`.
4207+
CalleeFinished(LiftedResult),
4208+
4209+
/// Terminal state for host tasks meaning that the task was cancelled or the
4210+
/// result was taken.
4211+
CalleeDone,
41584212
}
41594213

41604214
impl HostTask {
4161-
fn new(caller: QualifiedThreadId) -> Self {
4215+
fn new(caller: QualifiedThreadId, state: HostTaskState) -> Self {
41624216
Self {
41634217
common: WaitableCommon::default(),
41644218
call_context: CallContext::default(),
41654219
caller,
4166-
join_handle: None,
4167-
result: None,
4220+
state,
41684221
}
41694222
}
41704223
}

crates/wasmtime/src/runtime/component/func/host.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ where
447447
ptr,
448448
)?)
449449
};
450-
Self::lower_result_and_exit_call(&mut lower, ty, ret, dst)
450+
Self::lower_result_and_exit_call(&mut lower, ty, Some(ret), dst)
451451
}
452452

453453
/// Implementation of the "async" ABI of the component model.
@@ -499,7 +499,7 @@ where
499499
Self::lower_result_and_exit_call(
500500
&mut LowerContext::new(store, options, instance),
501501
ty,
502-
result?,
502+
Some(result?),
503503
Destination::Memory(retptr),
504504
)?;
505505
None
@@ -568,17 +568,19 @@ where
568568
fn lower_result_and_exit_call(
569569
lower: &mut LowerContext<'_, T>,
570570
ty: TypeFuncIndex,
571-
ret: R,
571+
ret: Option<R>,
572572
dst: Destination<'_>,
573573
) -> Result<()> {
574-
let caller_instance = lower.options().instance;
575-
let mut flags = lower.instance_mut().instance_flags(caller_instance);
576-
unsafe {
577-
flags.set_may_leave(false);
578-
}
579-
Self::lower_result(lower, ty, ret, dst)?;
580-
unsafe {
581-
flags.set_may_leave(true);
574+
if let Some(ret) = ret {
575+
let caller_instance = lower.options().instance;
576+
let mut flags = lower.instance_mut().instance_flags(caller_instance);
577+
unsafe {
578+
flags.set_may_leave(false);
579+
}
580+
Self::lower_result(lower, ty, ret, dst)?;
581+
unsafe {
582+
flags.set_may_leave(true);
583+
}
582584
}
583585
lower.validate_scope_exit()?;
584586
Ok(())

crates/wast/src/spectest.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,5 +197,26 @@ pub fn link_component_spectest<T>(linker: &mut component::Linker<T>) -> Result<(
197197
Ok(())
198198
},
199199
)?;
200+
i.func_wrap_concurrent("never-return", |_, _: ()| {
201+
Box::pin(async move { std::future::pending::<Result<()>>().await })
202+
})?;
203+
i.func_wrap_concurrent("return-two-slowly", |_, _: ()| {
204+
Box::pin(async move {
205+
tokio::task::yield_now().await;
206+
Ok((2,))
207+
})
208+
})?;
209+
i.func_wrap_concurrent("echo-slowly", |_, (a,): (u32,)| {
210+
Box::pin(async move {
211+
tokio::task::yield_now().await;
212+
Ok((a,))
213+
})
214+
})?;
215+
i.func_wrap_concurrent(
216+
"[method]resource1.never-return",
217+
|_, (_,): (Resource<Resource1>,)| {
218+
Box::pin(async move { std::future::pending::<Result<()>>().await })
219+
},
220+
)?;
200221
Ok(())
201222
}

0 commit comments

Comments
 (0)