Skip to content

Commit f3156fe

Browse files
authored
Update fibers to avoid no-return functions (#12928)
* Update fibers to avoid no-return functions This commit is aimed at fixing the ASAN false positives in #12899. Initially the fix there was to invoke some `__asan_*` intrinsics, and I ended up finding a sort of smaller set of `__asan_*` intrinsics to call as well. In the end what's happening though is that fibers, upon terminating, have a few frames of Rust code on the stack before switching off. To ASAN these frames never returned so when a stack is subsequently reused ASAN is tricked into thinking this is buffer overflow or use-after-free since it's stomping on frames that haven't returned. The fix in this commit is to avoid this style of function which doesn't returns. Functions which don't return in Rust are easy to leak memory from and are a hazard from a safety perspective as well (e.g. it's unsafe to skip running destructors of stack variables). I feel we've had better success over time with "all Rust functions always return" and so what's what was applied here. Unlike #12899 or my thoughts on that PR this does not have any new `__asan_*` intrinsic calls. Instead what this does is it shuffles around responsibility for what exact piece of the infrastructure is responsible for what. Specifically `fiber_start` functions now actually return, meaning the `wasmtime_fiber_start` naked function actually resumes execution, unlike before. The `wasmtime_fiber_start` then delegates to `wasmtime_fiber_switch` immediately to perform the final switch. Effectively there's now only two function frames that never return, and both of these frames are handwritten inline assembly. This means that ASAN gets to see that all normal functions return and updates all of its metadata accordingly. The end result is that the original issue from #12899 is fixed and this I feel is in general more robust as well. One caveat is that the handwritten `wasmtime_fiber_start` assembly needs to invoke a sibling `wasmtime_fiber_switch_` function. In lieu of trying to figure out how to get PIC-vs-not calls working (e.g. static calls) I've opted to use indirect function calls and pointers instead. This mirrors historical changes in our fiber implementation too. * Fix CI builds * Fix miri
1 parent b7c30d1 commit f3156fe

13 files changed

Lines changed: 116 additions & 77 deletions

File tree

crates/fiber/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
255255
inner: imp::Suspend,
256256
initial: Resume,
257257
func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
258-
) {
258+
) -> imp::Suspend {
259259
let mut suspend = Suspend {
260260
inner,
261261
_phantom: PhantomData,
@@ -278,7 +278,8 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
278278
#[cfg(not(feature = "std"))]
279279
let result = RunResult::Returned((func)(initial, &mut suspend));
280280

281-
suspend.inner.exit::<Resume, Yield, Return>(result);
281+
suspend.inner.start_exit::<Resume, Yield, Return>(result);
282+
suspend.inner
282283
}
283284
}
284285

@@ -431,7 +432,7 @@ mod tests {
431432

432433
#[test]
433434
fn fiber_stack_max_size() {
434-
if cfg!(windows) {
435+
if cfg!(windows) || cfg!(miri) {
435436
return;
436437
}
437438
assert!(FiberStack::new(usize::MAX, true).is_err());

crates/fiber/src/miri.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,17 @@ where
116116
// Execute this fiber through `Suspend::execute` and once that's done
117117
// deallocate the `state` that we have.
118118
let state = Arc::into_raw(state);
119-
super::Suspend::<A, B, C>::execute(
119+
let mut suspend = super::Suspend::<A, B, C>::execute(
120120
Suspend {
121121
state: state.cast(),
122122
},
123123
init,
124124
func.0,
125125
);
126+
match suspend.block_until_notified::<A, B, C>() {
127+
State::Exiting => {}
128+
_ => unreachable!(),
129+
}
126130
unsafe {
127131
drop(Arc::from_raw(state));
128132
}
@@ -150,8 +154,7 @@ impl Fiber {
150154
let state = state.clone();
151155
let func = IgnoreSendSync(func);
152156
move || run(state, func)
153-
})
154-
.unwrap()
157+
})?
155158
};
156159

157160
// Cast the fiber back into a raw pointer to lose the type parameters
@@ -210,7 +213,7 @@ impl Fiber {
210213
}
211214

212215
impl Suspend {
213-
fn suspend<A, B, C>(&mut self, result: RunResult<A, B, C>) -> State<A, B, C> {
216+
fn set_result<A, B, C>(&mut self, result: RunResult<A, B, C>) {
214217
let state = unsafe { self.state() };
215218
let mut lock = state.state.lock().unwrap();
216219

@@ -219,9 +222,11 @@ impl Suspend {
219222
assert!(matches!(*lock, State::None));
220223
*lock = State::SuspendWith(result);
221224
state.cond.notify_one();
225+
}
222226

223-
// Wait for the resumption to come back, which is returned from this
224-
// method.
227+
fn block_until_notified<A, B, C>(&mut self) -> State<A, B, C> {
228+
let state = unsafe { self.state() };
229+
let mut lock = state.state.lock().unwrap();
225230
lock = state
226231
.cond
227232
.wait_while(lock, |s| {
@@ -232,17 +237,18 @@ impl Suspend {
232237
}
233238

234239
pub(crate) fn switch<A, B, C>(&mut self, result: RunResult<A, B, C>) -> A {
235-
match self.suspend(result) {
240+
self.set_result(result);
241+
242+
// Wait for the resumption to come back, which is returned from this
243+
// method.
244+
match self.block_until_notified::<A, B, C>() {
236245
State::ResumeWith(RunResult::Resuming(a)) => a,
237246
_ => unreachable!(),
238247
}
239248
}
240249

241-
pub(crate) fn exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
242-
match self.suspend(result) {
243-
State::Exiting => {}
244-
_ => unreachable!(),
245-
}
250+
pub(crate) fn start_exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
251+
self.set_result(result);
246252
}
247253

248254
unsafe fn state<A, B, C>(&self) -> &SharedFiberState<A, B, C> {

crates/fiber/src/nostd.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,16 @@ pub struct Suspend {
115115
top_of_stack: *mut u8,
116116
}
117117

118-
extern "C" fn fiber_start<F, A, B, C>(arg0: *mut u8, top_of_stack: *mut u8)
118+
extern "C" fn fiber_start<F, A, B, C>(arg0: *mut u8, top_of_stack: *mut u8) -> *mut u8
119119
where
120120
F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,
121121
{
122122
unsafe {
123123
let inner = Suspend { top_of_stack };
124124
let initial = inner.take_resume::<A, B, C>();
125-
super::Suspend::<A, B, C>::execute(inner, initial, Box::from_raw(arg0.cast::<F>()))
125+
let inner =
126+
super::Suspend::<A, B, C>::execute(inner, initial, Box::from_raw(arg0.cast::<F>()));
127+
inner.top_of_stack
126128
}
127129
}
128130

@@ -176,9 +178,10 @@ impl Suspend {
176178
}
177179
}
178180

179-
pub(crate) fn exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
180-
self.switch(result);
181-
unreachable!();
181+
pub(crate) fn start_exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
182+
unsafe {
183+
(*self.result_location::<A, B, C>()).set(result);
184+
}
182185
}
183186

184187
unsafe fn take_resume<A, B, C>(&self) -> A {

crates/fiber/src/stackswitch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ mod unsupported {
6565

6666
pub(crate) unsafe fn wasmtime_fiber_init(
6767
_top_of_stack: *mut u8,
68-
_entry: extern "C" fn(*mut u8, *mut u8),
68+
_entry: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
6969
_entry_arg0: *mut u8,
7070
) {
7171
unreachable!();

crates/fiber/src/stackswitch/aarch64.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8 /* x0 */) {
9292

9393
pub(crate) unsafe fn wasmtime_fiber_init(
9494
top_of_stack: *mut u8,
95-
entry_point: extern "C" fn(*mut u8, *mut u8),
95+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
9696
entry_arg0: *mut u8, // x2
9797
) {
9898
#[repr(C)]
@@ -132,6 +132,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
132132
x19: top_of_stack,
133133
x20: entry_point as *mut u8,
134134
x21: entry_arg0,
135+
x22: wasmtime_fiber_switch_ as *mut u8,
135136

136137
// We set up the newly initialized fiber, so that it resumes
137138
// execution from wasmtime_fiber_start(). As a result, we need a
@@ -208,6 +209,11 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
208209
// ... and then we call the function! Note that this is a function call
209210
// so our frame stays on the stack to backtrace through.
210211
blr x20
212+
213+
// The entry function returns where to switch to as the final switch, so
214+
// that's performed here in inline assembly.
215+
blr x22
216+
211217
// Unreachable, here for safety. This should help catch unexpected
212218
// behaviors. Use a noticeable payload so one can grep for it in the
213219
// codebase.

crates/fiber/src/stackswitch/arm.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8 /* r0 */) {
3636

3737
pub(crate) unsafe fn wasmtime_fiber_init(
3838
top_of_stack: *mut u8,
39-
entry_point: extern "C" fn(*mut u8, *mut u8),
39+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
4040
entry_arg0: *mut u8,
4141
) {
4242
#[repr(C)]
@@ -60,6 +60,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
6060
unsafe {
6161
let initial_stack = top_of_stack.cast::<InitialStack>().sub(1);
6262
initial_stack.write(InitialStack {
63+
r8: wasmtime_fiber_switch_ as *mut u8,
6364
r9: entry_arg0,
6465
r10: entry_point as *mut u8,
6566
r11: top_of_stack,
@@ -103,6 +104,7 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
103104
mov r1, r11
104105
mov r0, r9
105106
blx r10
107+
blx r8
106108
.cfi_endproc
107109
",
108110
);

crates/fiber/src/stackswitch/riscv32imac.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8 /* a0 */) {
7272

7373
pub(crate) unsafe fn wasmtime_fiber_init(
7474
top_of_stack: *mut u8,
75-
entry_point: extern "C" fn(*mut u8, *mut u8),
75+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
7676
entry_arg0: *mut u8,
7777
) {
7878
#[repr(C)]
@@ -106,6 +106,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
106106
initial_stack.write(InitialStack {
107107
s1: entry_point as *mut u8,
108108
s2: entry_arg0,
109+
s3: wasmtime_fiber_switch_ as *mut u8,
109110
fp: top_of_stack,
110111
ra: wasmtime_fiber_start as *mut u8,
111112
last_sp: initial_stack.cast(),
@@ -147,6 +148,7 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
147148
mv a0, s2
148149
mv a1, fp
149150
jalr s1
151+
jalr s3
150152
// .4byte 0 will cause panic.
151153
// for safety just like x86_64.rs and riscv64.rs.
152154
.4byte 0

crates/fiber/src/stackswitch/riscv64.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8 /* a0 */) {
9191

9292
pub(crate) unsafe fn wasmtime_fiber_init(
9393
top_of_stack: *mut u8,
94-
entry_point: extern "C" fn(*mut u8, *mut u8),
94+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
9595
entry_arg0: *mut u8,
9696
) {
9797
#[repr(C)]
@@ -126,6 +126,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
126126
initial_stack.write(InitialStack {
127127
s1: entry_point as *mut u8,
128128
s2: entry_arg0,
129+
s3: wasmtime_fiber_switch_ as *mut u8,
129130
fp: top_of_stack,
130131
ra: wasmtime_fiber_start as *mut u8,
131132
last_sp: initial_stack.cast(),
@@ -143,10 +144,10 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
143144
144145
145146
.cfi_escape 0x0f, /* DW_CFA_def_cfa_expression */ \
146-
5, /* the byte length of this expression */ \
147-
0x52, /* DW_OP_reg2 (sp) */ \
148-
0x06, /* DW_OP_deref */ \
149-
0x08, 0xd0, /* DW_OP_const1u 0xc8 */ \
147+
5, /* the byte length of this expression */ \
148+
0x52, /* DW_OP_reg2 (sp) */ \
149+
0x06, /* DW_OP_deref */ \
150+
0x08, 0xd0, /* DW_OP_const1u 0xc8 */ \
150151
0x22 /* DW_OP_plus */
151152
152153
@@ -178,7 +179,9 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
178179
179180
mv a0, s2
180181
mv a1, fp
181-
jalr s1
182+
jalr s1 // entry_point
183+
jalr s3 // wasmtime_fiber_switch_
184+
182185
// .4byte 0 will cause panic.
183186
// for safety just like x86_64.rs.
184187
.4byte 0

crates/fiber/src/stackswitch/s390x.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8 /* x0 */) {
5252

5353
pub(crate) unsafe fn wasmtime_fiber_init(
5454
top_of_stack: *mut u8,
55-
entry_point: extern "C" fn(*mut u8, *mut u8),
55+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
5656
entry_arg0: *mut u8,
5757
) {
5858
#[repr(C)]
@@ -104,6 +104,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
104104
r6: top_of_stack,
105105
r7: entry_point as *mut u8,
106106
r8: entry_arg0,
107+
r9: wasmtime_fiber_switch_ as *mut u8,
107108

108109
last_sp: initial_stack.cast(),
109110
..InitialStack::default()
@@ -145,6 +146,10 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
145146
// ... and then we call the function! Note that this is a function call so
146147
// our frame stays on the stack to backtrace through.
147148
basr %r14, %r7 // entry_point
149+
150+
// Perform the final switch.
151+
basr %r14, %r9 // wasmtime_fiber_switch_
152+
148153
// .. technically we shouldn't get here, so just trap.
149154
.word 0x0000
150155
.cfi_endproc

crates/fiber/src/stackswitch/x86.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ unsafe extern "C" fn wasmtime_fiber_switch_(top_of_stack: *mut u8) {
4747

4848
pub(crate) unsafe fn wasmtime_fiber_init(
4949
top_of_stack: *mut u8,
50-
entry_point: extern "C" fn(*mut u8, *mut u8),
50+
entry_point: extern "C" fn(*mut u8, *mut u8) -> *mut u8,
5151
entry_arg0: *mut u8,
5252
) {
5353
// Our stack from top-to-bottom looks like:
@@ -84,6 +84,7 @@ pub(crate) unsafe fn wasmtime_fiber_init(
8484
let initial_stack = top_of_stack.cast::<InitialStack>().sub(1);
8585
initial_stack.write(InitialStack {
8686
ebp: entry_point as *mut u8,
87+
esi: wasmtime_fiber_switch_ as *mut u8,
8788
return_address: wasmtime_fiber_start as *mut u8,
8889
arg1: entry_arg0,
8990
arg2: top_of_stack,
@@ -112,8 +113,15 @@ unsafe extern "C" fn wasmtime_fiber_start() -> ! {
112113
.cfi_rel_offset edi, -20
113114
114115
// Our arguments and stack alignment are all prepped by
115-
// `wasmtime_fiber_init`.
116+
// `wasmtime_fiber_init`. After calling `entry_point` clean up the
117+
// stack arguments.
116118
call ebp
119+
pop edi
120+
pop edi
121+
122+
// Call `wasmtime_fiber_switch_`, pushing its argument onto the stack.
123+
push eax
124+
call esi
117125
ud2
118126
.cfi_endproc
119127
",

0 commit comments

Comments
 (0)