wasmer_wasix/syscalls/wasix/
closure_prepare.rs

1//! Closures provide a way to generate a WASM function that wraps a generic function and an environment.
2//!
3//! A typical usage of this API is as follows:
4//!
5//! 1. Allocate a function pointer for your closure with [`closure_allocate`]
6//! 2. Prepare the closure with [`closure_prepare`]
7//! 3. Call function pointer
8//! 4. Call [`closure_prepare`] again to redefine the function pointer
9//! 5. Notify wasmer that the closure is no longer needed with [`closure_free`]
10
11use crate::{state::DlModuleSpec, syscalls::*};
12use std::{path::PathBuf, sync::atomic::AtomicUsize};
13use wasm_encoder::{
14    CodeSection, CustomSection, ExportKind, ExportSection, FunctionSection, GlobalType,
15    ImportSection, InstructionSink, MemArg, MemoryType, RefType, TableType, TypeSection, ValType,
16};
17use wasmer::{FunctionType, Table, Type, imports};
18
19use wasmer_wasix_types::wasi::WasmValueType;
20
21// Implement helper functions for wasm_encoder::ValType
22trait ValTypeOps
23where
24    Self: Sized,
25{
26    fn from_u8(value: u8) -> Result<Self, Errno>;
27    fn size(&self) -> u64;
28    fn store(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32);
29    fn load(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32);
30}
31impl ValTypeOps for ValType {
32    fn from_u8(value: u8) -> Result<Self, Errno> {
33        let wasix_type = WasmValueType::try_from(value).map_err(|_| Errno::Inval)?;
34        match wasix_type {
35            WasmValueType::I32 => Ok(Self::I32),
36            WasmValueType::I64 => Ok(Self::I64),
37            WasmValueType::F32 => Ok(Self::F32),
38            WasmValueType::F64 => Ok(Self::F64),
39            WasmValueType::V128 => Ok(Self::V128),
40            _ => Err(Errno::Inval),
41        }
42    }
43    fn size(&self) -> u64 {
44        match self {
45            Self::I32 => 4,
46            Self::I64 => 8,
47            Self::F32 => 4,
48            Self::F64 => 8,
49            Self::V128 => 16,
50            // Not supported in closures.
51            Self::Ref(_) => panic!("Cannot get size of reference type"),
52        }
53    }
54    fn store(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32) {
55        match self {
56            Self::I32 => sink.i32_store(MemArg {
57                offset,
58                align: 0,
59                memory_index,
60            }),
61            Self::I64 => sink.i64_store(MemArg {
62                offset,
63                align: 0,
64                memory_index,
65            }),
66            Self::F32 => sink.f32_store(MemArg {
67                offset,
68                align: 0,
69                memory_index,
70            }),
71            Self::F64 => sink.f64_store(MemArg {
72                offset,
73                align: 0,
74                memory_index,
75            }),
76            Self::V128 => sink.v128_store(MemArg {
77                offset,
78                align: 0,
79                memory_index,
80            }),
81            // Not supported in closures
82            Self::Ref(_) => panic!("Cannot store reference type"),
83        };
84    }
85    fn load(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32) {
86        match self {
87            Self::I32 => sink.i32_load(MemArg {
88                offset,
89                align: 0,
90                memory_index,
91            }),
92            Self::I64 => sink.i64_load(MemArg {
93                offset,
94                align: 0,
95                memory_index,
96            }),
97            Self::F32 => sink.f32_load(MemArg {
98                offset,
99                align: 0,
100                memory_index,
101            }),
102            Self::F64 => sink.f64_load(MemArg {
103                offset,
104                align: 0,
105                memory_index,
106            }),
107            Self::V128 => sink.v128_load(MemArg {
108                offset,
109                align: 0,
110                memory_index,
111            }),
112            // Not supported in closures
113            Self::Ref(_) => panic!("Cannot load reference type"),
114        };
115    }
116}
117
118/// Build a dynamically linkable WASM module for the given closure.
119fn build_closure_wasm_bytes(
120    module_name: &str,
121    closure: u32,
122    backing_function: u32,
123    environment_offset: u64,
124    argument_types: &[ValType],
125    result_types: &[ValType],
126) -> Vec<u8> {
127    let mut wasm_module = wasm_encoder::Module::new();
128
129    // Add dylink section
130    let dylink = CustomSection {
131        name: Cow::Borrowed("dylink.0"),
132        data: Cow::Borrowed(&[]),
133    };
134    wasm_module.section(&dylink);
135
136    // Add types section
137    let mut types = TypeSection::new();
138    types.ty().function(vec![], vec![]);
139    let on_load_function_type_index = 0;
140    let mut trampoline_function_params = argument_types.to_vec();
141    let mut trampoline_function_results = result_types.to_vec();
142    types
143        .ty()
144        .function(trampoline_function_params, trampoline_function_results);
145    let trampoline_function_type_index = 1;
146    types
147        .ty()
148        .function(vec![ValType::I32, ValType::I32, ValType::I32], vec![]);
149    let backing_function_type_index = 2;
150    wasm_module.section(&types);
151
152    // Add a imports section
153    let mut imports = ImportSection::new();
154    imports.import(
155        "env",
156        "memory",
157        MemoryType {
158            minimum: 1,
159            maximum: Some(65536),
160            shared: true,
161            memory64: false,
162            page_size_log2: None,
163        },
164    );
165    let main_memory_index = 0;
166    imports.import(
167        "env",
168        "__indirect_function_table",
169        TableType {
170            element_type: RefType::FUNCREF,
171            minimum: 1,
172            maximum: None,
173            shared: false,
174            table64: false,
175        },
176    );
177    let indirect_function_table_index = 0;
178    imports.import(
179        "env",
180        "__stack_pointer",
181        GlobalType {
182            val_type: ValType::I32,
183            mutable: true,
184            shared: false,
185        },
186    );
187    let stack_pointer_index = 0;
188    imports.import(
189        "GOT.func",
190        module_name,
191        GlobalType {
192            val_type: ValType::I32,
193            mutable: true,
194            shared: false,
195        },
196    );
197    let trampoline_function_pointer_index = 1;
198    wasm_module.section(&imports);
199
200    let mut functions = FunctionSection::new();
201    functions.function(on_load_function_type_index);
202    let on_load_function_index = 0;
203    functions.function(trampoline_function_type_index);
204    let trampoline_function_index = 1;
205    wasm_module.section(&functions);
206
207    // Add an export section
208    // FIXME: Look into replacing this with the wasm start function
209    let mut exports = ExportSection::new();
210    exports.export(
211        "__wasm_call_ctors",
212        ExportKind::Func,
213        on_load_function_index,
214    );
215    exports.export(module_name, ExportKind::Func, trampoline_function_index);
216    wasm_module.section(&exports);
217
218    let mut code = CodeSection::new();
219    let mut on_load_function = wasm_encoder::Function::new(vec![]);
220    on_load_function
221        .instructions()
222        .i32_const(closure as i32)
223        .global_get(trampoline_function_pointer_index)
224        .table_get(indirect_function_table_index)
225        .table_set(indirect_function_table_index)
226        .end();
227    code.function(&on_load_function);
228
229    let mut trampoline_function = wasm_encoder::Function::new(vec![(3, ValType::I32)]);
230    let original_stackpointer_local: u32 = argument_types.len() as u32;
231    let arguments_base_local: u32 = argument_types.len() as u32 + 1;
232    let results_base_local: u32 = argument_types.len() as u32 + 2;
233    let mut trampoline_function_instructions = trampoline_function.instructions();
234    let values_size = argument_types
235        .iter()
236        .map(ValType::size)
237        .sum::<u64>()
238        .next_multiple_of(16);
239    let results_size = result_types
240        .iter()
241        .map(ValType::size)
242        .sum::<u64>()
243        .next_multiple_of(16);
244    trampoline_function_instructions
245        .global_get(stack_pointer_index)
246        .local_tee(original_stackpointer_local)
247        .i32_const(results_size as i32)
248        .i32_sub()
249        .local_tee(results_base_local)
250        .i32_const(values_size as i32)
251        .i32_sub()
252        .local_tee(arguments_base_local)
253        .global_set(stack_pointer_index);
254    argument_types.iter().enumerate().fold(
255        (0, &mut trampoline_function_instructions),
256        |mut acc, (index, ty)| {
257            let size = ty.size();
258            acc.1
259                .local_get(arguments_base_local)
260                .local_get(index as u32);
261            ty.store(acc.1, acc.0, main_memory_index);
262            acc.0 += size;
263            acc
264        },
265    );
266    trampoline_function_instructions
267        .local_get(arguments_base_local)
268        .local_get(results_base_local)
269        .i32_const(environment_offset as i32)
270        .i32_const(backing_function as i32)
271        .call_indirect(indirect_function_table_index, backing_function_type_index);
272    result_types.iter().enumerate().fold(
273        (0, &mut trampoline_function_instructions),
274        |mut acc, (index, ty)| {
275            let size = ty.size();
276            acc.1.local_get(results_base_local);
277            ty.load(acc.1, acc.0, main_memory_index);
278            acc.0 += size;
279            acc
280        },
281    );
282    trampoline_function_instructions
283        .local_get(original_stackpointer_local)
284        .global_set(stack_pointer_index)
285        .end();
286    code.function(&trampoline_function);
287    wasm_module.section(&code);
288
289    wasm_module.finish()
290}
291
292// Monotonically incrementing id for closures
293static CLOSURE_ID: AtomicUsize = AtomicUsize::new(0);
294
295/// Prepare a closure so that it can be called with a given signature.
296///
297/// When the closure is called after [`closure_prepare`], the arguments will be decoded and passed to the backing function together with a pointer to the environment.
298///
299/// The backing function needs to conform to the following signature:
300///   uint8_t* values - a pointer to a buffer containing the arguments.
301///   uint8_t* results - a pointer to a buffer where the results will be written.
302///   void* environment - the environment that was passed to closure_prepare
303///
304/// `backing_function` is a pointer (index into `__indirect_function_table`) to the backing function
305///
306/// `closure` is a pointer (index into `__indirect_function_table`) to the closure that was obtained via [`closure_allocate`].
307///
308/// `argument_types_ptr` is a pointer to the argument types as a list of [`WasmValueType`]s
309/// `argument_types_length` is the number of arguments
310///
311/// `result_types_ptr` is a pointer to the result types as a list of [`WasmValueType`]s
312/// `result_types_length` is the number of results
313///
314/// `environment` is the closure environment that will be passed to the backing function alongside the decoded arguments and results
315#[instrument(level = "trace", fields(%backing_function, %closure), ret)]
316pub fn closure_prepare<M: MemorySize>(
317    mut ctx: FunctionEnvMut<'_, WasiEnv>,
318    backing_function: u32,
319    closure: u32,
320    argument_types_ptr: WasmPtr<u8, M>,
321    argument_types_length: u32,
322    result_types_ptr: WasmPtr<u8, M>,
323    result_types_length: u32,
324    environment: WasmPtr<u8, M>,
325) -> Result<Errno, WasiError> {
326    WasiEnv::do_pending_operations(&mut ctx)?;
327
328    let (env, mut store) = ctx.data_and_store_mut();
329    let memory = unsafe { env.memory_view(&store) };
330
331    let Some(linker) = env.inner().linker().cloned() else {
332        error!("Closures only work for dynamic modules.");
333        return Ok(Errno::Notsup);
334    };
335
336    let argument_types = {
337        let arg_offset = argument_types_ptr.offset().into();
338        let arguments_slice = wasi_try_mem_ok!(
339            WasmSlice::new(&memory, arg_offset, argument_types_length as u64)
340                .and_then(WasmSlice::access)
341        );
342        wasi_try_ok!(
343            arguments_slice
344                .iter()
345                .map(|t: &u8| ValType::from_u8(*t))
346                .collect::<Result<Vec<_>, Errno>>()
347        )
348    };
349
350    let result_types = {
351        let res_offset = result_types_ptr.offset().into();
352        let result_slice = wasi_try_mem_ok!(
353            WasmSlice::new(&memory, res_offset, result_types_length as u64)
354                .and_then(WasmSlice::access)
355        );
356        wasi_try_ok!(
357            result_slice
358                .iter()
359                .map(|t: &u8| ValType::from_u8(*t))
360                .collect::<Result<Vec<_>, Errno>>()
361        )
362    };
363
364    let module_name = format!(
365        "__wasix_closure_{}",
366        CLOSURE_ID.fetch_add(1, Ordering::SeqCst),
367    );
368
369    let wasm_bytes = build_closure_wasm_bytes(
370        &module_name,
371        closure,
372        backing_function,
373        environment.offset().into(),
374        &argument_types,
375        &result_types,
376    );
377
378    let ld_library_path: [&Path; 0] = [];
379    let wasm_loader = DlModuleSpec::Memory {
380        module_name: &module_name,
381        bytes: &wasm_bytes,
382    };
383    let module_handle = match linker.load_module(wasm_loader, &mut ctx) {
384        Ok(m) => m,
385        Err(e) => {
386            // Should never happen
387            panic!("Failed to load newly built in-memory module: {e}");
388        }
389    };
390
391    return Ok(Errno::Success);
392}