Skip to content

Commit af69fb8

Browse files
committed
Fix: flattened functype marshalling indirect return values into params for lowering
1 parent 9669681 commit af69fb8

3 files changed

Lines changed: 130 additions & 134 deletions

File tree

crates/environ/src/component/types.rs

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -372,36 +372,43 @@ impl ComponentTypes {
372372
}
373373
}
374374

375-
/// Returns the flat storage ABI representation for an interface type.
376-
/// If the flat representation is larger than `limit` number of flat types, returns
377-
/// storage with a pointer.
375+
/// Returns the flat representation of a function's params and returns for
376+
/// the underlying core wasm function according to the Canonical ABI
378377
///
379-
/// The intention of this method is to determine the flat ABI on host-to-wasm
380-
/// transitions (return from hostcall, or entry into wasmcall). When the type is
381-
/// not encodable in flat types, the values are all lowered to memory, implied by
382-
/// the pointer storage.
383-
pub fn flat_types_storage_or_pointer(
378+
/// As per the Canonical ABI, when the representation is larger than MAX_FLAT_RESULTS
379+
/// or MAX_FLAT_PARAMS, the core wasm function will take a pointer to the arg/result list.
380+
/// Returns (param_iterator, result_iterator)
381+
pub fn flat_func_type(
384382
&self,
385-
ty: &InterfaceType,
386-
limit: usize,
387-
) -> FlatTypesStorage {
388-
assert!(
389-
limit <= MAX_FLAT_TYPES,
390-
"limit exceeding maximum flat types not allowed"
391-
);
392-
self.flat_types_storage_inner(ty, limit).unwrap_or_else(|| {
393-
let mut flat = FlatTypesStorage::new();
394-
// Pointer representation for wasm32 and wasm64 respectively
395-
flat.push(FlatType::I32, FlatType::I64);
396-
flat
397-
})
383+
ty: &TypeFunc,
384+
context: FlatFuncTypeContext,
385+
) -> (FlatTypesStorage, FlatTypesStorage) {
386+
let mut params_storage = self
387+
.flat_interface_type(&InterfaceType::Tuple(ty.params), MAX_FLAT_PARAMS)
388+
.unwrap_or_else(|| {
389+
let mut flat = FlatTypesStorage::new();
390+
flat.push(FlatType::I32, FlatType::I64);
391+
flat
392+
});
393+
let results_storage = self
394+
.flat_interface_type(&InterfaceType::Tuple(ty.results), MAX_FLAT_RESULTS)
395+
.unwrap_or_else(|| {
396+
let mut flat = FlatTypesStorage::new();
397+
match context {
398+
FlatFuncTypeContext::Lift => {
399+
flat.push(FlatType::I32, FlatType::I64);
400+
}
401+
// For lowers, the retptr is passed as the last parameter
402+
FlatFuncTypeContext::Lower => {
403+
params_storage.push(FlatType::I32, FlatType::I64);
404+
}
405+
}
406+
flat
407+
});
408+
(params_storage, results_storage)
398409
}
399410

400-
fn flat_types_storage_inner(
401-
&self,
402-
ty: &InterfaceType,
403-
limit: usize,
404-
) -> Option<FlatTypesStorage> {
411+
fn flat_interface_type(&self, ty: &InterfaceType, limit: usize) -> Option<FlatTypesStorage> {
405412
// Helper routines
406413
let push = |storage: &mut FlatTypesStorage, t32: FlatType, t64: FlatType| -> bool {
407414
storage.push(t32, t64);
@@ -502,13 +509,14 @@ impl ComponentTypes {
502509
&& push(storage, FlatType::I32, FlatType::I64)
503510
}
504511

505-
InterfaceType::Record(i) => self[*i].fields.iter().all(|field| {
506-
push_storage(storage, self.flat_types_storage_inner(&field.ty, limit))
507-
}),
512+
InterfaceType::Record(i) => self[*i]
513+
.fields
514+
.iter()
515+
.all(|field| push_storage(storage, self.flat_interface_type(&field.ty, limit))),
508516
InterfaceType::Tuple(i) => self[*i]
509517
.types
510518
.iter()
511-
.all(|field| push_storage(storage, self.flat_types_storage_inner(field, limit))),
519+
.all(|field| push_storage(storage, self.flat_interface_type(field, limit))),
512520
InterfaceType::Flags(i) => match FlagsSize::from_count(self[*i].names.len()) {
513521
FlagsSize::Size0 => true,
514522
FlagsSize::Size1 | FlagsSize::Size2 => push(storage, FlatType::I32, FlatType::I32),
@@ -519,9 +527,7 @@ impl ComponentTypes {
519527
InterfaceType::Variant(i) => {
520528
push_discrim(storage)
521529
&& self[*i].cases.values().all(|case| {
522-
let case_flat = case
523-
.as_ref()
524-
.map(|ty| self.flat_types_storage_inner(ty, limit));
530+
let case_flat = case.as_ref().map(|ty| self.flat_interface_type(ty, limit));
525531
push_storage_variant_case(storage, case_flat)
526532
})
527533
}
@@ -530,22 +536,18 @@ impl ComponentTypes {
530536
&& push_storage_variant_case(storage, None)
531537
&& push_storage_variant_case(
532538
storage,
533-
Some(self.flat_types_storage_inner(&self[*i].ty, limit)),
539+
Some(self.flat_interface_type(&self[*i].ty, limit)),
534540
)
535541
}
536542
InterfaceType::Result(i) => {
537543
push_discrim(storage)
538544
&& push_storage_variant_case(
539545
storage,
540-
self[*i]
541-
.ok
542-
.map(|ty| self.flat_types_storage_inner(&ty, limit)),
546+
self[*i].ok.map(|ty| self.flat_interface_type(&ty, limit)),
543547
)
544548
&& push_storage_variant_case(
545549
storage,
546-
self[*i]
547-
.err
548-
.map(|ty| self.flat_types_storage_inner(&ty, limit)),
550+
self[*i].err.map(|ty| self.flat_interface_type(&ty, limit)),
549551
)
550552
}
551553
}
@@ -1393,7 +1395,6 @@ const fn max_flat(a: Option<u8>, b: Option<u8>) -> Option<u8> {
13931395
/// that's 24 bytes. Otherwise `FlatType` is 1 byte large and
13941396
/// `MAX_FLAT_TYPES` is 16, so it should ideally be more space-efficient to
13951397
/// use a flat array instead of a heap-based vector.
1396-
#[derive(Debug)]
13971398
pub struct FlatTypesStorage {
13981399
/// Representation for 32-bit memory
13991400
pub memory32: [FlatType; MAX_FLAT_TYPES],
@@ -1514,3 +1515,14 @@ impl FlatType {
15141515
}
15151516
}
15161517
}
1518+
1519+
/// Context under which the flat ABI is considered for functypes.
1520+
///
1521+
/// Note that this is necessary since the same signature can have different
1522+
/// ABIs depending on whether it is a lifted function or a lowered function.
1523+
pub enum FlatFuncTypeContext {
1524+
/// Flattening args for a lifted function
1525+
Lift,
1526+
/// Flattening args for a lowered function
1527+
Lower,
1528+
}

crates/wasmtime/src/runtime/component/func/host.rs

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ use core::mem::{self, MaybeUninit};
2121
use core::pin::Pin;
2222
use core::ptr::NonNull;
2323
use wasmtime_environ::component::{
24-
CanonicalAbiInfo, ComponentTypes, InterfaceType, MAX_FLAT_ASYNC_PARAMS, MAX_FLAT_PARAMS,
25-
MAX_FLAT_RESULTS, OptionsIndex, RuntimeComponentInstanceIndex, TypeFuncIndex, TypeTuple,
24+
CanonicalAbiInfo, ComponentTypes, FlatFuncTypeContext, FlatTypesStorage, InterfaceType,
25+
MAX_FLAT_ASYNC_PARAMS, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, OptionsIndex,
26+
RuntimeComponentInstanceIndex, TypeFunc, TypeFuncIndex, TypeTuple,
2627
};
2728

2829
pub struct HostFunc {
@@ -213,6 +214,27 @@ where
213214
Ok(())
214215
}
215216

217+
#[cfg(feature = "rr")]
218+
#[inline(always)]
219+
fn flat_func_type(
220+
types: &ComponentTypes,
221+
ty: &TypeFunc,
222+
context: FlatFuncTypeContext,
223+
) -> (FlatTypesStorage, FlatTypesStorage) {
224+
types.flat_func_type(ty, context)
225+
}
226+
227+
#[cfg(not(feature = "rr"))]
228+
#[inline(always)]
229+
/// This will get DCEd when RR is disabled
230+
fn flat_func_type(
231+
_types: &ComponentTypes,
232+
_ty: &TypeFunc,
233+
_context: FlatFuncTypeContext,
234+
) -> (FlatTypesStorage, FlatTypesStorage) {
235+
(FlatTypesStorage::new(), FlatTypesStorage::new())
236+
}
237+
216238
/// The "meat" of calling a host function from wasm.
217239
///
218240
/// This function is delegated to from implementations of
@@ -267,7 +289,10 @@ where
267289
let param_tys = InterfaceType::Tuple(ty.params);
268290
let result_tys = InterfaceType::Tuple(ty.results);
269291

270-
rr::component_hooks::record_validate_host_func_entry(storage, &types, &param_tys, store.0)?;
292+
let (param_flat_types, result_flat_types) =
293+
flat_func_type(types, &ty, FlatFuncTypeContext::Lower);
294+
295+
rr::component_hooks::record_validate_host_func_entry(storage, param_flat_types, store.0)?;
271296

272297
if async_ {
273298
#[cfg(feature = "component-model-async")]
@@ -338,7 +363,7 @@ where
338363
};
339364

340365
let mut lower = LowerContext::new(store, options, instance);
341-
storage.lower_results(&mut lower, InterfaceType::U32, status)?;
366+
storage.lower_results(&mut lower, InterfaceType::U32, result_flat_types, status)?;
342367
}
343368
#[cfg(not(feature = "component-model-async"))]
344369
{
@@ -366,7 +391,7 @@ where
366391
flags.set_may_leave(false);
367392
}
368393
let mut lower = LowerContext::new(store, options, instance);
369-
storage.lower_results(&mut lower, result_tys, ret)?;
394+
storage.lower_results(&mut lower, result_tys, result_flat_types, ret)?;
370395
unsafe {
371396
flags.set_may_leave(true);
372397
}
@@ -590,6 +615,7 @@ where
590615
&mut self,
591616
cx: &mut LowerContext<'_, T>,
592617
ty: InterfaceType,
618+
result_flat_types: FlatTypesStorage,
593619
ret: R,
594620
) -> Result<()> {
595621
match self.lower_dst() {
@@ -601,8 +627,7 @@ where
601627
);
602628
rr::component_hooks::record_host_func_return(
603629
unsafe { storage_as_slice_mut(storage) },
604-
cx.types,
605-
&ty,
630+
result_flat_types,
606631
cx.store.0,
607632
)?;
608633
result
@@ -615,11 +640,10 @@ where
615640
ty,
616641
offset,
617642
);
618-
// Record the pointer
643+
// Pointer will be recorded by params; not necessary here
619644
rr::component_hooks::record_host_func_return(
620-
&[MaybeUninit::new(*ptr)],
621-
cx.types,
622-
&InterfaceType::U32,
645+
&[],
646+
result_flat_types,
623647
cx.store.0,
624648
)?;
625649
result
@@ -846,10 +870,12 @@ where
846870
params_and_results.push(Val::Bool(false));
847871
}
848872

873+
let (param_flat_types, result_flat_types) =
874+
flat_func_type(types, func_ty, FlatFuncTypeContext::Lower);
875+
849876
rr::component_hooks::record_validate_host_func_entry(
850877
storage,
851-
types,
852-
&InterfaceType::Tuple(func_ty.params),
878+
param_flat_types,
853879
store.0.store_opaque_mut(),
854880
)?;
855881

@@ -928,12 +954,7 @@ where
928954
)?;
929955
}
930956
assert!(dst.next().is_none());
931-
rr::component_hooks::record_host_func_return(
932-
storage,
933-
cx.types,
934-
&InterfaceType::Tuple(func_ty.results),
935-
cx.store.0,
936-
)?;
957+
rr::component_hooks::record_host_func_return(storage, result_flat_types, cx.store.0)?;
937958
} else {
938959
let ret_ptr = unsafe { storage[ret_index].assume_init_ref() };
939960
let mut ptr = validate_inbounds_dynamic(&result_tys.abi, cx.as_slice_mut(), ret_ptr)?;
@@ -946,13 +967,8 @@ where
946967
offset,
947968
)?;
948969
}
949-
// Lower store into pointer
950-
rr::component_hooks::record_host_func_return(
951-
&storage[ret_index..ret_index + 1],
952-
cx.types,
953-
&InterfaceType::U32,
954-
cx.store.0,
955-
)?;
970+
// Ret ptr is passed through params, and doesn't need to be recorded.
971+
rr::component_hooks::record_host_func_return(&[], result_flat_types, cx.store.0)?;
956972
}
957973

958974
unsafe {
@@ -977,43 +993,31 @@ unsafe fn call_host_dynamic_replay<T>(
977993
#[cfg(feature = "rr")]
978994
{
979995
use crate::rr::component_hooks::ReplayLoweringPhase;
980-
// Mirror of `dynamic_params_load` for replay. Keep in sync
981-
fn dynamic_params_load_replay(param_tys: &TypeTuple, max_flat_params: usize) -> usize {
982-
if let Some(param_count) = param_tys.abi.flat_count(max_flat_params) {
983-
param_count
984-
} else {
985-
1
986-
}
987-
}
988996

989997
if async_ {
990998
unreachable!(
991999
"Replay logic should be unreachable with component async-ABI (currently unsupported)"
9921000
);
9931001
}
9941002
let func_ty = &types[ty];
995-
let param_tys = &types[func_ty.params];
1003+
let (param_flat_types, _) = flat_func_type(types, func_ty, FlatFuncTypeContext::Lower);
9961004
let result_tys = &types[func_ty.results];
9971005

9981006
rr::component_hooks::replay_validate_host_func_entry(
9991007
storage,
1000-
types,
1001-
&InterfaceType::Tuple(func_ty.params),
1008+
param_flat_types,
10021009
store.0.store_opaque_mut(),
10031010
)?;
10041011

10051012
let mut cx = LowerContext::new(store, options, instance);
10061013

10071014
// Skip lifting/lowering logic, and just replaying the lowering state
1008-
let ret_index = dynamic_params_load_replay(param_tys, MAX_FLAT_PARAMS);
1009-
// Copy the entire contiguous storage slice instead of looping
10101015
if let Some(_cnt) = result_tys.abi.flat_count(MAX_FLAT_RESULTS) {
1016+
// Copy the entire contiguous storage slice instead of looping
10111017
cx.replay_lowering(Some(storage), ReplayLoweringPhase::HostFuncReturn)?;
10121018
} else {
1013-
cx.replay_lowering(
1014-
Some(&mut storage[ret_index..ret_index + 1]),
1015-
ReplayLoweringPhase::HostFuncReturn,
1016-
)?;
1019+
// The retptr is passed through params for lowering
1020+
cx.replay_lowering(None, ReplayLoweringPhase::HostFuncReturn)?;
10171021
}
10181022
Ok(())
10191023
}

0 commit comments

Comments
 (0)