wasmer_wasix/syscalls/wasix/
closure_prepare.rs

1//! Closures provide a way to generate a WASM function that wraps a generic function and an environment.
2//!
3//! A typical usage of this API is as follows:
4//!
5//! 1. Allocate a function pointer for your closure with [`closure_allocate`]
6//! 2. Prepare the closure with [`closure_prepare`]
7//! 3. Call function pointer
8//! 4. Call [`closure_prepare`] again to redefine the function pointer
9//! 5. Notify wasmer that the closure is no longer needed with [`closure_free`]
10
11use crate::{state::DlModuleSpec, syscalls::*};
12use std::{path::PathBuf, sync::atomic::AtomicUsize};
13use wasm_encoder::{
14    CodeSection, CustomSection, ExportKind, ExportSection, FunctionSection, GlobalType,
15    ImportSection, InstructionSink, MemArg, MemoryType, RefType, TableType, TypeSection, ValType,
16};
17use wasmer::{FunctionType, Table, Type, imports};
18
19use wasmer_wasix_types::wasi::WasmValueType;
20
21// Implement helper functions for wasm_encoder::ValType
22trait ValTypeOps
23where
24    Self: Sized,
25{
26    fn from_u8(value: u8) -> Result<Self, Errno>;
27    fn size(&self) -> u64;
28    fn store(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32);
29    fn load(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32);
30}
31impl ValTypeOps for ValType {
32    fn from_u8(value: u8) -> Result<Self, Errno> {
33        let wasix_type = WasmValueType::try_from(value).map_err(|_| Errno::Inval)?;
34        match wasix_type {
35            WasmValueType::I32 => Ok(Self::I32),
36            WasmValueType::I64 => Ok(Self::I64),
37            WasmValueType::F32 => Ok(Self::F32),
38            WasmValueType::F64 => Ok(Self::F64),
39            WasmValueType::V128 => Ok(Self::V128),
40        }
41    }
42    fn size(&self) -> u64 {
43        match self {
44            Self::I32 => 4,
45            Self::I64 => 8,
46            Self::F32 => 4,
47            Self::F64 => 8,
48            Self::V128 => 16,
49            // Not supported in closures.
50            Self::Ref(_) => panic!("Cannot get size of reference type"),
51        }
52    }
53    fn store(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32) {
54        match self {
55            Self::I32 => sink.i32_store(MemArg {
56                offset,
57                align: 0,
58                memory_index,
59            }),
60            Self::I64 => sink.i64_store(MemArg {
61                offset,
62                align: 0,
63                memory_index,
64            }),
65            Self::F32 => sink.f32_store(MemArg {
66                offset,
67                align: 0,
68                memory_index,
69            }),
70            Self::F64 => sink.f64_store(MemArg {
71                offset,
72                align: 0,
73                memory_index,
74            }),
75            Self::V128 => sink.v128_store(MemArg {
76                offset,
77                align: 0,
78                memory_index,
79            }),
80            // Not supported in closures
81            Self::Ref(_) => panic!("Cannot store reference type"),
82        };
83    }
84    fn load(&self, sink: &mut InstructionSink<'_>, offset: u64, memory_index: u32) {
85        match self {
86            Self::I32 => sink.i32_load(MemArg {
87                offset,
88                align: 0,
89                memory_index,
90            }),
91            Self::I64 => sink.i64_load(MemArg {
92                offset,
93                align: 0,
94                memory_index,
95            }),
96            Self::F32 => sink.f32_load(MemArg {
97                offset,
98                align: 0,
99                memory_index,
100            }),
101            Self::F64 => sink.f64_load(MemArg {
102                offset,
103                align: 0,
104                memory_index,
105            }),
106            Self::V128 => sink.v128_load(MemArg {
107                offset,
108                align: 0,
109                memory_index,
110            }),
111            // Not supported in closures
112            Self::Ref(_) => panic!("Cannot load reference type"),
113        };
114    }
115}
116
117/// Build a dynamically linkable WASM module for the given closure.
118fn build_closure_wasm_bytes(
119    module_name: &str,
120    closure: u32,
121    backing_function: u32,
122    environment_offset: u64,
123    argument_types: &[ValType],
124    result_types: &[ValType],
125) -> Vec<u8> {
126    let mut wasm_module = wasm_encoder::Module::new();
127
128    // Add dylink section
129    let dylink = CustomSection {
130        name: Cow::Borrowed("dylink.0"),
131        data: Cow::Borrowed(&[]),
132    };
133    wasm_module.section(&dylink);
134
135    // Add types section
136    let mut types = TypeSection::new();
137    types.ty().function(vec![], vec![]);
138    let on_load_function_type_index = 0;
139    let mut trampoline_function_params = argument_types.to_vec();
140    let mut trampoline_function_results = result_types.to_vec();
141    types
142        .ty()
143        .function(trampoline_function_params, trampoline_function_results);
144    let trampoline_function_type_index = 1;
145    types
146        .ty()
147        .function(vec![ValType::I32, ValType::I32, ValType::I32], vec![]);
148    let backing_function_type_index = 2;
149    wasm_module.section(&types);
150
151    // Add a imports section
152    let mut imports = ImportSection::new();
153    imports.import(
154        "env",
155        "memory",
156        MemoryType {
157            minimum: 1,
158            maximum: Some(65536),
159            shared: true,
160            memory64: false,
161            page_size_log2: None,
162        },
163    );
164    let main_memory_index = 0;
165    imports.import(
166        "env",
167        "__indirect_function_table",
168        TableType {
169            element_type: RefType::FUNCREF,
170            minimum: 1,
171            maximum: None,
172            shared: false,
173            table64: false,
174        },
175    );
176    let indirect_function_table_index = 0;
177    imports.import(
178        "env",
179        "__stack_pointer",
180        GlobalType {
181            val_type: ValType::I32,
182            mutable: true,
183            shared: false,
184        },
185    );
186    let stack_pointer_index = 0;
187    imports.import(
188        "GOT.func",
189        module_name,
190        GlobalType {
191            val_type: ValType::I32,
192            mutable: true,
193            shared: false,
194        },
195    );
196    let trampoline_function_pointer_index = 1;
197    wasm_module.section(&imports);
198
199    let mut functions = FunctionSection::new();
200    functions.function(on_load_function_type_index);
201    let on_load_function_index = 0;
202    functions.function(trampoline_function_type_index);
203    let trampoline_function_index = 1;
204    wasm_module.section(&functions);
205
206    // Add an export section
207    // FIXME: Look into replacing this with the wasm start function
208    let mut exports = ExportSection::new();
209    exports.export(
210        "__wasm_call_ctors",
211        ExportKind::Func,
212        on_load_function_index,
213    );
214    exports.export(module_name, ExportKind::Func, trampoline_function_index);
215    wasm_module.section(&exports);
216
217    let mut code = CodeSection::new();
218    let mut on_load_function = wasm_encoder::Function::new(vec![]);
219    on_load_function
220        .instructions()
221        .i32_const(closure as i32)
222        .global_get(trampoline_function_pointer_index)
223        .table_get(indirect_function_table_index)
224        .table_set(indirect_function_table_index)
225        .end();
226    code.function(&on_load_function);
227
228    let mut trampoline_function = wasm_encoder::Function::new(vec![(3, ValType::I32)]);
229    let original_stackpointer_local: u32 = argument_types.len() as u32;
230    let arguments_base_local: u32 = argument_types.len() as u32 + 1;
231    let results_base_local: u32 = argument_types.len() as u32 + 2;
232    let mut trampoline_function_instructions = trampoline_function.instructions();
233    let values_size = argument_types
234        .iter()
235        .map(ValType::size)
236        .sum::<u64>()
237        .next_multiple_of(16);
238    let results_size = result_types
239        .iter()
240        .map(ValType::size)
241        .sum::<u64>()
242        .next_multiple_of(16);
243    trampoline_function_instructions
244        .global_get(stack_pointer_index)
245        .local_tee(original_stackpointer_local)
246        .i32_const(results_size as i32)
247        .i32_sub()
248        .local_tee(results_base_local)
249        .i32_const(values_size as i32)
250        .i32_sub()
251        .local_tee(arguments_base_local)
252        .global_set(stack_pointer_index);
253    argument_types.iter().enumerate().fold(
254        (0, &mut trampoline_function_instructions),
255        |mut acc, (index, ty)| {
256            let size = ty.size();
257            acc.1
258                .local_get(arguments_base_local)
259                .local_get(index as u32);
260            ty.store(acc.1, acc.0, main_memory_index);
261            acc.0 += size;
262            acc
263        },
264    );
265    trampoline_function_instructions
266        .local_get(arguments_base_local)
267        .local_get(results_base_local)
268        .i32_const(environment_offset as i32)
269        .i32_const(backing_function as i32)
270        .call_indirect(indirect_function_table_index, backing_function_type_index);
271    result_types.iter().enumerate().fold(
272        (0, &mut trampoline_function_instructions),
273        |mut acc, (index, ty)| {
274            let size = ty.size();
275            acc.1.local_get(results_base_local);
276            ty.load(acc.1, acc.0, main_memory_index);
277            acc.0 += size;
278            acc
279        },
280    );
281    trampoline_function_instructions
282        .local_get(original_stackpointer_local)
283        .global_set(stack_pointer_index)
284        .end();
285    code.function(&trampoline_function);
286    wasm_module.section(&code);
287
288    wasm_module.finish()
289}
290
291// Monotonically incrementing id for closures
292static CLOSURE_ID: AtomicUsize = AtomicUsize::new(0);
293
294/// Prepare a closure so that it can be called with a given signature.
295///
296/// When the closure is called after [`closure_prepare`], the arguments will be decoded and passed to the backing function together with a pointer to the environment.
297///
298/// The backing function needs to conform to the following signature:
299///   uint8_t* values - a pointer to a buffer containing the arguments.
300///   uint8_t* results - a pointer to a buffer where the results will be written.
301///   void* environment - the environment that was passed to closure_prepare
302///
303/// `backing_function` is a pointer (index into `__indirect_function_table`) to the backing function
304///
305/// `closure` is a pointer (index into `__indirect_function_table`) to the closure that was obtained via [`closure_allocate`].
306///
307/// `argument_types_ptr` is a pointer to the argument types as a list of [`WasmValueType`]s
308/// `argument_types_length` is the number of arguments
309///
310/// `result_types_ptr` is a pointer to the result types as a list of [`WasmValueType`]s
311/// `result_types_length` is the number of results
312///
313/// `environment` is the closure environment that will be passed to the backing function alongside the decoded arguments and results
314#[instrument(level = "trace", fields(%backing_function, %closure), ret)]
315pub fn closure_prepare<M: MemorySize>(
316    mut ctx: FunctionEnvMut<'_, WasiEnv>,
317    backing_function: u32,
318    closure: u32,
319    argument_types_ptr: WasmPtr<u8, M>,
320    argument_types_length: u32,
321    result_types_ptr: WasmPtr<u8, M>,
322    result_types_length: u32,
323    environment: WasmPtr<u8, M>,
324) -> Result<Errno, WasiError> {
325    WasiEnv::do_pending_operations(&mut ctx)?;
326
327    let (env, mut store) = ctx.data_and_store_mut();
328    let memory = unsafe { env.memory_view(&store) };
329
330    let Some(linker) = env.inner().linker().cloned() else {
331        error!("Closures only work for dynamic modules.");
332        return Ok(Errno::Notsup);
333    };
334
335    let argument_types = {
336        let arg_offset = argument_types_ptr.offset().into();
337        let arguments_slice = wasi_try_mem_ok!(
338            WasmSlice::new(&memory, arg_offset, argument_types_length as u64)
339                .and_then(WasmSlice::access)
340        );
341        wasi_try_ok!(
342            arguments_slice
343                .iter()
344                .map(|t: &u8| ValType::from_u8(*t))
345                .collect::<Result<Vec<_>, Errno>>()
346        )
347    };
348
349    let result_types = {
350        let res_offset = result_types_ptr.offset().into();
351        let result_slice = wasi_try_mem_ok!(
352            WasmSlice::new(&memory, res_offset, result_types_length as u64)
353                .and_then(WasmSlice::access)
354        );
355        wasi_try_ok!(
356            result_slice
357                .iter()
358                .map(|t: &u8| ValType::from_u8(*t))
359                .collect::<Result<Vec<_>, Errno>>()
360        )
361    };
362
363    let module_name = format!(
364        "__wasix_closure_{}",
365        CLOSURE_ID.fetch_add(1, Ordering::SeqCst),
366    );
367
368    let wasm_bytes = build_closure_wasm_bytes(
369        &module_name,
370        closure,
371        backing_function,
372        environment.offset().into(),
373        &argument_types,
374        &result_types,
375    );
376
377    let ld_library_path: [&Path; 0] = [];
378    let wasm_loader = DlModuleSpec::Memory {
379        module_name: &module_name,
380        bytes: &wasm_bytes,
381    };
382    let module_handle = match linker.load_module(wasm_loader, &mut ctx) {
383        Ok(m) => m,
384        Err(e) => {
385            // Should never happen
386            panic!("Failed to load newly built in-memory module: {e}");
387        }
388    };
389
390    return Ok(Errno::Success);
391}