wasmer_vm/trap/
traphandlers.rs

1// This file contains code from external sources.
2// Attributions: https://github.com/wasmerio/wasmer/blob/main/docs/ATTRIBUTIONS.md
3
4#![allow(static_mut_refs)]
5
6//! WebAssembly trap handling, which is built on top of the lower-level
7//! signalhandling mechanisms.
8
9#[cfg(all(unix, feature = "experimental-host-interrupt"))]
10use crate::interrupt_registry;
11use crate::vmcontext::{VMFunctionContext, VMTrampoline};
12use crate::{Trap, VMContext, VMFunctionBody};
13use backtrace::Backtrace;
14use bytesize::ByteSize;
15use core::ptr::{read, read_unaligned};
16use corosensei::stack::{DefaultStack, Stack};
17use corosensei::trap::{CoroutineTrapHandler, TrapHandlerRegs};
18use corosensei::{CoroutineResult, ScopedCoroutine, Yielder};
19use scopeguard::defer;
20use std::any::Any;
21use std::cell::Cell;
22use std::error::Error;
23use std::io;
24use std::mem;
25#[cfg(unix)]
26use std::mem::MaybeUninit;
27use std::ptr::{self, NonNull};
28use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering, compiler_fence};
29use std::sync::{LazyLock, Once};
30use wasmer_types::TrapCode;
31
32/// Convenience extension for [`Stack`] that exposes the total mapped size.
33trait StackExt: Stack {
34    /// Returns the total size of the stack mapping (including guard page).
35    fn size(&self) -> usize {
36        self.base().get() - self.limit().get()
37    }
38}
39impl<T: Stack> StackExt for T {}
40
41/// Configuration for the runtime VM
42/// Currently only the stack size is configurable
43pub struct VMConfig {
44    /// Optional stack size (in byte) of the VM. Value lower than 8K will be rounded to 8K.
45    pub wasm_stack_size: Option<usize>,
46}
47
48// TrapInformation can be stored in the "Undefined Instruction" itself.
49// On x86_64, 0xC? select a "Register" for the Mod R/M part of "ud1" (so with no other bytes after)
50// On Arm64, the udf allows for a 16bits values, so we'll use the same 0xC? to store the trapinfo
51static MAGIC: u8 = 0xc0;
52
53static DEFAULT_STACK_SIZE: AtomicUsize = AtomicUsize::new(ByteSize::mib(1).as_u64() as usize);
54
55/// Maximum allowed default stack size (100 MiB) for the process-wide
56/// configuration set via `set_stack_size`.
57pub const MAX_STACK_SIZE: usize = ByteSize::mib(100).as_u64() as usize;
58
59// Current definition of `ucontext_t` in the `libc` crate is incorrect
60// on aarch64-apple-drawing so it's defined here with a more accurate definition.
61#[repr(C)]
62#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
63#[allow(non_camel_case_types)]
64struct ucontext_t {
65    uc_onstack: libc::c_int,
66    uc_sigmask: libc::sigset_t,
67    uc_stack: libc::stack_t,
68    uc_link: *mut libc::ucontext_t,
69    uc_mcsize: usize,
70    uc_mcontext: libc::mcontext_t,
71}
72
73#[cfg(all(unix, not(all(target_arch = "aarch64", target_os = "macos"))))]
74use libc::ucontext_t;
75
76/// Sets the process-wide default stack size for new Wasmer coroutines.
77/// The value is clamped to [8 KiB, MAX_STACK_SIZE].
78pub fn set_stack_size(size: usize) {
79    DEFAULT_STACK_SIZE.store(
80        size.clamp(ByteSize::kib(8).as_u64() as usize, MAX_STACK_SIZE),
81        Ordering::Relaxed,
82    );
83}
84
85/// Returns the process-wide default stack size in bytes.
86pub fn get_stack_size() -> usize {
87    DEFAULT_STACK_SIZE.load(Ordering::Relaxed)
88}
89
90/// Pool of pre-allocated coroutine stacks to avoid repeated mmap syscalls.
91static STACK_POOL: LazyLock<crossbeam_queue::SegQueue<DefaultStack>> =
92    LazyLock::new(crossbeam_queue::SegQueue::new);
93
94/// Drains the coroutine stack pool at the moment it runs.
95///
96/// This is intended to be called before retrying with a larger stack size so
97/// that the pool does not keep serving cached undersized stacks.
98///
99/// Note that `STACK_POOL` is a global, concurrently used queue. Other threads
100/// may push stacks back into the pool (for example, when their Wasm execution
101/// finishes) while or after this function is running. As a result, this
102/// function provides only a best-effort drain of the pool: there is no
103/// guarantee that no undersized stacks exist immediately after it returns
104/// unless the caller ensures, via external synchronization, that no other
105/// Wasm executions can return stacks to the pool while this function runs.
106pub fn drain_stack_pool() {
107    while STACK_POOL.pop().is_some() {}
108}
109
110cfg_if::cfg_if! {
111    if #[cfg(unix)] {
112        /// Function which may handle custom signals while processing traps.
113        pub type TrapHandlerFn<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + Send + Sync + 'a;
114    } else if #[cfg(target_os = "windows")] {
115        /// Function which may handle custom signals while processing traps.
116        pub type TrapHandlerFn<'a> = dyn Fn(*mut windows_sys::Win32::System::Diagnostics::Debug::EXCEPTION_POINTERS) -> bool + Send + Sync + 'a;
117    }
118}
119
120// Process an IllegalOpcode to see if it has a TrapCode payload
121unsafe fn process_illegal_op(addr: usize) -> Option<TrapCode> {
122    let mut val: Option<u8> = None;
123    unsafe {
124        if cfg!(target_arch = "x86_64") {
125            val = if read(addr as *mut u8) & 0xf0 == 0x40
126                && read((addr + 1) as *mut u8) == 0x0f
127                && read((addr + 2) as *mut u8) == 0xb9
128            {
129                Some(read((addr + 3) as *mut u8))
130            } else if read(addr as *mut u8) == 0x0f && read((addr + 1) as *mut u8) == 0xb9 {
131                Some(read((addr + 2) as *mut u8))
132            } else {
133                None
134            }
135        }
136        if cfg!(target_arch = "aarch64") {
137            val = if read_unaligned(addr as *mut u32) & 0xffff0000 == 0 {
138                Some(read(addr as *mut u8))
139            } else {
140                None
141            }
142        }
143        if cfg!(target_arch = "riscv64") {
144            let addr = addr as *mut u32;
145            // Check if 'unimp' instruction
146            val = if read(addr) == 0xc0001073 {
147                // Read from the instruction we emitted: 'addi a0, xzero, $payload'
148                // and take the encoded immediate value (upper 12-bits).
149                let prev_insn = read(addr.sub(1));
150                if (prev_insn & 0xffff) == 0x0513 {
151                    Some((prev_insn >> 20) as u8)
152                } else {
153                    None
154                }
155            } else {
156                None
157            };
158        }
159    }
160
161    // The direct encoding of a trap into the instruction is unused on RISC-V:
162    if cfg!(target_arch = "x86_64") || cfg!(target_arch = "aarch64") {
163        val = val.and_then(|val| {
164            if val & MAGIC == MAGIC {
165                Some(val & 0xf)
166            } else {
167                None
168            }
169        });
170    }
171
172    match val {
173        None => None,
174        Some(val) => match val {
175            0 => Some(TrapCode::StackOverflow),
176            1 => Some(TrapCode::HeapAccessOutOfBounds),
177            2 => Some(TrapCode::HeapMisaligned),
178            3 => Some(TrapCode::TableAccessOutOfBounds),
179            4 => Some(TrapCode::IndirectCallToNull),
180            5 => Some(TrapCode::BadSignature),
181            6 => Some(TrapCode::IntegerOverflow),
182            7 => Some(TrapCode::IntegerDivisionByZero),
183            8 => Some(TrapCode::BadConversionToInteger),
184            9 => Some(TrapCode::UnreachableCodeReached),
185            10 => Some(TrapCode::UnalignedAtomic),
186            _ => None,
187        },
188    }
189}
190
191cfg_if::cfg_if! {
192    if #[cfg(unix)] {
193        static mut PREV_SIGSEGV: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
194        static mut PREV_SIGBUS: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
195        static mut PREV_SIGILL: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
196        static mut PREV_SIGFPE: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
197
198        #[cfg(feature = "experimental-host-interrupt")]
199        static mut PREV_SIGUSR1: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
200
201        unsafe fn platform_init() { unsafe {
202            let register = |slot: &mut MaybeUninit<libc::sigaction>, signal: i32, nodefer: bool| {
203                let mut handler: libc::sigaction = mem::zeroed();
204                // The flags here are relatively careful, and they are...
205                //
206                // SA_SIGINFO gives us access to information like the program
207                // counter from where the fault happened.
208                //
209                // SA_ONSTACK allows us to handle signals on an alternate stack,
210                // so that the handler can run in response to running out of
211                // stack space on the main stack. Rust installs an alternate
212                // stack with sigaltstack, so we rely on that.
213                //
214                // SA_NODEFER allows us to reenter the signal handler if we
215                // crash while handling the signal, and fall through to the
216                // Breakpad handler by testing handlingSegFault.
217                handler.sa_flags = libc::SA_SIGINFO | libc::SA_ONSTACK;
218                if nodefer {
219                    handler.sa_flags |= libc::SA_NODEFER;
220                }
221                handler.sa_sigaction = trap_handler as *const () as usize;
222                libc::sigemptyset(&mut handler.sa_mask);
223                if libc::sigaction(signal, &handler, slot.as_mut_ptr()) != 0 {
224                    panic!(
225                        "unable to install signal handler: {}",
226                        io::Error::last_os_error(),
227                    );
228                }
229            };
230
231            // Allow handling OOB with signals on all architectures
232            register(&mut PREV_SIGSEGV, libc::SIGSEGV, true);
233
234            // Handle `unreachable` instructions which execute `ud2` right now
235            register(&mut PREV_SIGILL, libc::SIGILL, true);
236
237            // SIGUSR1 is used to interrupt long-running WASM code.
238            // It doesn't use NODEFER since, if a second interruption
239            // request comes in while one is already being processed,
240            // there's nothing meaningful we can do.
241            #[cfg(feature = "experimental-host-interrupt")]
242            register(&mut PREV_SIGUSR1, libc::SIGUSR1, false);
243
244            // x86 uses SIGFPE to report division by zero
245            if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
246                register(&mut PREV_SIGFPE, libc::SIGFPE, true);
247            }
248
249            // On ARM, handle Unaligned Accesses.
250            // On Darwin, guard page accesses are raised as SIGBUS.
251            if cfg!(target_arch = "arm") || cfg!(target_vendor = "apple") {
252                register(&mut PREV_SIGBUS, libc::SIGBUS, true);
253            }
254
255            // This is necessary to support debugging under LLDB on Darwin.
256            // For more details see https://github.com/mono/mono/commit/8e75f5a28e6537e56ad70bf870b86e22539c2fb7
257            #[cfg(target_vendor = "apple")]
258            {
259                use mach2::exception_types::*;
260                use mach2::kern_return::*;
261                use mach2::port::*;
262                use mach2::thread_status::*;
263                use mach2::traps::*;
264                use mach2::mach_types::*;
265
266                unsafe extern "C" {
267                    fn task_set_exception_ports(
268                        task: task_t,
269                        exception_mask: exception_mask_t,
270                        new_port: mach_port_t,
271                        behavior: exception_behavior_t,
272                        new_flavor: thread_state_flavor_t,
273                    ) -> kern_return_t;
274                }
275
276                #[allow(non_snake_case)]
277                #[cfg(target_arch = "x86_64")]
278                let MACHINE_THREAD_STATE = x86_THREAD_STATE64;
279                #[allow(non_snake_case)]
280                #[cfg(target_arch = "aarch64")]
281                let MACHINE_THREAD_STATE = 6;
282
283                task_set_exception_ports(
284                    mach_task_self(),
285                    EXC_MASK_BAD_ACCESS | EXC_MASK_ARITHMETIC | EXC_MASK_BAD_INSTRUCTION,
286                    MACH_PORT_NULL,
287                    EXCEPTION_STATE_IDENTITY as exception_behavior_t,
288                    MACHINE_THREAD_STATE,
289                );
290            }
291        }}
292
293        unsafe extern "C" fn trap_handler(
294            signum: libc::c_int,
295            siginfo: *mut libc::siginfo_t,
296            context: *mut libc::c_void,
297        ) { unsafe {
298            let previous = match signum {
299                libc::SIGSEGV => &PREV_SIGSEGV,
300                libc::SIGBUS => &PREV_SIGBUS,
301                libc::SIGFPE => &PREV_SIGFPE,
302                libc::SIGILL => &PREV_SIGILL,
303                #[cfg(feature = "experimental-host-interrupt")]
304                libc::SIGUSR1 => &PREV_SIGUSR1,
305                _ => panic!("unknown signal: {signum}"),
306            };
307            // We try to get the fault address associated to this signal
308            let maybe_fault_address = match signum {
309                libc::SIGSEGV | libc::SIGBUS => {
310                    Some((*siginfo).si_addr() as usize)
311                }
312                _ => None,
313            };
314            let trap_code = match signum {
315                // check if it was cased by a UD and if the Trap info is a payload to it
316                libc::SIGILL => {
317                    let addr = (*siginfo).si_addr() as usize;
318                    process_illegal_op(addr)
319                }
320                #[cfg(feature = "experimental-host-interrupt")]
321                libc::SIGUSR1 => {
322                    // If we're not running WASM code from the specific store for which
323                    // an interrupt was requested, there's nothing to do.
324                    if !interrupt_registry::on_interrupted() {
325                        return;
326                    }
327                    Some(TrapCode::HostInterrupt)
328                }
329                _ => None,
330            };
331            let ucontext = &mut *(context as *mut ucontext_t);
332            let (pc, sp) = get_pc_sp(ucontext);
333            let handled = TrapHandlerContext::handle_trap(
334                pc,
335                sp,
336                maybe_fault_address,
337                trap_code,
338                |regs| update_context(ucontext, regs),
339                |handler| handler(signum, siginfo, context),
340            );
341
342            if handled {
343                return;
344            }
345
346            // If we're not running WASM code at all, there's nothing to
347            // do for an interrupt.
348            #[cfg(feature = "experimental-host-interrupt")]
349            if signum == libc::SIGUSR1 {
350                return;
351            }
352
353            // This signal is not for any compiled wasm code we expect, so we
354            // need to forward the signal to the next handler. If there is no
355            // next handler (SIG_IGN or SIG_DFL), then it's time to crash. To do
356            // this, we set the signal back to its original disposition and
357            // return. This will cause the faulting op to be re-executed which
358            // will crash in the normal way. If there is a next handler, call
359            // it. It will either crash synchronously, fix up the instruction
360            // so that execution can continue and return, or trigger a crash by
361            // returning the signal to it's original disposition and returning.
362            let previous = &*previous.as_ptr();
363            if previous.sa_flags & libc::SA_SIGINFO != 0 {
364                mem::transmute::<
365                    usize,
366                    extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void),
367                >(previous.sa_sigaction)(signum, siginfo, context)
368            } else if previous.sa_sigaction == libc::SIG_DFL
369            {
370                libc::sigaction(signum, previous, ptr::null_mut());
371            } else if previous.sa_sigaction != libc::SIG_IGN {
372                mem::transmute::<usize, extern "C" fn(libc::c_int)>(
373                    previous.sa_sigaction
374                )(signum)
375            }
376        }}
377
378        unsafe fn get_pc_sp(context: &ucontext_t) -> (usize, usize) {
379            let (pc, sp);
380            cfg_if::cfg_if! {
381                if #[cfg(all(
382                    any(target_os = "linux", target_os = "android"),
383                    target_arch = "x86_64",
384                ))] {
385                    pc = context.uc_mcontext.gregs[libc::REG_RIP as usize] as usize;
386                    sp = context.uc_mcontext.gregs[libc::REG_RSP as usize] as usize;
387                } else if #[cfg(all(
388                    any(target_os = "linux", target_os = "android"),
389                    target_arch = "x86",
390                ))] {
391                    pc = context.uc_mcontext.gregs[libc::REG_EIP as usize] as usize;
392                    sp = context.uc_mcontext.gregs[libc::REG_ESP as usize] as usize;
393                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
394                    pc = context.uc_mcontext.mc_eip as usize;
395                    sp = context.uc_mcontext.mc_esp as usize;
396                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
397                    pc = context.uc_mcontext.mc_rip as usize;
398                    sp = context.uc_mcontext.mc_rsp as usize;
399                } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
400                    let mcontext = unsafe { &*context.uc_mcontext };
401                    pc = mcontext.__ss.__rip as usize;
402                    sp = mcontext.__ss.__rsp as usize;
403                } else if #[cfg(all(
404                        any(target_os = "linux", target_os = "android"),
405                        target_arch = "aarch64",
406                    ))] {
407                    pc = context.uc_mcontext.pc as usize;
408                    sp = context.uc_mcontext.sp as usize;
409                } else if #[cfg(all(
410                    any(target_os = "linux", target_os = "android"),
411                    target_arch = "arm",
412                ))] {
413                    pc = context.uc_mcontext.arm_pc as usize;
414                    sp = context.uc_mcontext.arm_sp as usize;
415                } else if #[cfg(all(
416                    any(target_os = "linux", target_os = "android"),
417                    any(target_arch = "riscv64", target_arch = "riscv32"),
418                ))] {
419                    pc = context.uc_mcontext.__gregs[libc::REG_PC] as usize;
420                    sp = context.uc_mcontext.__gregs[libc::REG_SP] as usize;
421                } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
422                    let mcontext = unsafe { &*context.uc_mcontext };
423                    pc = mcontext.__ss.__pc as usize;
424                    sp = mcontext.__ss.__sp as usize;
425                } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
426                    pc = context.uc_mcontext.mc_gpregs.gp_elr as usize;
427                    sp = context.uc_mcontext.mc_gpregs.gp_sp as usize;
428                } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
429                    pc = context.uc_mcontext.__gregs[1] as usize;
430                    sp = context.uc_mcontext.__gregs[3] as usize;
431                } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
432                    pc = (*context.uc_mcontext.regs).nip as usize;
433                    sp = (*context.uc_mcontext.regs).gpr[1] as usize;
434                } else {
435                    compile_error!("Unsupported platform");
436                }
437            };
438            (pc, sp)
439        }
440
441        unsafe fn update_context(context: &mut ucontext_t, regs: TrapHandlerRegs) {
442            cfg_if::cfg_if! {
443                if #[cfg(all(
444                        any(target_os = "linux", target_os = "android"),
445                        target_arch = "x86_64",
446                    ))] {
447                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
448                    context.uc_mcontext.gregs[libc::REG_RIP as usize] = rip as i64;
449                    context.uc_mcontext.gregs[libc::REG_RSP as usize] = rsp as i64;
450                    context.uc_mcontext.gregs[libc::REG_RBP as usize] = rbp as i64;
451                    context.uc_mcontext.gregs[libc::REG_RDI as usize] = rdi as i64;
452                    context.uc_mcontext.gregs[libc::REG_RSI as usize] = rsi as i64;
453                } else if #[cfg(all(
454                    any(target_os = "linux", target_os = "android"),
455                    target_arch = "x86",
456                ))] {
457                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
458                    context.uc_mcontext.gregs[libc::REG_EIP as usize] = eip as i32;
459                    context.uc_mcontext.gregs[libc::REG_ESP as usize] = esp as i32;
460                    context.uc_mcontext.gregs[libc::REG_EBP as usize] = ebp as i32;
461                    context.uc_mcontext.gregs[libc::REG_ECX as usize] = ecx as i32;
462                    context.uc_mcontext.gregs[libc::REG_EDX as usize] = edx as i32;
463                } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
464                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
465                    let mcontext = unsafe { &mut *context.uc_mcontext };
466                    mcontext.__ss.__rip = rip;
467                    mcontext.__ss.__rsp = rsp;
468                    mcontext.__ss.__rbp = rbp;
469                    mcontext.__ss.__rdi = rdi;
470                    mcontext.__ss.__rsi = rsi;
471                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
472                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
473                    context.uc_mcontext.mc_eip = eip as libc::register_t;
474                    context.uc_mcontext.mc_esp = esp as libc::register_t;
475                    context.uc_mcontext.mc_ebp = ebp as libc::register_t;
476                    context.uc_mcontext.mc_ecx = ecx as libc::register_t;
477                    context.uc_mcontext.mc_edx = edx as libc::register_t;
478                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
479                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
480                    context.uc_mcontext.mc_rip = rip as libc::register_t;
481                    context.uc_mcontext.mc_rsp = rsp as libc::register_t;
482                    context.uc_mcontext.mc_rbp = rbp as libc::register_t;
483                    context.uc_mcontext.mc_rdi = rdi as libc::register_t;
484                    context.uc_mcontext.mc_rsi = rsi as libc::register_t;
485                } else if #[cfg(all(
486                        any(target_os = "linux", target_os = "android"),
487                        target_arch = "aarch64",
488                    ))] {
489                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
490                    context.uc_mcontext.pc = pc;
491                    context.uc_mcontext.sp = sp;
492                    context.uc_mcontext.regs[0] = x0;
493                    context.uc_mcontext.regs[1] = x1;
494                    context.uc_mcontext.regs[29] = x29;
495                    context.uc_mcontext.regs[30] = lr;
496                } else if #[cfg(all(
497                        any(target_os = "linux", target_os = "android"),
498                        target_arch = "arm",
499                    ))] {
500                    let TrapHandlerRegs {
501                        pc,
502                        r0,
503                        r1,
504                        r7,
505                        r11,
506                        r13,
507                        r14,
508                        cpsr_thumb,
509                        cpsr_endian,
510                    } = regs;
511                    context.uc_mcontext.arm_pc = pc;
512                    context.uc_mcontext.arm_r0 = r0;
513                    context.uc_mcontext.arm_r1 = r1;
514                    context.uc_mcontext.arm_r7 = r7;
515                    context.uc_mcontext.arm_fp = r11;
516                    context.uc_mcontext.arm_sp = r13;
517                    context.uc_mcontext.arm_lr = r14;
518                    if cpsr_thumb {
519                        context.uc_mcontext.arm_cpsr |= 0x20;
520                    } else {
521                        context.uc_mcontext.arm_cpsr &= !0x20;
522                    }
523                    if cpsr_endian {
524                        context.uc_mcontext.arm_cpsr |= 0x200;
525                    } else {
526                        context.uc_mcontext.arm_cpsr &= !0x200;
527                    }
528                } else if #[cfg(all(
529                    any(target_os = "linux", target_os = "android"),
530                    any(target_arch = "riscv64", target_arch = "riscv32"),
531                ))] {
532                    let TrapHandlerRegs { pc, ra, sp, a0, a1, s0 } = regs;
533                    context.uc_mcontext.__gregs[libc::REG_PC] = pc as libc::c_ulong;
534                    context.uc_mcontext.__gregs[libc::REG_RA] = ra as libc::c_ulong;
535                    context.uc_mcontext.__gregs[libc::REG_SP] = sp as libc::c_ulong;
536                    context.uc_mcontext.__gregs[libc::REG_A0] = a0 as libc::c_ulong;
537                    context.uc_mcontext.__gregs[libc::REG_A0 + 1] = a1 as libc::c_ulong;
538                    context.uc_mcontext.__gregs[libc::REG_S0] = s0 as libc::c_ulong;
539                } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
540                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
541                    let mcontext = unsafe { &mut *context.uc_mcontext };
542                    mcontext.__ss.__pc = pc;
543                    mcontext.__ss.__sp = sp;
544                    mcontext.__ss.__x[0] = x0;
545                    mcontext.__ss.__x[1] = x1;
546                    mcontext.__ss.__fp = x29;
547                    mcontext.__ss.__lr = lr;
548                } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
549                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
550                    context.uc_mcontext.mc_gpregs.gp_elr = pc as libc::register_t;
551                    context.uc_mcontext.mc_gpregs.gp_sp = sp as libc::register_t;
552                    context.uc_mcontext.mc_gpregs.gp_x[0] = x0 as libc::register_t;
553                    context.uc_mcontext.mc_gpregs.gp_x[1] = x1 as libc::register_t;
554                    context.uc_mcontext.mc_gpregs.gp_x[29] = x29 as libc::register_t;
555                    context.uc_mcontext.mc_gpregs.gp_lr = lr as libc::register_t;
556                } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
557                    let TrapHandlerRegs { pc, sp, a0, a1, fp, ra } = regs;
558                    context.uc_mcontext.__pc = pc;
559                    context.uc_mcontext.__gregs[1] = ra;
560                    context.uc_mcontext.__gregs[3] = sp;
561                    context.uc_mcontext.__gregs[4] = a0;
562                    context.uc_mcontext.__gregs[5] = a1;
563                    context.uc_mcontext.__gregs[22] = fp;
564                } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
565                    let TrapHandlerRegs { pc, sp, r3, r4, r31, lr } = regs;
566                    (*context.uc_mcontext.regs).nip = pc;
567                    (*context.uc_mcontext.regs).gpr[1] = sp;
568                    (*context.uc_mcontext.regs).gpr[3] = r3;
569                    (*context.uc_mcontext.regs).gpr[4] = r4;
570                    (*context.uc_mcontext.regs).gpr[31] = r31;
571                    (*context.uc_mcontext.regs).link = lr;
572                } else {
573                    compile_error!("Unsupported platform");
574                }
575            };
576        }
577    } else if #[cfg(target_os = "windows")] {
578        use windows_sys::Win32::System::Diagnostics::Debug::{
579            AddVectoredExceptionHandler,
580            CONTEXT,
581            EXCEPTION_CONTINUE_EXECUTION,
582            EXCEPTION_CONTINUE_SEARCH,
583            EXCEPTION_POINTERS,
584        };
585        use windows_sys::Win32::Foundation::{
586            EXCEPTION_ACCESS_VIOLATION,
587            EXCEPTION_ILLEGAL_INSTRUCTION,
588            EXCEPTION_INT_DIVIDE_BY_ZERO,
589            EXCEPTION_INT_OVERFLOW,
590            EXCEPTION_STACK_OVERFLOW,
591        };
592
593        unsafe fn platform_init() {
594            unsafe {
595                // our trap handler needs to go first, so that we can recover from
596                // wasm faults and continue execution, so pass `1` as a true value
597                // here.
598                let handler = AddVectoredExceptionHandler(1, Some(exception_handler));
599                if handler.is_null() {
600                    panic!("failed to add exception handler: {}", io::Error::last_os_error());
601                }
602            }
603        }
604
605        unsafe extern "system" fn exception_handler(
606            exception_info: *mut EXCEPTION_POINTERS
607        ) -> i32 {
608            unsafe {
609                // Check the kind of exception, since we only handle a subset within
610                // wasm code. If anything else happens we want to defer to whatever
611                // the rest of the system wants to do for this exception.
612                let record = &*(*exception_info).ExceptionRecord;
613                if record.ExceptionCode != EXCEPTION_ACCESS_VIOLATION &&
614                    record.ExceptionCode != EXCEPTION_ILLEGAL_INSTRUCTION &&
615                    record.ExceptionCode != EXCEPTION_STACK_OVERFLOW &&
616                    record.ExceptionCode != EXCEPTION_INT_DIVIDE_BY_ZERO &&
617                    record.ExceptionCode != EXCEPTION_INT_OVERFLOW
618                {
619                    return EXCEPTION_CONTINUE_SEARCH;
620                }
621
622                // FIXME: this is what the previous C++ did to make sure that TLS
623                // works by the time we execute this trap handling code. This isn't
624                // exactly super easy to call from Rust though and it's not clear we
625                // necessarily need to do so. Leaving this here in case we need this
626                // in the future, but for now we can probably wait until we see a
627                // strange fault before figuring out how to reimplement this in
628                // Rust.
629                //
630                // if (!NtCurrentTeb()->Reserved1[sThreadLocalArrayPointerIndex]) {
631                //     return EXCEPTION_CONTINUE_SEARCH;
632                // }
633
634                let context = &mut *(*exception_info).ContextRecord;
635                let (pc, sp) = get_pc_sp(context);
636
637                // We try to get the fault address associated to this exception.
638                let maybe_fault_address = match record.ExceptionCode {
639                    EXCEPTION_ACCESS_VIOLATION => Some(record.ExceptionInformation[1]),
640                    EXCEPTION_STACK_OVERFLOW => Some(sp),
641                    _ => None,
642                };
643                let trap_code = match record.ExceptionCode {
644                    // check if it was cased by a UD and if the Trap info is a payload to it
645                    EXCEPTION_ILLEGAL_INSTRUCTION => {
646                        process_illegal_op(pc)
647                    }
648                    _ => None,
649                };
650                // This is basically the same as the unix version above, only with a
651                // few parameters tweaked here and there.
652                let handled = TrapHandlerContext::handle_trap(
653                    pc,
654                    sp,
655                    maybe_fault_address,
656                    trap_code,
657                    |regs| update_context(context, regs),
658                    |handler| handler(exception_info),
659                );
660
661                if handled {
662                    EXCEPTION_CONTINUE_EXECUTION
663                } else {
664                    EXCEPTION_CONTINUE_SEARCH
665                }
666            }
667        }
668
669        unsafe fn get_pc_sp(context: &CONTEXT) -> (usize, usize) {
670            let (pc, sp);
671            cfg_if::cfg_if! {
672                if #[cfg(target_arch = "x86_64")] {
673                    pc = context.Rip as usize;
674                    sp = context.Rsp as usize;
675                } else if #[cfg(target_arch = "x86")] {
676                    pc = context.Eip as usize;
677                    sp = context.Esp as usize;
678                } else {
679                    compile_error!("Unsupported platform");
680                }
681            };
682            (pc, sp)
683        }
684
685        unsafe fn update_context(context: &mut CONTEXT, regs: TrapHandlerRegs) {
686            cfg_if::cfg_if! {
687                if #[cfg(target_arch = "x86_64")] {
688                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
689                    context.Rip = rip;
690                    context.Rsp = rsp;
691                    context.Rbp = rbp;
692                    context.Rdi = rdi;
693                    context.Rsi = rsi;
694                } else if #[cfg(target_arch = "x86")] {
695                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
696                    context.Eip = eip;
697                    context.Esp = esp;
698                    context.Ebp = ebp;
699                    context.Ecx = ecx;
700                    context.Edx = edx;
701                } else {
702                    compile_error!("Unsupported platform");
703                }
704            };
705        }
706    }
707}
708
709/// This function is required to be called before any WebAssembly is entered.
710/// This will configure global state such as signal handlers to prepare the
711/// process to receive wasm traps.
712///
713/// This function must not only be called globally once before entering
714/// WebAssembly but it must also be called once-per-thread that enters
715/// WebAssembly. Currently in wasmer's integration this function is called on
716/// creation of a `Store`.
717pub fn init_traps() {
718    static INIT: Once = Once::new();
719    INIT.call_once(|| unsafe {
720        platform_init();
721    });
722}
723
724/// Raises a user-defined trap immediately.
725///
726/// This function performs as-if a wasm trap was just executed, only the trap
727/// has a dynamic payload associated with it which is user-provided. This trap
728/// payload is then returned from `catch_traps` below.
729///
730/// # Safety
731///
732/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
733/// have been previous called and not yet returned.
734/// Additionally no Rust destructors may be on the stack.
735/// They will be skipped and not executed.
736pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
737    unsafe { unwind_with(UnwindReason::UserTrap(data)) }
738}
739
740/// Raises a trap from inside library code immediately.
741///
742/// This function performs as-if a wasm trap was just executed. This trap
743/// payload is then returned from `catch_traps` below.
744///
745/// # Safety
746///
747/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
748/// have been previous called and not yet returned.
749/// Additionally no Rust destructors may be on the stack.
750/// They will be skipped and not executed.
751pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
752    unsafe { unwind_with(UnwindReason::LibTrap(trap)) }
753}
754
755/// Carries a Rust panic across wasm code and resumes the panic on the other
756/// side.
757///
758/// # Safety
759///
760/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
761/// have been previously called and not returned. Additionally no Rust destructors may be on the
762/// stack. They will be skipped and not executed.
763pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
764    unsafe { unwind_with(UnwindReason::Panic(payload)) }
765}
766
767/// Call the wasm function pointed to by `callee`.
768///
769/// * `vmctx` - the callee vmctx argument
770/// * `caller_vmctx` - the caller vmctx argument
771/// * `trampoline` - the jit-generated trampoline whose ABI takes 4 values, the
772///   callee vmctx, the caller vmctx, the `callee` argument below, and then the
773///   `values_vec` argument.
774/// * `callee` - the third argument to the `trampoline` function
775/// * `values_vec` - points to a buffer which holds the incoming arguments, and to
776///   which the outgoing return values will be written.
777///
778/// # Safety
779///
780/// Wildly unsafe because it calls raw function pointers and reads/writes raw
781/// function pointers.
782pub unsafe fn wasmer_call_trampoline(
783    trap_handler: Option<*const TrapHandlerFn<'static>>,
784    config: &VMConfig,
785    vmctx: VMFunctionContext,
786    trampoline: VMTrampoline,
787    callee: *const VMFunctionBody,
788    values_vec: *mut u8,
789) -> Result<(), Trap> {
790    unsafe {
791        catch_traps(trap_handler, config, move || {
792            mem::transmute::<
793                unsafe extern "C" fn(
794                    *mut VMContext,
795                    *const VMFunctionBody,
796                    *mut wasmer_types::RawValue,
797                ),
798                extern "C" fn(VMFunctionContext, *const VMFunctionBody, *mut u8),
799            >(trampoline)(vmctx, callee, values_vec);
800        })
801    }
802}
803
804/// Catches any wasm traps that happen within the execution of `closure`,
805/// returning them as a `Result`.
806///
807/// # Safety
808///
809/// Highly unsafe since `closure` won't have any dtors run.
810pub unsafe fn catch_traps<F, R: 'static>(
811    trap_handler: Option<*const TrapHandlerFn<'static>>,
812    config: &VMConfig,
813    closure: F,
814) -> Result<R, Trap>
815where
816    F: FnOnce() -> R + 'static,
817{
818    // Ensure that per-thread initialization is done.
819    lazy_per_thread_init()?;
820    let stack_size = config
821        .wasm_stack_size
822        .unwrap_or_else(|| DEFAULT_STACK_SIZE.load(Ordering::Relaxed));
823    on_wasm_stack(stack_size, trap_handler, closure).map_err(UnwindReason::into_trap)
824}
825
826// We need two separate thread-local variables here:
827// - YIELDER is set within the new stack and is used to unwind back to the root
828//   of the stack from inside it.
829// - TRAP_HANDLER is set from outside the new stack and is solely used from
830//   signal handlers. It must be atomic since it is used by signal handlers.
831//
832// We also do per-thread signal stack initialization on the first time
833// TRAP_HANDLER is accessed.
834thread_local! {
835    static YIELDER: Cell<Option<NonNull<Yielder<(), UnwindReason>>>> = const { Cell::new(None) };
836    static TRAP_HANDLER: AtomicPtr<TrapHandlerContext> = const { AtomicPtr::new(ptr::null_mut()) };
837}
838
839/// Read-only information that is used by signal handlers to handle and recover
840/// from traps.
841#[allow(clippy::type_complexity)]
842struct TrapHandlerContext {
843    inner: *const u8,
844    handle_trap: fn(
845        *const u8,
846        usize,
847        usize,
848        Option<usize>,
849        Option<TrapCode>,
850        &mut dyn FnMut(TrapHandlerRegs),
851    ) -> bool,
852    custom_trap: Option<*const TrapHandlerFn<'static>>,
853}
854struct TrapHandlerContextInner<T> {
855    /// Information about the currently running coroutine. This is used to
856    /// reset execution to the root of the coroutine when a trap is handled.
857    coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
858}
859
860impl TrapHandlerContext {
861    /// Runs the given function with a trap handler context. The previous
862    /// trap handler context is preserved and restored afterwards.
863    fn install<T, R>(
864        custom_trap: Option<*const TrapHandlerFn<'static>>,
865        coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
866        f: impl FnOnce() -> R,
867    ) -> R {
868        // Type-erase the trap handler function so that it can be placed in TLS.
869        fn func<T>(
870            ptr: *const u8,
871            pc: usize,
872            sp: usize,
873            maybe_fault_address: Option<usize>,
874            trap_code: Option<TrapCode>,
875            update_regs: &mut dyn FnMut(TrapHandlerRegs),
876        ) -> bool {
877            unsafe {
878                (*(ptr as *const TrapHandlerContextInner<T>)).handle_trap(
879                    pc,
880                    sp,
881                    maybe_fault_address,
882                    trap_code,
883                    update_regs,
884                )
885            }
886        }
887        let inner = TrapHandlerContextInner { coro_trap_handler };
888        let ctx = Self {
889            inner: &inner as *const _ as *const u8,
890            handle_trap: func::<T>,
891            custom_trap,
892        };
893
894        compiler_fence(Ordering::Release);
895        let prev = TRAP_HANDLER.with(|ptr| {
896            let prev = ptr.load(Ordering::Relaxed);
897            ptr.store(&ctx as *const Self as *mut Self, Ordering::Relaxed);
898            prev
899        });
900
901        defer! {
902            TRAP_HANDLER.with(|ptr| ptr.store(prev, Ordering::Relaxed));
903            compiler_fence(Ordering::Acquire);
904        }
905
906        f()
907    }
908
909    /// Attempts to handle the trap if it's a wasm trap.
910    unsafe fn handle_trap(
911        pc: usize,
912        sp: usize,
913        maybe_fault_address: Option<usize>,
914        trap_code: Option<TrapCode>,
915        mut update_regs: impl FnMut(TrapHandlerRegs),
916        call_handler: impl Fn(&TrapHandlerFn<'static>) -> bool,
917    ) -> bool {
918        unsafe {
919            let ptr = TRAP_HANDLER.with(|ptr| ptr.load(Ordering::Relaxed));
920            if ptr.is_null() {
921                return false;
922            }
923
924            let ctx = &*ptr;
925
926            // Check if this trap is handled by a custom trap handler.
927            if let Some(trap_handler) = ctx.custom_trap
928                && call_handler(&*trap_handler)
929            {
930                return true;
931            }
932
933            (ctx.handle_trap)(
934                ctx.inner,
935                pc,
936                sp,
937                maybe_fault_address,
938                trap_code,
939                &mut update_regs,
940            )
941        }
942    }
943}
944
945impl<T> TrapHandlerContextInner<T> {
946    unsafe fn handle_trap(
947        &self,
948        pc: usize,
949        sp: usize,
950        maybe_fault_address: Option<usize>,
951        trap_code: Option<TrapCode>,
952        update_regs: &mut dyn FnMut(TrapHandlerRegs),
953    ) -> bool {
954        unsafe {
955            // Check if this trap occurred while executing on the Wasm stack. We can
956            // only recover from traps if that is the case.
957            if !self.coro_trap_handler.stack_ptr_in_bounds(sp) {
958                return false;
959            }
960
961            let signal_trap = trap_code.or_else(|| {
962                maybe_fault_address.map(|addr| {
963                    if self.coro_trap_handler.stack_ptr_in_bounds(addr) {
964                        TrapCode::StackOverflow
965                    } else {
966                        TrapCode::HeapAccessOutOfBounds
967                    }
968                })
969            });
970
971            // Don't try to generate a backtrace for stack overflows: unwinding
972            // information is often not precise enough to properly describe what is
973            // happening during a function prologue, which can lead the unwinder to
974            // read invalid memory addresses.
975            //
976            // See: https://github.com/rust-lang/backtrace-rs/pull/357
977            let backtrace = if signal_trap == Some(TrapCode::StackOverflow) {
978                Backtrace::from(vec![])
979            } else {
980                Backtrace::new_unresolved()
981            };
982
983            // Set up the register state for exception return to force the
984            // coroutine to return to its caller with UnwindReason::WasmTrap.
985            let unwind = UnwindReason::WasmTrap {
986                backtrace,
987                signal_trap,
988                pc,
989            };
990            let regs = self
991                .coro_trap_handler
992                .setup_trap_handler(move || Err(unwind));
993            update_regs(regs);
994            true
995        }
996    }
997}
998
999enum UnwindReason {
1000    /// A panic caused by the host
1001    Panic(Box<dyn Any + Send>),
1002    /// A custom error triggered by the user
1003    UserTrap(Box<dyn Error + Send + Sync>),
1004    /// A Trap triggered by a wasm libcall
1005    LibTrap(Trap),
1006    /// A trap caused by the Wasm generated code
1007    WasmTrap {
1008        backtrace: Backtrace,
1009        pc: usize,
1010        signal_trap: Option<TrapCode>,
1011    },
1012}
1013
1014impl UnwindReason {
1015    fn into_trap(self) -> Trap {
1016        match self {
1017            Self::UserTrap(data) => Trap::User(data),
1018            Self::LibTrap(trap) => trap,
1019            Self::WasmTrap {
1020                backtrace,
1021                pc,
1022                signal_trap,
1023            } => Trap::wasm(pc, backtrace, signal_trap),
1024            Self::Panic(panic) => std::panic::resume_unwind(panic),
1025        }
1026    }
1027}
1028
1029unsafe fn unwind_with(reason: UnwindReason) -> ! {
1030    unsafe {
1031        let yielder = YIELDER
1032            .with(|cell| cell.replace(None))
1033            .expect("not running on Wasm stack");
1034
1035        yielder.as_ref().suspend(reason);
1036
1037        // on_wasm_stack will forcibly reset the coroutine stack after yielding.
1038        unreachable!();
1039    }
1040}
1041
1042/// Runs the given function on a separate stack so that its stack usage can be
1043/// bounded. Stack overflows and other traps can be caught and execution
1044/// returned to the root of the stack.
1045fn on_wasm_stack<F: FnOnce() -> T + 'static, T: 'static>(
1046    stack_size: usize,
1047    trap_handler: Option<*const TrapHandlerFn<'static>>,
1048    f: F,
1049) -> Result<T, UnwindReason> {
1050    // Reuse a cached stack from the pool if it is large enough, otherwise
1051    // allocate a fresh one. The size check prevents using undersized stacks
1052    // that were returned by threads still running at the old size after
1053    // `drain_stack_pool()` was called. `base() - limit()` is the full mmap
1054    // region (including guard page), which is always >= the requested size
1055    // for stacks allocated with that size.
1056    let stack = STACK_POOL
1057        .pop()
1058        .filter(|s| s.size() >= stack_size)
1059        .unwrap_or_else(|| DefaultStack::new(stack_size).unwrap());
1060    let mut stack = scopeguard::guard(stack, |stack| STACK_POOL.push(stack));
1061
1062    // Create a coroutine with a new stack to run the function on.
1063    let coro = ScopedCoroutine::with_stack(&mut *stack, move |yielder, ()| {
1064        // Save the yielder to TLS so that it can be used later.
1065        YIELDER.with(|cell| cell.set(Some(yielder.into())));
1066
1067        Ok(f())
1068    });
1069
1070    // Ensure that YIELDER is reset on exit even if the coroutine panics,
1071    defer! {
1072        YIELDER.with(|cell| cell.set(None));
1073    }
1074
1075    coro.scope(|mut coro_ref| {
1076        // Set up metadata for the trap handler for the duration of the coroutine
1077        // execution. This is restored to its previous value afterwards.
1078        TrapHandlerContext::install(trap_handler, coro_ref.trap_handler(), || {
1079            match coro_ref.resume(()) {
1080                CoroutineResult::Yield(trap) => {
1081                    // This came from unwind_with which requires that there be only
1082                    // Wasm code on the stack.
1083                    unsafe {
1084                        coro_ref.force_reset();
1085                    }
1086                    Err(trap)
1087                }
1088                CoroutineResult::Return(result) => result,
1089            }
1090        })
1091    })
1092}
1093
1094/// When executing on the Wasm stack, temporarily switch back to the host stack
1095/// to perform an operation that should not be constrained by the Wasm stack
1096/// limits.
1097///
1098/// This is particularly important since the usage of the Wasm stack is under
1099/// the control of untrusted code. Malicious code could artificially induce a
1100/// stack overflow in the middle of a sensitive host operations (e.g. growing
1101/// a memory) which would be hard to recover from.
1102pub fn on_host_stack<F: FnOnce() -> T, T>(f: F) -> T {
1103    // Reset YIEDER to None for the duration of this call to indicate that we
1104    // are no longer on the Wasm stack.
1105    let yielder_ptr = YIELDER.with(|cell| cell.replace(None));
1106
1107    // If we are already on the host stack, execute the function directly. This
1108    // happens if a host function is called directly from the API.
1109    let yielder = match yielder_ptr {
1110        Some(ptr) => unsafe { ptr.as_ref() },
1111        None => return f(),
1112    };
1113
1114    // Restore YIELDER upon exiting normally or unwinding.
1115    defer! {
1116        YIELDER.with(|cell| cell.set(yielder_ptr));
1117    }
1118
1119    // on_parent_stack requires the closure to be Send so that the Yielder
1120    // cannot be called from the parent stack. This is not a problem for us
1121    // since we don't expose the Yielder.
1122    struct SendWrapper<T>(T);
1123    unsafe impl<T> Send for SendWrapper<T> {}
1124    let wrapped = SendWrapper(f);
1125    yielder.on_parent_stack(move || {
1126        let wrapped = wrapped;
1127        (wrapped.0)()
1128    })
1129}
1130
1131#[cfg(windows)]
1132pub fn lazy_per_thread_init() -> Result<(), Trap> {
1133    // We need additional space on the stack to handle stack overflow
1134    // exceptions. Rust's initialization code sets this to 0x5000 but this
1135    // seems to be insufficient in practice.
1136    use windows_sys::Win32::System::Threading::SetThreadStackGuarantee;
1137    if unsafe { SetThreadStackGuarantee(&mut 0x10000) } == 0 {
1138        panic!("failed to set thread stack guarantee");
1139    }
1140
1141    Ok(())
1142}
1143
1144/// A module for registering a custom alternate signal stack (sigaltstack).
1145///
1146/// Rust's libstd installs an alternate stack with size `SIGSTKSZ`, which is not
1147/// always large enough for our signal handling code. Override it by creating
1148/// and registering our own alternate stack that is large enough and has a guard
1149/// page.
1150#[cfg(unix)]
1151pub fn lazy_per_thread_init() -> Result<(), Trap> {
1152    use std::ptr::null_mut;
1153
1154    thread_local! {
1155        /// Thread-local state is lazy-initialized on the first time it's used,
1156        /// and dropped when the thread exits.
1157        static TLS: Tls = unsafe { init_sigstack() };
1158    }
1159
1160    /// The size of the sigaltstack (not including the guard, which will be
1161    /// added). Make this large enough to run our signal handlers.
1162    const MIN_STACK_SIZE: usize = ByteSize::kib(64).as_u64() as usize;
1163
1164    enum Tls {
1165        OutOfMemory,
1166        Allocated {
1167            mmap_ptr: *mut libc::c_void,
1168            mmap_size: usize,
1169        },
1170        BigEnough,
1171    }
1172
1173    unsafe fn init_sigstack() -> Tls {
1174        unsafe {
1175            // Check to see if the existing sigaltstack, if it exists, is big
1176            // enough. If so we don't need to allocate our own.
1177            let mut old_stack = mem::zeroed();
1178            let r = libc::sigaltstack(ptr::null(), &mut old_stack);
1179            assert_eq!(r, 0, "learning about sigaltstack failed");
1180            if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE {
1181                return Tls::BigEnough;
1182            }
1183
1184            // ... but failing that we need to allocate our own, so do all that
1185            // here.
1186            let page_size: usize = region::page::size();
1187            let guard_size = page_size;
1188            let alloc_size = guard_size + MIN_STACK_SIZE;
1189
1190            let ptr = libc::mmap(
1191                null_mut(),
1192                alloc_size,
1193                libc::PROT_NONE,
1194                libc::MAP_PRIVATE | libc::MAP_ANON,
1195                -1,
1196                0,
1197            );
1198            if ptr == libc::MAP_FAILED {
1199                return Tls::OutOfMemory;
1200            }
1201
1202            // Prepare the stack with readable/writable memory and then register it
1203            // with `sigaltstack`.
1204            let stack_ptr = (ptr as usize + guard_size) as *mut libc::c_void;
1205            let r = libc::mprotect(
1206                stack_ptr,
1207                MIN_STACK_SIZE,
1208                libc::PROT_READ | libc::PROT_WRITE,
1209            );
1210            assert_eq!(r, 0, "mprotect to configure memory for sigaltstack failed");
1211            let new_stack = libc::stack_t {
1212                ss_sp: stack_ptr,
1213                ss_flags: 0,
1214                ss_size: MIN_STACK_SIZE,
1215            };
1216            let r = libc::sigaltstack(&new_stack, ptr::null_mut());
1217            assert_eq!(r, 0, "registering new sigaltstack failed");
1218
1219            Tls::Allocated {
1220                mmap_ptr: ptr,
1221                mmap_size: alloc_size,
1222            }
1223        }
1224    }
1225
1226    // Ensure TLS runs its initializer and return an error if it failed to
1227    // set up a separate stack for signal handlers.
1228    return TLS.with(|tls| {
1229        if let Tls::OutOfMemory = tls {
1230            Err(Trap::oom())
1231        } else {
1232            Ok(())
1233        }
1234    });
1235
1236    impl Drop for Tls {
1237        fn drop(&mut self) {
1238            let (ptr, size) = match self {
1239                Self::Allocated {
1240                    mmap_ptr,
1241                    mmap_size,
1242                } => (*mmap_ptr, *mmap_size),
1243                _ => return,
1244            };
1245            unsafe {
1246                // Deallocate the stack memory.
1247                let r = libc::munmap(ptr, size);
1248                debug_assert_eq!(r, 0, "munmap failed during thread shutdown");
1249            }
1250        }
1251    }
1252}
1253
1254#[cfg(test)]
1255mod tests {
1256    use super::*;
1257    use std::sync::Mutex;
1258
1259    // Guards tests that mutate global state (DEFAULT_STACK_SIZE, STACK_POOL).
1260    // Rust runs tests in parallel by default; this mutex serializes them so
1261    // they don't step on each other.
1262    static GLOBAL_STATE: Mutex<()> = Mutex::new(());
1263
1264    /// Saves the current stack size and restores it on drop (even on panic).
1265    struct RestoreStackSize(usize);
1266    impl Drop for RestoreStackSize {
1267        fn drop(&mut self) {
1268            set_stack_size(self.0);
1269        }
1270    }
1271
1272    #[test]
1273    fn max_stack_size_is_100mb() {
1274        assert_eq!(MAX_STACK_SIZE, ByteSize::mib(100).as_u64() as usize);
1275    }
1276
1277    #[test]
1278    fn get_set_stack_size_roundtrip() {
1279        let _lock = GLOBAL_STATE.lock().unwrap();
1280        let _restore = RestoreStackSize(get_stack_size());
1281        let new_size = ByteSize::mib(4).as_u64() as usize;
1282        set_stack_size(new_size);
1283        assert_eq!(get_stack_size(), new_size);
1284    }
1285
1286    #[test]
1287    fn set_stack_size_clamps_to_min() {
1288        let _lock = GLOBAL_STATE.lock().unwrap();
1289        let _restore = RestoreStackSize(get_stack_size());
1290        set_stack_size(1); // way below 8 KiB minimum
1291        assert_eq!(get_stack_size(), ByteSize::kib(8).as_u64() as usize);
1292    }
1293
1294    #[test]
1295    fn set_stack_size_clamps_to_max() {
1296        let _lock = GLOBAL_STATE.lock().unwrap();
1297        let _restore = RestoreStackSize(get_stack_size());
1298        set_stack_size(usize::MAX);
1299        assert_eq!(get_stack_size(), MAX_STACK_SIZE);
1300    }
1301
1302    #[test]
1303    fn drain_stack_pool_empties_pool() {
1304        let _lock = GLOBAL_STATE.lock().unwrap();
1305        let stack = DefaultStack::new(ByteSize::mib(1).as_u64() as usize).unwrap();
1306        STACK_POOL.push(stack);
1307        assert!(!STACK_POOL.is_empty());
1308        drain_stack_pool();
1309        assert!(STACK_POOL.is_empty());
1310    }
1311
1312    #[test]
1313    fn drain_stack_pool_is_idempotent() {
1314        let _lock = GLOBAL_STATE.lock().unwrap();
1315        drain_stack_pool();
1316        drain_stack_pool(); // second call on empty pool should not panic
1317        assert!(STACK_POOL.is_empty());
1318    }
1319
1320    /// The stack pool is not size-aware, so after a stack size increase it keeps
1321    /// serving cached undersized stacks. `drain_stack_pool()` breaks the cycle.
1322    ///
1323    /// 1. A call fills the pool with 500 KiB stacks (simulating normal execution).
1324    /// 2. The caller doubles the default to 1 MiB (simulating overflow retry).
1325    /// 3. WITHOUT draining, the pool still hands back a 500 KiB stack — the
1326    ///    retry would overflow again, creating an infinite loop.
1327    /// 4. After `drain_stack_pool()`, the pool is empty and the next allocation
1328    ///    must use the new, larger size.
1329    #[test]
1330    fn pool_returns_stale_stack_without_drain() {
1331        let _lock = GLOBAL_STATE.lock().unwrap();
1332        let _restore = RestoreStackSize(get_stack_size());
1333        drain_stack_pool();
1334
1335        // --- Phase 1: simulate normal execution that returns a 500 KiB stack ---
1336        let small_size = ByteSize::kib(500).as_u64() as usize;
1337        let small_stack = DefaultStack::new(small_size).unwrap();
1338        STACK_POOL.push(small_stack);
1339
1340        // --- Phase 2: "overflow detected" — caller doubles the default ---
1341        let big_size = ByteSize::mib(1).as_u64() as usize;
1342        set_stack_size(big_size);
1343        assert_eq!(get_stack_size(), big_size);
1344
1345        // --- Phase 3: WITHOUT drain, pool still returns the old small stack ---
1346        // This is the bug: the caller asked for a bigger stack but the pool
1347        // serves a cached undersized one, causing the retry to overflow again.
1348        let stale = STACK_POOL.pop();
1349        assert!(
1350            stale.is_some(),
1351            "pool should still contain the old stack (the bug scenario)"
1352        );
1353
1354        // --- Phase 4: with drain, pool is empty — next alloc uses new size ---
1355        STACK_POOL.push(stale.unwrap());
1356        drain_stack_pool();
1357        assert!(
1358            STACK_POOL.pop().is_none(),
1359            "after drain, pool must be empty so a fresh stack is allocated at the new size"
1360        );
1361    }
1362
1363    /// `on_wasm_stack` discards undersized stacks from the pool and allocates
1364    /// a fresh one instead of blindly reusing whatever the pool returns.
1365    #[test]
1366    fn on_wasm_stack_discards_undersized_stack() {
1367        let _lock = GLOBAL_STATE.lock().unwrap();
1368        let _restore = RestoreStackSize(get_stack_size());
1369        drain_stack_pool();
1370
1371        // Push an undersized stack into the pool.
1372        let small_size = ByteSize::kib(500).as_u64() as usize;
1373        let small_stack = DefaultStack::new(small_size).unwrap();
1374        STACK_POOL.push(small_stack);
1375
1376        // Request a larger stack via on_wasm_stack.
1377        let big_size = ByteSize::mib(1).as_u64() as usize;
1378        let result = on_wasm_stack(big_size, None, || 42);
1379
1380        assert_eq!(result.ok().expect("on_wasm_stack should succeed"), 42);
1381        // The undersized stack was discarded; the pool should now contain
1382        // the correctly-sized stack that was allocated for this call.
1383        let returned = STACK_POOL
1384            .pop()
1385            .expect("stack should have been returned to pool");
1386        assert!(
1387            returned.size() >= big_size,
1388            "returned stack must be at least as large as the requested size"
1389        );
1390    }
1391}