Skip to content

Commit 6f67b5c

Browse files
feat(core): implement async export abi call lift lower (#1394)
* feat(core): implement abi::call lift/lower for async export * fix(core): max flat params * refactor(core): slightly improve todos
1 parent cff17ba commit 6f67b5c

File tree

1 file changed

+66
-50
lines changed

1 file changed

+66
-50
lines changed

crates/core/src/abi.rs

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -958,13 +958,9 @@ impl<'a, B: Bindgen> Generator<'a, B> {
958958

959959
match lift_lower {
960960
LiftLower::LowerArgsLiftResults => {
961-
assert!(!async_, "generators should not be using this for async");
962-
963961
self.realloc = Some(realloc);
964-
if let (AbiVariant::GuestExport, true) = (variant, async_) {
965-
unimplemented!("host-side code generation for async lift/lower not supported");
966-
}
967962

963+
// Create a function that performs individual lowering of operands
968964
let lower_to_memory = |self_: &mut Self, ptr: B::Operand| {
969965
let mut offset = ArchitectureSize::default();
970966
for (nth, (_, ty)) in func.params.iter().enumerate() {
@@ -977,26 +973,26 @@ impl<'a, B: Bindgen> Generator<'a, B> {
977973
self_.stack.push(ptr);
978974
};
979975

980-
if !sig.indirect_params {
981-
// If the parameters for this function aren't indirect
982-
// (there aren't too many) then we simply do a normal lower
983-
// operation for them all.
984-
for (nth, (_, ty)) in func.params.iter().enumerate() {
985-
self.emit(&Instruction::GetArg { nth });
986-
self.lower(ty);
987-
}
988-
} else {
989-
// ... otherwise if parameters are indirect space is
976+
// Lower parameters
977+
if sig.indirect_params {
978+
// If parameters are indirect space is
990979
// allocated for them and each argument is lowered
991980
// individually into memory.
992981
let ElementInfo { size, align } = self
993982
.bindgen
994983
.sizes()
995984
.record(func.params.iter().map(|t| &t.1));
985+
986+
// Resolve the pointer to the indirectly stored parameters
996987
let ptr = match variant {
997988
// When a wasm module calls an import it will provide
998989
// space that isn't explicitly deallocated.
999990
AbiVariant::GuestImport => self.bindgen.return_pointer(size, align),
991+
992+
AbiVariant::GuestImportAsync => {
993+
todo!("direct param lowering for async guest import not implemented")
994+
}
995+
1000996
// When calling a wasm module from the outside, though,
1001997
// malloc needs to be called.
1002998
AbiVariant::GuestExport => {
@@ -1007,39 +1003,45 @@ impl<'a, B: Bindgen> Generator<'a, B> {
10071003
});
10081004
self.stack.pop().unwrap()
10091005
}
1010-
AbiVariant::GuestImportAsync
1011-
| AbiVariant::GuestExportAsync
1012-
| AbiVariant::GuestExportAsyncStackful => {
1013-
unreachable!()
1006+
1007+
AbiVariant::GuestExportAsync | AbiVariant::GuestExportAsyncStackful => {
1008+
todo!("direct param lowering for async not implemented")
10141009
}
10151010
};
1011+
1012+
// Lower the parameters to memory
10161013
lower_to_memory(self, ptr);
1014+
} else {
1015+
// ... otherwise arguments are direct,
1016+
// (there aren't too many) then we simply do a normal lower
1017+
// operation for them all.
1018+
for (nth, (_, ty)) in func.params.iter().enumerate() {
1019+
self.emit(&Instruction::GetArg { nth });
1020+
self.lower(ty);
1021+
}
10171022
}
10181023
self.realloc = None;
10191024

1020-
// If necessary we may need to prepare a return pointer for
1021-
// this ABI.
1025+
// If necessary we may need to prepare a return pointer for this ABI.
10221026
if variant == AbiVariant::GuestImport && sig.retptr {
10231027
let info = self.bindgen.sizes().params(&func.result);
10241028
let ptr = self.bindgen.return_pointer(info.size, info.align);
10251029
self.return_pointer = Some(ptr.clone());
10261030
self.stack.push(ptr);
10271031
}
10281032

1033+
// Call the Wasm function
10291034
assert_eq!(self.stack.len(), sig.params.len());
10301035
self.emit(&Instruction::CallWasm {
10311036
name: &func.name,
10321037
sig: &sig,
10331038
});
10341039

1035-
if !sig.retptr {
1036-
// With no return pointer in use we can simply lift the
1037-
// result(s) of the function from the result of the core
1038-
// wasm function.
1039-
if let Some(ty) = &func.result {
1040-
self.lift(ty)
1041-
}
1042-
} else {
1040+
// Handle the result
1041+
if sig.retptr {
1042+
// If there is a return pointer, we must get the pointer to where results
1043+
// should be stored, and store the results there?
1044+
10431045
let ptr = match variant {
10441046
// imports into guests means it's a wasm module
10451047
// calling an imported function. We supplied the
@@ -1063,16 +1065,34 @@ impl<'a, B: Bindgen> Generator<'a, B> {
10631065
}
10641066
};
10651067

1066-
self.read_results_from_memory(
1067-
&func.result,
1068-
ptr.clone(),
1069-
ArchitectureSize::default(),
1070-
);
1071-
self.emit(&Instruction::Flush {
1072-
amt: usize::from(func.result.is_some()),
1073-
});
1068+
if let (AbiVariant::GuestExport, true) = (variant, async_) {
1069+
// If we're dealing with an async function, the result should not be read from memory
1070+
// immediately, as it's the async call result
1071+
//
1072+
// We can leave the result of the call (the indication of what to do as an async call)
1073+
// on the stack as a return
1074+
self.stack.push(ptr);
1075+
} else {
1076+
// If we're not dealing with an async call, the result must be in memory at this point and can be read out
1077+
self.read_results_from_memory(
1078+
&func.result,
1079+
ptr.clone(),
1080+
ArchitectureSize::default(),
1081+
);
1082+
self.emit(&Instruction::Flush {
1083+
amt: usize::from(func.result.is_some()),
1084+
});
1085+
}
1086+
} else {
1087+
// With no return pointer in use we can simply lift the
1088+
// result(s) of the function from the result of the core
1089+
// wasm function.
1090+
if let Some(ty) = &func.result {
1091+
self.lift(ty)
1092+
}
10741093
}
10751094

1095+
// Emit the function return
10761096
self.emit(&Instruction::Return {
10771097
func,
10781098
amt: usize::from(func.result.is_some()),
@@ -1081,9 +1101,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
10811101

10821102
LiftLower::LiftArgsLowerResults => {
10831103
let max_flat_params = match (variant, async_) {
1084-
(AbiVariant::GuestImport | AbiVariant::GuestImportAsync, _is_async @ true) => {
1085-
MAX_FLAT_ASYNC_PARAMS
1086-
}
1104+
(AbiVariant::GuestImportAsync, _is_async @ true) => MAX_FLAT_ASYNC_PARAMS,
10871105
_ => MAX_FLAT_PARAMS,
10881106
};
10891107

@@ -1113,9 +1131,11 @@ impl<'a, B: Bindgen> Generator<'a, B> {
11131131
// argument in succession from the component wasm types that
11141132
// make-up the type.
11151133
let mut offset = 0;
1116-
for (_, ty) in func.params.iter() {
1117-
let types = flat_types(self.resolve, ty, Some(max_flat_params))
1118-
.expect(&format!("direct parameter load failed to produce types during generation of fn call (func name: '{}')", func.name));
1134+
for (param_name, ty) in func.params.iter() {
1135+
let Some(types) = flat_types(self.resolve, ty, Some(max_flat_params))
1136+
else {
1137+
panic!("failed to flatten types during direct parameter lifting ('{param_name}' in func '{}')", func.name);
1138+
};
11191139
for _ in 0..types.len() {
11201140
self.emit(&Instruction::GetArg { nth: offset });
11211141
offset += 1;
@@ -2473,12 +2493,8 @@ fn cast(from: WasmType, to: WasmType) -> Bitcast {
24732493
/// It is sometimes necessary to restrict the number of max parameters dynamically,
24742494
/// for example during an async guest import call (flat params are limited to 4)
24752495
fn flat_types(resolve: &Resolve, ty: &Type, max_params: Option<usize>) -> Option<Vec<WasmType>> {
2476-
let mut storage =
2477-
iter::repeat_n(WasmType::I32, max_params.unwrap_or(MAX_FLAT_PARAMS)).collect::<Vec<_>>();
2496+
let max_params = max_params.unwrap_or(MAX_FLAT_PARAMS);
2497+
let mut storage = iter::repeat_n(WasmType::I32, max_params).collect::<Vec<_>>();
24782498
let mut flat = FlatTypes::new(storage.as_mut_slice());
2479-
if resolve.push_flat(ty, &mut flat) {
2480-
Some(flat.to_vec())
2481-
} else {
2482-
None
2483-
}
2499+
resolve.push_flat(ty, &mut flat).then_some(flat.to_vec())
24842500
}

0 commit comments

Comments
 (0)