Skip to content

Commit 369bbb0

Browse files
authored
rust: Fix receiving async events in busy yield loops (#1439)
This commit fixes an issue with the rust async support where when a yield-loop was detected it suspended with `CALLBACK_CODE_YIELD` which meant that it wasn't possible to deliver async events for other async operations in progress. The fix is to instead return `CALLBACK_CODE_POLL` with the waitable-set that has all the waitables inside of it. This enables delivery of async events from that set to ensure that if the yielding task is waiting on something else that it gets delivered. This also refactors a few things here and there, such as `CALLBACK_CODE_*`, to be less error prone. Support was added for composing 3+ components with `wasm-compose` since `wac` doesn't yet use the same `wasm-tools`.
1 parent 2a6d845 commit 369bbb0

7 files changed

Lines changed: 148 additions & 27 deletions

File tree

crates/guest-rust/src/rt/async_support.rs

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ impl FutureState<'_> {
156156

157157
/// Handles the `event{0,1,2}` event codes and returns a corresponding
158158
/// return code along with a flag whether this future is "done" or not.
159-
fn callback(&mut self, event0: u32, event1: u32, event2: u32) -> (u32, bool) {
159+
fn callback(&mut self, event0: u32, event1: u32, event2: u32) -> CallbackCode {
160160
match event0 {
161161
EVENT_NONE => rtdebug!("EVENT_NONE"),
162162
EVENT_SUBTASK => rtdebug!("EVENT_SUBTASK({event1:#x}, {event2:#x})"),
@@ -171,7 +171,7 @@ impl FutureState<'_> {
171171
// code/bool indicating we're done. The caller will then
172172
// appropriately deallocate this `FutureState` which will
173173
// transitively run all destructors.
174-
return (CALLBACK_CODE_EXIT, true);
174+
return CallbackCode::Exit;
175175
}
176176
_ => unreachable!(),
177177
}
@@ -198,7 +198,7 @@ impl FutureState<'_> {
198198
///
199199
/// Returns the code representing what happened along with a boolean as to
200200
/// whether this execution is done.
201-
fn poll(&mut self) -> (u32, bool) {
201+
fn poll(&mut self) -> CallbackCode {
202202
self.with_p3_task_set(|me| {
203203
let mut context = Context::from_waker(&me.waker_clone);
204204

@@ -223,9 +223,9 @@ impl FutureState<'_> {
223223
assert!(me.tasks.is_empty());
224224
if me.remaining_work() {
225225
let waitable = me.waitable_set.as_ref().unwrap().as_raw();
226-
break (CALLBACK_CODE_WAIT | (waitable << 4), false);
226+
break CallbackCode::Wait(waitable);
227227
} else {
228-
break (CALLBACK_CODE_EXIT, true);
228+
break CallbackCode::Exit;
229229
}
230230
}
231231

@@ -236,12 +236,18 @@ impl FutureState<'_> {
236236
Poll::Pending => {
237237
assert!(!me.tasks.is_empty());
238238
if me.waker.0.load(Ordering::Relaxed) {
239-
break (CALLBACK_CODE_YIELD, false);
239+
let code = if me.remaining_work() {
240+
let waitable = me.waitable_set.as_ref().unwrap().as_raw();
241+
CallbackCode::Poll(waitable)
242+
} else {
243+
CallbackCode::Yield
244+
};
245+
break code;
240246
}
241247

242248
assert!(me.remaining_work());
243249
let waitable = me.waitable_set.as_ref().unwrap().as_raw();
244-
break (CALLBACK_CODE_WAIT | (waitable << 4), false);
250+
break CallbackCode::Wait(waitable);
245251
}
246252
}
247253
}
@@ -331,10 +337,24 @@ const EVENT_FUTURE_READ: u32 = 4;
331337
const EVENT_FUTURE_WRITE: u32 = 5;
332338
const EVENT_CANCEL: u32 = 6;
333339

334-
const CALLBACK_CODE_EXIT: u32 = 0;
335-
const CALLBACK_CODE_YIELD: u32 = 1;
336-
const CALLBACK_CODE_WAIT: u32 = 2;
337-
const _CALLBACK_CODE_POLL: u32 = 3;
340+
#[derive(PartialEq, Debug)]
341+
enum CallbackCode {
342+
Exit,
343+
Yield,
344+
Wait(u32),
345+
Poll(u32),
346+
}
347+
348+
impl CallbackCode {
349+
fn encode(self) -> u32 {
350+
match self {
351+
CallbackCode::Exit => 0,
352+
CallbackCode::Yield => 1,
353+
CallbackCode::Wait(waitable) => 2 | (waitable << 4),
354+
CallbackCode::Poll(waitable) => 3 | (waitable << 4),
355+
}
356+
}
357+
}
338358

339359
const STATUS_STARTING: u32 = 0;
340360
const STATUS_STARTED: u32 = 1;
@@ -425,14 +445,14 @@ pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 {
425445
// our future so deallocate it. Otherwise put our future back in
426446
// context-local storage and forward the code.
427447
unsafe {
428-
let (rc, done) = (*state).callback(event0, event1, event2);
429-
if done {
448+
let rc = (*state).callback(event0, event1, event2);
449+
if rc == CallbackCode::Exit {
430450
drop(Box::from_raw(state));
431451
} else {
432452
context_set(state.cast());
433453
}
434-
rtdebug!(" => (cb) {rc:#x}");
435-
rc
454+
rtdebug!(" => (cb) {rc:?}");
455+
rc.encode()
436456
}
437457
}
438458

@@ -449,12 +469,14 @@ pub fn block_on<T: 'static>(future: impl Future<Output = T>) -> T {
449469
let mut event = (EVENT_NONE, 0, 0);
450470
loop {
451471
match state.callback(event.0, event.1, event.2) {
452-
(_, true) => {
472+
CallbackCode::Exit => {
453473
drop(state);
454474
break result.unwrap();
455475
}
456-
(CALLBACK_CODE_YIELD, false) => event = state.waitable_set.as_ref().unwrap().poll(),
457-
_ => event = state.waitable_set.as_ref().unwrap().wait(),
476+
CallbackCode::Yield | CallbackCode::Poll(_) => {
477+
event = state.waitable_set.as_ref().unwrap().poll()
478+
}
479+
CallbackCode::Wait(_) => event = state.waitable_set.as_ref().unwrap().wait(),
458480
}
459481
}
460482
}

crates/guest-rust/src/rt/async_support/waitable_set.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ impl WaitableSet {
1717
}
1818

1919
pub fn remove_waitable_from_all_sets(waitable: u32) {
20-
rtdebug!("waitable-set.join({waitable}, 0)");
20+
rtdebug!("waitable.join({waitable}, 0)");
2121
unsafe { join(waitable, 0) }
2222
}
2323

2424
pub fn wait(&self) -> (u32, u32, u32) {
2525
unsafe {
2626
let mut payload = [0; 2];
27+
rtdebug!("waitable-set.wait({}) = ...", self.0.get());
2728
let event0 = wait(self.0.get(), &mut payload);
2829
rtdebug!(
2930
"waitable-set.wait({}) = ({event0}, {:#x}, {:#x})",
@@ -38,6 +39,7 @@ impl WaitableSet {
3839
pub fn poll(&self) -> (u32, u32, u32) {
3940
unsafe {
4041
let mut payload = [0; 2];
42+
rtdebug!("waitable-set.poll({}) = ...", self.0.get());
4143
let event0 = poll(self.0.get(), &mut payload);
4244
rtdebug!(
4345
"waitable-set.poll({}) = ({event0}, {:#x}, {:#x})",

crates/test/src/lib.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ impl Runner<'_> {
837837
// done for async tests at this time to ensure that there's a version of
838838
// composition that's done which is at the same version as wasmparser
839839
// and friends.
840-
let composed = if case.config.wac.is_none() && test_components.len() == 1 {
840+
let composed = if case.config.wac.is_none() {
841841
self.compose_wasm_with_wasm_compose(runner_wasm, test_components)?
842842
} else {
843843
self.compose_wasm_with_wac(case, runner, runner_wasm, test_components)?
@@ -865,13 +865,32 @@ impl Runner<'_> {
865865
runner_wasm: &Path,
866866
test_components: &[(&Component, &Path)],
867867
) -> Result<Vec<u8>> {
868-
assert!(test_components.len() == 1);
869-
let test_wasm = test_components[0].1;
870-
let mut config = wasm_compose::config::Config::default();
871-
config.definitions = vec![test_wasm.to_path_buf()];
872-
wasm_compose::composer::ComponentComposer::new(runner_wasm, &config)
873-
.compose()
874-
.with_context(|| format!("failed to compose {runner_wasm:?} with {test_wasm:?}"))
868+
assert!(test_components.len() > 0);
869+
let mut last_bytes = None;
870+
let mut path: PathBuf;
871+
for (i, (_component, component_path)) in test_components.iter().enumerate() {
872+
let main = match last_bytes.take() {
873+
Some(bytes) => {
874+
path = runner_wasm.with_extension(&format!("composition{i}.wasm"));
875+
std::fs::write(&path, &bytes)
876+
.with_context(|| format!("failed to write temporary file {path:?}"))?;
877+
path.as_path()
878+
}
879+
None => runner_wasm,
880+
};
881+
882+
let mut config = wasm_compose::config::Config::default();
883+
config.definitions = vec![component_path.to_path_buf()];
884+
last_bytes = Some(
885+
wasm_compose::composer::ComponentComposer::new(main, &config)
886+
.compose()
887+
.with_context(|| {
888+
format!("failed to compose {main:?} with {component_path:?}")
889+
})?,
890+
);
891+
}
892+
893+
Ok(last_bytes.unwrap())
875894
}
876895

877896
fn compose_wasm_with_wac(
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
include!(env!("BINDINGS"));
2+
3+
struct Component;
4+
5+
export!(Component);
6+
7+
impl crate::exports::test::common::i_middle::Guest for Component {
8+
async fn f() {
9+
for _ in 0..2 {
10+
wit_bindgen::yield_async().await;
11+
}
12+
}
13+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
include!(env!("BINDINGS"));
2+
3+
use crate::test::common::i_middle::f;
4+
use std::task::Poll;
5+
6+
pub struct Component;
7+
8+
export!(Component);
9+
10+
static mut HIT: bool = false;
11+
12+
impl crate::exports::test::common::i_runner::Guest for Component {
13+
async fn f() {
14+
wit_bindgen::spawn(async move {
15+
f().await;
16+
unsafe {
17+
HIT = true;
18+
}
19+
});
20+
21+
// This is an "infinite loop" but it's also effectively a yield which
22+
// should enable not only making progress on sibling rust-level tasks
23+
// but additionally async events should be deliverable.
24+
std::future::poll_fn(|cx| unsafe {
25+
if HIT {
26+
Poll::Ready(())
27+
} else {
28+
cx.waker().wake_by_ref();
29+
Poll::Pending
30+
}
31+
})
32+
.await;
33+
}
34+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
include!(env!("BINDINGS"));
2+
3+
fn main() {
4+
wit_bindgen::block_on(async {
5+
crate::test::common::i_runner::f().await;
6+
});
7+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//@ dependencies = ['middle', 'leaf']
2+
3+
package test:common;
4+
5+
world runner {
6+
import i-runner;
7+
}
8+
9+
interface i-runner {
10+
f: async func();
11+
}
12+
13+
world middle {
14+
export i-runner;
15+
import i-middle;
16+
}
17+
18+
interface i-middle {
19+
f: async func();
20+
}
21+
22+
world leaf {
23+
export i-middle;
24+
}

0 commit comments

Comments
 (0)