wasmer_vm/trap/
traphandlers.rs

1// This file contains code from external sources.
2// Attributions: https://github.com/wasmerio/wasmer/blob/main/docs/ATTRIBUTIONS.md
3
4#![allow(static_mut_refs)]
5
6//! WebAssembly trap handling, which is built on top of the lower-level
7//! signalhandling mechanisms.
8
9use super::trap::UnwindReason;
10use crate::Trap;
11#[cfg(all(unix, feature = "experimental-host-interrupt"))]
12use crate::interrupt_registry;
13use backtrace::Backtrace;
14use bytesize::ByteSize;
15use core::ptr::{read, read_unaligned};
16use corosensei::stack::{DefaultStack, Stack};
17use corosensei::trap::{CoroutineTrapHandler, TrapHandlerRegs};
18use corosensei::{CoroutineResult, ScopedCoroutine, Yielder};
19use scopeguard::defer;
20use std::any::Any;
21use std::cell::Cell;
22use std::error::Error;
23use std::io;
24use std::mem;
25#[cfg(unix)]
26use std::mem::MaybeUninit;
27use std::ptr::{self, NonNull};
28use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering, compiler_fence};
29use std::sync::{LazyLock, Once};
30use wasmer_types::TrapCode;
31
32/// Convenience extension for [`Stack`] that exposes the total mapped size.
33trait StackExt: Stack {
34    /// Returns the total size of the stack mapping (including guard page).
35    fn size(&self) -> usize {
36        self.base().get() - self.limit().get()
37    }
38}
39impl<T: Stack> StackExt for T {}
40
41/// Configuration for the runtime VM
42/// Currently only the stack size is configurable
43pub struct VMConfig {
44    /// Optional stack size (in byte) of the VM. Value lower than 8K will be rounded to 8K.
45    pub wasm_stack_size: Option<usize>,
46}
47
48// TrapInformation can be stored in the "Undefined Instruction" itself.
49// On x86_64, 0xC? select a "Register" for the Mod R/M part of "ud1" (so with no other bytes after)
50// On Arm64, the udf allows for a 16bits values, so we'll use the same 0xC? to store the trapinfo
51static MAGIC: u8 = 0xc0;
52
53static DEFAULT_STACK_SIZE: AtomicUsize = AtomicUsize::new(ByteSize::mib(1).as_u64() as usize);
54
55/// Maximum allowed default stack size (100 MiB) for the process-wide
56/// configuration set via `set_stack_size`.
57pub const MAX_STACK_SIZE: usize = ByteSize::mib(100).as_u64() as usize;
58
59// Current definition of `ucontext_t` in the `libc` crate is incorrect
60// on aarch64-apple-drawing so it's defined here with a more accurate definition.
61#[repr(C)]
62#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
63#[allow(non_camel_case_types)]
64struct ucontext_t {
65    uc_onstack: libc::c_int,
66    uc_sigmask: libc::sigset_t,
67    uc_stack: libc::stack_t,
68    uc_link: *mut libc::ucontext_t,
69    uc_mcsize: usize,
70    uc_mcontext: libc::mcontext_t,
71}
72
73#[cfg(all(unix, not(all(target_arch = "aarch64", target_os = "macos"))))]
74use libc::ucontext_t;
75
76/// Sets the process-wide default stack size for new Wasmer coroutines.
77/// The value is clamped to [8 KiB, MAX_STACK_SIZE].
78pub fn set_stack_size(size: usize) {
79    DEFAULT_STACK_SIZE.store(
80        size.clamp(ByteSize::kib(8).as_u64() as usize, MAX_STACK_SIZE),
81        Ordering::Relaxed,
82    );
83}
84
85/// Returns the process-wide default stack size in bytes.
86pub fn get_stack_size() -> usize {
87    DEFAULT_STACK_SIZE.load(Ordering::Relaxed)
88}
89
90/// Pool of pre-allocated coroutine stacks to avoid repeated mmap syscalls.
91/// Acts as the cross-thread overflow store; per-thread reuse is served by
92/// `TLS_STACK` to keep the hot path atomic-free.
93static STACK_POOL: LazyLock<crossbeam_queue::SegQueue<DefaultStack>> =
94    LazyLock::new(crossbeam_queue::SegQueue::new);
95
96/// Per-thread cache holding a single ready-to-use coroutine stack. The hot
97/// path of `on_wasm_stack` pops from here without touching the global
98/// `STACK_POOL`'s atomics; only the first call on a thread or re-entrant
99/// nested calls fall back to the pool.
100///
101/// On thread exit the held stack (if any) is pushed to `STACK_POOL` so memory
102/// cycles correctly across thread lifetimes (no mmap leaks).
103struct StackCache(Cell<Option<DefaultStack>>);
104
105impl Drop for StackCache {
106    fn drop(&mut self) {
107        if let Some(stack) = self.0.take() {
108            STACK_POOL.push(stack);
109        }
110    }
111}
112
113thread_local! {
114    static TLS_STACK: StackCache = const { StackCache(Cell::new(None)) };
115}
116
117/// Acquire a coroutine stack large enough for `min_size`. Prefers the
118/// thread-local cache (no atomics), falls back to the global `STACK_POOL`,
119/// then allocates a fresh stack.
120fn acquire_stack(min_size: usize) -> DefaultStack {
121    // Fast path: thread-local cache. Steady-state per-thread reuse never
122    // touches the SegQueue.
123    if let Some(stack) = TLS_STACK.with(|cache| cache.0.take()) {
124        if stack.size() >= min_size {
125            return stack;
126        }
127        // Undersized — discard (mirrors the existing `STACK_POOL.pop().filter(...)`
128        // behavior of not holding undersized stacks in rotation).
129        drop(stack);
130    }
131    // Cross-thread overflow pool. Single pop, single filter — same semantics
132    // as the pre-TLS implementation.
133    STACK_POOL
134        .pop()
135        .filter(|s| s.size() >= min_size)
136        .unwrap_or_else(|| DefaultStack::new(min_size).unwrap())
137}
138
139/// Release a coroutine stack. Prefers the thread-local slot if empty (no
140/// atomics); otherwise pushes to the global `STACK_POOL` so the stack is
141/// still reusable by other threads.
142fn release_stack(stack: DefaultStack) {
143    let displaced = TLS_STACK.with(|cache| cache.0.replace(Some(stack)));
144    if let Some(displaced) = displaced {
145        STACK_POOL.push(displaced);
146    }
147}
148
149/// Drains the coroutine stack pool at the moment it runs.
150///
151/// This is intended to be called before retrying with a larger stack size so
152/// that the pool does not keep serving cached undersized stacks.
153///
154/// Note that `STACK_POOL` is a global, concurrently used queue and that each
155/// thread also keeps a private cached stack in TLS. Other threads may push
156/// stacks back into the pool (for example, when their Wasm execution
157/// finishes) while or after this function is running, and TLS-cached stacks
158/// on other threads are not touched. As a result, this function provides
159/// only a best-effort drain: there is no guarantee that no undersized stacks
160/// exist immediately after it returns unless the caller ensures, via external
161/// synchronization, that no other Wasm executions can return stacks to the
162/// pool while this function runs. The current thread's TLS-cached stack is
163/// drained as part of this call.
164pub fn drain_stack_pool() {
165    // Drain the calling thread's TLS slot first (best-effort across threads
166    // still applies — other threads' caches aren't touched).
167    if let Some(stack) = TLS_STACK.with(|cache| cache.0.take()) {
168        drop(stack);
169    }
170    while STACK_POOL.pop().is_some() {}
171}
172
173cfg_if::cfg_if! {
174    if #[cfg(unix)] {
175        /// Function which may handle custom signals while processing traps.
176        pub type TrapHandlerFn<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + Send + Sync + 'a;
177    } else if #[cfg(target_os = "windows")] {
178        /// Function which may handle custom signals while processing traps.
179        pub type TrapHandlerFn<'a> = dyn Fn(*mut windows_sys::Win32::System::Diagnostics::Debug::EXCEPTION_POINTERS) -> bool + Send + Sync + 'a;
180    }
181}
182
183// Process an IllegalOpcode to see if it has a TrapCode payload
184unsafe fn process_illegal_op(addr: usize) -> Option<TrapCode> {
185    let mut val: Option<u8> = None;
186    unsafe {
187        if cfg!(target_arch = "x86_64") {
188            val = if read(addr as *mut u8) & 0xf0 == 0x40
189                && read((addr + 1) as *mut u8) == 0x0f
190                && read((addr + 2) as *mut u8) == 0xb9
191            {
192                Some(read((addr + 3) as *mut u8))
193            } else if read(addr as *mut u8) == 0x0f && read((addr + 1) as *mut u8) == 0xb9 {
194                Some(read((addr + 2) as *mut u8))
195            } else {
196                None
197            }
198        }
199        if cfg!(target_arch = "aarch64") {
200            val = if read_unaligned(addr as *mut u32) & 0xffff0000 == 0 {
201                Some(read(addr as *mut u8))
202            } else {
203                None
204            }
205        }
206        if cfg!(target_arch = "riscv64") {
207            let addr = addr as *mut u32;
208            // Check if 'unimp' instruction
209            val = if read(addr) == 0xc0001073 {
210                // Read from the instruction we emitted: 'addi a0, xzero, $payload'
211                // and take the encoded immediate value (upper 12-bits).
212                let prev_insn = read(addr.sub(1));
213                if (prev_insn & 0xffff) == 0x0513 {
214                    Some((prev_insn >> 20) as u8)
215                } else {
216                    None
217                }
218            } else {
219                None
220            };
221        }
222    }
223
224    // The direct encoding of a trap into the instruction is unused on RISC-V:
225    if cfg!(target_arch = "x86_64") || cfg!(target_arch = "aarch64") {
226        val = val.and_then(|val| {
227            if val & MAGIC == MAGIC {
228                Some(val & 0xf)
229            } else {
230                None
231            }
232        });
233    }
234
235    match val {
236        None => None,
237        Some(val) => match val {
238            0 => Some(TrapCode::StackOverflow),
239            1 => Some(TrapCode::HeapAccessOutOfBounds),
240            2 => Some(TrapCode::HeapMisaligned),
241            3 => Some(TrapCode::TableAccessOutOfBounds),
242            4 => Some(TrapCode::IndirectCallToNull),
243            5 => Some(TrapCode::BadSignature),
244            6 => Some(TrapCode::IntegerOverflow),
245            7 => Some(TrapCode::IntegerDivisionByZero),
246            8 => Some(TrapCode::BadConversionToInteger),
247            9 => Some(TrapCode::UnreachableCodeReached),
248            10 => Some(TrapCode::UnalignedAtomic),
249            _ => None,
250        },
251    }
252}
253
254cfg_if::cfg_if! {
255    if #[cfg(unix)] {
256        static mut PREV_SIGSEGV: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
257        static mut PREV_SIGBUS: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
258        static mut PREV_SIGILL: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
259        static mut PREV_SIGFPE: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
260
261        #[cfg(feature = "experimental-host-interrupt")]
262        static mut PREV_SIGUSR1: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
263
264        unsafe fn platform_init() { unsafe {
265            let register = |slot: &mut MaybeUninit<libc::sigaction>, signal: i32, nodefer: bool| {
266                let mut handler: libc::sigaction = mem::zeroed();
267                // The flags here are relatively careful, and they are...
268                //
269                // SA_SIGINFO gives us access to information like the program
270                // counter from where the fault happened.
271                //
272                // SA_ONSTACK allows us to handle signals on an alternate stack,
273                // so that the handler can run in response to running out of
274                // stack space on the main stack. Rust installs an alternate
275                // stack with sigaltstack, so we rely on that.
276                //
277                // SA_NODEFER allows us to reenter the signal handler if we
278                // crash while handling the signal, and fall through to the
279                // Breakpad handler by testing handlingSegFault.
280                handler.sa_flags = libc::SA_SIGINFO | libc::SA_ONSTACK;
281                if nodefer {
282                    handler.sa_flags |= libc::SA_NODEFER;
283                }
284                handler.sa_sigaction = trap_handler as *const () as usize;
285                libc::sigemptyset(&mut handler.sa_mask);
286                if libc::sigaction(signal, &handler, slot.as_mut_ptr()) != 0 {
287                    panic!(
288                        "unable to install signal handler: {}",
289                        io::Error::last_os_error(),
290                    );
291                }
292            };
293
294            // Allow handling OOB with signals on all architectures
295            register(&mut PREV_SIGSEGV, libc::SIGSEGV, true);
296
297            // Handle `unreachable` instructions which execute `ud2` right now
298            register(&mut PREV_SIGILL, libc::SIGILL, true);
299
300            // SIGUSR1 is used to interrupt long-running WASM code.
301            // It doesn't use NODEFER since, if a second interruption
302            // request comes in while one is already being processed,
303            // there's nothing meaningful we can do.
304            #[cfg(feature = "experimental-host-interrupt")]
305            register(&mut PREV_SIGUSR1, libc::SIGUSR1, false);
306
307            // x86 uses SIGFPE to report division by zero
308            if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
309                register(&mut PREV_SIGFPE, libc::SIGFPE, true);
310            }
311
312            // On ARM, handle Unaligned Accesses.
313            // On Darwin, guard page accesses are raised as SIGBUS.
314            if cfg!(target_arch = "arm") || cfg!(target_vendor = "apple") {
315                register(&mut PREV_SIGBUS, libc::SIGBUS, true);
316            }
317
318            // This is necessary to support debugging under LLDB on Darwin.
319            // For more details see https://github.com/mono/mono/commit/8e75f5a28e6537e56ad70bf870b86e22539c2fb7
320            #[cfg(target_vendor = "apple")]
321            {
322                use mach2::exception_types::*;
323                use mach2::kern_return::*;
324                use mach2::port::*;
325                use mach2::thread_status::*;
326                use mach2::traps::*;
327                use mach2::mach_types::*;
328
329                unsafe extern "C" {
330                    fn task_set_exception_ports(
331                        task: task_t,
332                        exception_mask: exception_mask_t,
333                        new_port: mach_port_t,
334                        behavior: exception_behavior_t,
335                        new_flavor: thread_state_flavor_t,
336                    ) -> kern_return_t;
337                }
338
339                #[allow(non_snake_case)]
340                #[cfg(target_arch = "x86_64")]
341                let MACHINE_THREAD_STATE = x86_THREAD_STATE64;
342                #[allow(non_snake_case)]
343                #[cfg(target_arch = "aarch64")]
344                let MACHINE_THREAD_STATE = 6;
345
346                task_set_exception_ports(
347                    mach_task_self(),
348                    EXC_MASK_BAD_ACCESS | EXC_MASK_ARITHMETIC | EXC_MASK_BAD_INSTRUCTION,
349                    MACH_PORT_NULL,
350                    EXCEPTION_STATE_IDENTITY as exception_behavior_t,
351                    MACHINE_THREAD_STATE,
352                );
353            }
354        }}
355
356        unsafe extern "C" fn trap_handler(
357            signum: libc::c_int,
358            siginfo: *mut libc::siginfo_t,
359            context: *mut libc::c_void,
360        ) { unsafe {
361            let previous = match signum {
362                libc::SIGSEGV => &PREV_SIGSEGV,
363                libc::SIGBUS => &PREV_SIGBUS,
364                libc::SIGFPE => &PREV_SIGFPE,
365                libc::SIGILL => &PREV_SIGILL,
366                #[cfg(feature = "experimental-host-interrupt")]
367                libc::SIGUSR1 => &PREV_SIGUSR1,
368                _ => panic!("unknown signal: {signum}"),
369            };
370            // We try to get the fault address associated to this signal
371            let maybe_fault_address = match signum {
372                libc::SIGSEGV | libc::SIGBUS => {
373                    Some((*siginfo).si_addr() as usize)
374                }
375                _ => None,
376            };
377            let trap_code = match signum {
378                // check if it was cased by a UD and if the Trap info is a payload to it
379                libc::SIGILL => {
380                    let addr = (*siginfo).si_addr() as usize;
381                    process_illegal_op(addr)
382                }
383                #[cfg(feature = "experimental-host-interrupt")]
384                libc::SIGUSR1 => {
385                    // If we're not running WASM code from the specific store for which
386                    // an interrupt was requested, there's nothing to do.
387                    if !interrupt_registry::on_interrupted() {
388                        return;
389                    }
390                    Some(TrapCode::HostInterrupt)
391                }
392                _ => None,
393            };
394            let ucontext = &mut *(context as *mut ucontext_t);
395            let (pc, sp) = get_pc_sp(ucontext);
396            let handled = TrapHandlerContext::handle_trap(
397                pc,
398                sp,
399                maybe_fault_address,
400                trap_code,
401                |regs| update_context(ucontext, regs),
402                |handler| handler(signum, siginfo, context),
403            );
404
405            if handled {
406                return;
407            }
408
409            // If we're not running WASM code at all, there's nothing to
410            // do for an interrupt.
411            #[cfg(feature = "experimental-host-interrupt")]
412            if signum == libc::SIGUSR1 {
413                return;
414            }
415
416            // This signal is not for any compiled wasm code we expect, so we
417            // need to forward the signal to the next handler. If there is no
418            // next handler (SIG_IGN or SIG_DFL), then it's time to crash. To do
419            // this, we set the signal back to its original disposition and
420            // return. This will cause the faulting op to be re-executed which
421            // will crash in the normal way. If there is a next handler, call
422            // it. It will either crash synchronously, fix up the instruction
423            // so that execution can continue and return, or trigger a crash by
424            // returning the signal to it's original disposition and returning.
425            let previous = &*previous.as_ptr();
426            if previous.sa_flags & libc::SA_SIGINFO != 0 {
427                mem::transmute::<
428                    usize,
429                    extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void),
430                >(previous.sa_sigaction)(signum, siginfo, context)
431            } else if previous.sa_sigaction == libc::SIG_DFL
432            {
433                libc::sigaction(signum, previous, ptr::null_mut());
434            } else if previous.sa_sigaction != libc::SIG_IGN {
435                mem::transmute::<usize, extern "C" fn(libc::c_int)>(
436                    previous.sa_sigaction
437                )(signum)
438            }
439        }}
440
441        unsafe fn get_pc_sp(context: &ucontext_t) -> (usize, usize) {
442            let (pc, sp);
443            cfg_if::cfg_if! {
444                if #[cfg(all(
445                    any(target_os = "linux", target_os = "android"),
446                    target_arch = "x86_64",
447                ))] {
448                    pc = context.uc_mcontext.gregs[libc::REG_RIP as usize] as usize;
449                    sp = context.uc_mcontext.gregs[libc::REG_RSP as usize] as usize;
450                } else if #[cfg(all(
451                    any(target_os = "linux", target_os = "android"),
452                    target_arch = "x86",
453                ))] {
454                    pc = context.uc_mcontext.gregs[libc::REG_EIP as usize] as usize;
455                    sp = context.uc_mcontext.gregs[libc::REG_ESP as usize] as usize;
456                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
457                    pc = context.uc_mcontext.mc_eip as usize;
458                    sp = context.uc_mcontext.mc_esp as usize;
459                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
460                    pc = context.uc_mcontext.mc_rip as usize;
461                    sp = context.uc_mcontext.mc_rsp as usize;
462                } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
463                    let mcontext = unsafe { &*context.uc_mcontext };
464                    pc = mcontext.__ss.__rip as usize;
465                    sp = mcontext.__ss.__rsp as usize;
466                } else if #[cfg(all(
467                        any(target_os = "linux", target_os = "android"),
468                        target_arch = "aarch64",
469                    ))] {
470                    pc = context.uc_mcontext.pc as usize;
471                    sp = context.uc_mcontext.sp as usize;
472                } else if #[cfg(all(
473                    any(target_os = "linux", target_os = "android"),
474                    target_arch = "arm",
475                ))] {
476                    pc = context.uc_mcontext.arm_pc as usize;
477                    sp = context.uc_mcontext.arm_sp as usize;
478                } else if #[cfg(all(
479                    any(target_os = "linux", target_os = "android"),
480                    any(target_arch = "riscv64", target_arch = "riscv32"),
481                ))] {
482                    pc = context.uc_mcontext.__gregs[libc::REG_PC] as usize;
483                    sp = context.uc_mcontext.__gregs[libc::REG_SP] as usize;
484                } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
485                    let mcontext = unsafe { &*context.uc_mcontext };
486                    pc = mcontext.__ss.__pc as usize;
487                    sp = mcontext.__ss.__sp as usize;
488                } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
489                    pc = context.uc_mcontext.mc_gpregs.gp_elr as usize;
490                    sp = context.uc_mcontext.mc_gpregs.gp_sp as usize;
491                } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
492                    pc = context.uc_mcontext.__gregs[1] as usize;
493                    sp = context.uc_mcontext.__gregs[3] as usize;
494                } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
495                    pc = (*context.uc_mcontext.regs).nip as usize;
496                    sp = (*context.uc_mcontext.regs).gpr[1] as usize;
497                } else {
498                    compile_error!("Unsupported platform");
499                }
500            };
501            (pc, sp)
502        }
503
504        unsafe fn update_context(context: &mut ucontext_t, regs: TrapHandlerRegs) {
505            cfg_if::cfg_if! {
506                if #[cfg(all(
507                        any(target_os = "linux", target_os = "android"),
508                        target_arch = "x86_64",
509                    ))] {
510                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
511                    context.uc_mcontext.gregs[libc::REG_RIP as usize] = rip as i64;
512                    context.uc_mcontext.gregs[libc::REG_RSP as usize] = rsp as i64;
513                    context.uc_mcontext.gregs[libc::REG_RBP as usize] = rbp as i64;
514                    context.uc_mcontext.gregs[libc::REG_RDI as usize] = rdi as i64;
515                    context.uc_mcontext.gregs[libc::REG_RSI as usize] = rsi as i64;
516                } else if #[cfg(all(
517                    any(target_os = "linux", target_os = "android"),
518                    target_arch = "x86",
519                ))] {
520                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
521                    context.uc_mcontext.gregs[libc::REG_EIP as usize] = eip as i32;
522                    context.uc_mcontext.gregs[libc::REG_ESP as usize] = esp as i32;
523                    context.uc_mcontext.gregs[libc::REG_EBP as usize] = ebp as i32;
524                    context.uc_mcontext.gregs[libc::REG_ECX as usize] = ecx as i32;
525                    context.uc_mcontext.gregs[libc::REG_EDX as usize] = edx as i32;
526                } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
527                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
528                    let mcontext = unsafe { &mut *context.uc_mcontext };
529                    mcontext.__ss.__rip = rip;
530                    mcontext.__ss.__rsp = rsp;
531                    mcontext.__ss.__rbp = rbp;
532                    mcontext.__ss.__rdi = rdi;
533                    mcontext.__ss.__rsi = rsi;
534                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
535                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
536                    context.uc_mcontext.mc_eip = eip as libc::register_t;
537                    context.uc_mcontext.mc_esp = esp as libc::register_t;
538                    context.uc_mcontext.mc_ebp = ebp as libc::register_t;
539                    context.uc_mcontext.mc_ecx = ecx as libc::register_t;
540                    context.uc_mcontext.mc_edx = edx as libc::register_t;
541                } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
542                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
543                    context.uc_mcontext.mc_rip = rip as libc::register_t;
544                    context.uc_mcontext.mc_rsp = rsp as libc::register_t;
545                    context.uc_mcontext.mc_rbp = rbp as libc::register_t;
546                    context.uc_mcontext.mc_rdi = rdi as libc::register_t;
547                    context.uc_mcontext.mc_rsi = rsi as libc::register_t;
548                } else if #[cfg(all(
549                        any(target_os = "linux", target_os = "android"),
550                        target_arch = "aarch64",
551                    ))] {
552                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
553                    context.uc_mcontext.pc = pc;
554                    context.uc_mcontext.sp = sp;
555                    context.uc_mcontext.regs[0] = x0;
556                    context.uc_mcontext.regs[1] = x1;
557                    context.uc_mcontext.regs[29] = x29;
558                    context.uc_mcontext.regs[30] = lr;
559                } else if #[cfg(all(
560                        any(target_os = "linux", target_os = "android"),
561                        target_arch = "arm",
562                    ))] {
563                    let TrapHandlerRegs {
564                        pc,
565                        r0,
566                        r1,
567                        r7,
568                        r11,
569                        r13,
570                        r14,
571                        cpsr_thumb,
572                        cpsr_endian,
573                    } = regs;
574                    context.uc_mcontext.arm_pc = pc;
575                    context.uc_mcontext.arm_r0 = r0;
576                    context.uc_mcontext.arm_r1 = r1;
577                    context.uc_mcontext.arm_r7 = r7;
578                    context.uc_mcontext.arm_fp = r11;
579                    context.uc_mcontext.arm_sp = r13;
580                    context.uc_mcontext.arm_lr = r14;
581                    if cpsr_thumb {
582                        context.uc_mcontext.arm_cpsr |= 0x20;
583                    } else {
584                        context.uc_mcontext.arm_cpsr &= !0x20;
585                    }
586                    if cpsr_endian {
587                        context.uc_mcontext.arm_cpsr |= 0x200;
588                    } else {
589                        context.uc_mcontext.arm_cpsr &= !0x200;
590                    }
591                } else if #[cfg(all(
592                    any(target_os = "linux", target_os = "android"),
593                    any(target_arch = "riscv64", target_arch = "riscv32"),
594                ))] {
595                    let TrapHandlerRegs { pc, ra, sp, a0, a1, s0 } = regs;
596                    context.uc_mcontext.__gregs[libc::REG_PC] = pc as libc::c_ulong;
597                    context.uc_mcontext.__gregs[libc::REG_RA] = ra as libc::c_ulong;
598                    context.uc_mcontext.__gregs[libc::REG_SP] = sp as libc::c_ulong;
599                    context.uc_mcontext.__gregs[libc::REG_A0] = a0 as libc::c_ulong;
600                    context.uc_mcontext.__gregs[libc::REG_A0 + 1] = a1 as libc::c_ulong;
601                    context.uc_mcontext.__gregs[libc::REG_S0] = s0 as libc::c_ulong;
602                } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
603                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
604                    let mcontext = unsafe { &mut *context.uc_mcontext };
605                    mcontext.__ss.__pc = pc;
606                    mcontext.__ss.__sp = sp;
607                    mcontext.__ss.__x[0] = x0;
608                    mcontext.__ss.__x[1] = x1;
609                    mcontext.__ss.__fp = x29;
610                    mcontext.__ss.__lr = lr;
611                } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
612                    let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
613                    context.uc_mcontext.mc_gpregs.gp_elr = pc as libc::register_t;
614                    context.uc_mcontext.mc_gpregs.gp_sp = sp as libc::register_t;
615                    context.uc_mcontext.mc_gpregs.gp_x[0] = x0 as libc::register_t;
616                    context.uc_mcontext.mc_gpregs.gp_x[1] = x1 as libc::register_t;
617                    context.uc_mcontext.mc_gpregs.gp_x[29] = x29 as libc::register_t;
618                    context.uc_mcontext.mc_gpregs.gp_lr = lr as libc::register_t;
619                } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
620                    let TrapHandlerRegs { pc, sp, a0, a1, fp, ra } = regs;
621                    context.uc_mcontext.__pc = pc;
622                    context.uc_mcontext.__gregs[1] = ra;
623                    context.uc_mcontext.__gregs[3] = sp;
624                    context.uc_mcontext.__gregs[4] = a0;
625                    context.uc_mcontext.__gregs[5] = a1;
626                    context.uc_mcontext.__gregs[22] = fp;
627                } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
628                    let TrapHandlerRegs { pc, sp, r3, r4, r31, lr } = regs;
629                    (*context.uc_mcontext.regs).nip = pc;
630                    (*context.uc_mcontext.regs).gpr[1] = sp;
631                    (*context.uc_mcontext.regs).gpr[3] = r3;
632                    (*context.uc_mcontext.regs).gpr[4] = r4;
633                    (*context.uc_mcontext.regs).gpr[31] = r31;
634                    (*context.uc_mcontext.regs).link = lr;
635                } else {
636                    compile_error!("Unsupported platform");
637                }
638            };
639        }
640    } else if #[cfg(target_os = "windows")] {
641        use windows_sys::Win32::System::Diagnostics::Debug::{
642            AddVectoredExceptionHandler,
643            CONTEXT,
644            EXCEPTION_CONTINUE_EXECUTION,
645            EXCEPTION_CONTINUE_SEARCH,
646            EXCEPTION_POINTERS,
647        };
648        use windows_sys::Win32::Foundation::{
649            EXCEPTION_ACCESS_VIOLATION,
650            EXCEPTION_ILLEGAL_INSTRUCTION,
651            EXCEPTION_INT_DIVIDE_BY_ZERO,
652            EXCEPTION_INT_OVERFLOW,
653            EXCEPTION_STACK_OVERFLOW,
654        };
655
656        unsafe fn platform_init() {
657            unsafe {
658                // our trap handler needs to go first, so that we can recover from
659                // wasm faults and continue execution, so pass `1` as a true value
660                // here.
661                let handler = AddVectoredExceptionHandler(1, Some(exception_handler));
662                if handler.is_null() {
663                    panic!("failed to add exception handler: {}", io::Error::last_os_error());
664                }
665            }
666        }
667
668        unsafe extern "system" fn exception_handler(
669            exception_info: *mut EXCEPTION_POINTERS
670        ) -> i32 {
671            unsafe {
672                // Check the kind of exception, since we only handle a subset within
673                // wasm code. If anything else happens we want to defer to whatever
674                // the rest of the system wants to do for this exception.
675                let record = &*(*exception_info).ExceptionRecord;
676                if record.ExceptionCode != EXCEPTION_ACCESS_VIOLATION &&
677                    record.ExceptionCode != EXCEPTION_ILLEGAL_INSTRUCTION &&
678                    record.ExceptionCode != EXCEPTION_STACK_OVERFLOW &&
679                    record.ExceptionCode != EXCEPTION_INT_DIVIDE_BY_ZERO &&
680                    record.ExceptionCode != EXCEPTION_INT_OVERFLOW
681                {
682                    return EXCEPTION_CONTINUE_SEARCH;
683                }
684
685                // FIXME: this is what the previous C++ did to make sure that TLS
686                // works by the time we execute this trap handling code. This isn't
687                // exactly super easy to call from Rust though and it's not clear we
688                // necessarily need to do so. Leaving this here in case we need this
689                // in the future, but for now we can probably wait until we see a
690                // strange fault before figuring out how to reimplement this in
691                // Rust.
692                //
693                // if (!NtCurrentTeb()->Reserved1[sThreadLocalArrayPointerIndex]) {
694                //     return EXCEPTION_CONTINUE_SEARCH;
695                // }
696
697                let context = &mut *(*exception_info).ContextRecord;
698                let (pc, sp) = get_pc_sp(context);
699
700                // We try to get the fault address associated to this exception.
701                let maybe_fault_address = match record.ExceptionCode {
702                    EXCEPTION_ACCESS_VIOLATION => Some(record.ExceptionInformation[1]),
703                    EXCEPTION_STACK_OVERFLOW => Some(sp),
704                    _ => None,
705                };
706                let trap_code = match record.ExceptionCode {
707                    // check if it was cased by a UD and if the Trap info is a payload to it
708                    EXCEPTION_ILLEGAL_INSTRUCTION => {
709                        process_illegal_op(pc)
710                    }
711                    _ => None,
712                };
713                // This is basically the same as the unix version above, only with a
714                // few parameters tweaked here and there.
715                let handled = TrapHandlerContext::handle_trap(
716                    pc,
717                    sp,
718                    maybe_fault_address,
719                    trap_code,
720                    |regs| update_context(context, regs),
721                    |handler| handler(exception_info),
722                );
723
724                if handled {
725                    EXCEPTION_CONTINUE_EXECUTION
726                } else {
727                    EXCEPTION_CONTINUE_SEARCH
728                }
729            }
730        }
731
732        unsafe fn get_pc_sp(context: &CONTEXT) -> (usize, usize) {
733            let (pc, sp);
734            cfg_if::cfg_if! {
735                if #[cfg(target_arch = "x86_64")] {
736                    pc = context.Rip as usize;
737                    sp = context.Rsp as usize;
738                } else if #[cfg(target_arch = "x86")] {
739                    pc = context.Eip as usize;
740                    sp = context.Esp as usize;
741                } else {
742                    compile_error!("Unsupported platform");
743                }
744            };
745            (pc, sp)
746        }
747
748        unsafe fn update_context(context: &mut CONTEXT, regs: TrapHandlerRegs) {
749            cfg_if::cfg_if! {
750                if #[cfg(target_arch = "x86_64")] {
751                    let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
752                    context.Rip = rip;
753                    context.Rsp = rsp;
754                    context.Rbp = rbp;
755                    context.Rdi = rdi;
756                    context.Rsi = rsi;
757                } else if #[cfg(target_arch = "x86")] {
758                    let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
759                    context.Eip = eip;
760                    context.Esp = esp;
761                    context.Ebp = ebp;
762                    context.Ecx = ecx;
763                    context.Edx = edx;
764                } else {
765                    compile_error!("Unsupported platform");
766                }
767            };
768        }
769    }
770}
771
772/// This function is required to be called before any WebAssembly is entered.
773/// This will configure global state such as signal handlers to prepare the
774/// process to receive wasm traps.
775///
776/// This function must not only be called globally once before entering
777/// WebAssembly but it must also be called once-per-thread that enters
778/// WebAssembly. Currently in wasmer's integration this function is called on
779/// creation of a `Store`.
780pub fn init_traps() {
781    static INIT: Once = Once::new();
782    INIT.call_once(|| unsafe {
783        platform_init();
784    });
785}
786
787/// Raises a user-defined trap immediately.
788///
789/// This function performs as-if a wasm trap was just executed, only the trap
790/// has a dynamic payload associated with it which is user-provided. This trap
791/// payload is then returned from `catch_traps` below.
792///
793/// # Safety
794///
795/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
796/// have been previous called and not yet returned.
797/// Additionally no Rust destructors may be on the stack.
798/// They will be skipped and not executed.
799pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
800    unsafe { unwind_with(UnwindReason::UserTrap(data)) }
801}
802
803/// Raises a trap from inside library code immediately.
804///
805/// This function performs as-if a wasm trap was just executed. This trap
806/// payload is then returned from `catch_traps` below.
807///
808/// # Safety
809///
810/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
811/// have been previous called and not yet returned.
812/// Additionally no Rust destructors may be on the stack.
813/// They will be skipped and not executed.
814pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
815    unsafe { unwind_with(UnwindReason::LibTrap(trap)) }
816}
817
818/// Carries a Rust panic across wasm code and resumes the panic on the other
819/// side.
820///
821/// # Safety
822///
823/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
824/// have been previously called and not returned. Additionally no Rust destructors may be on the
825/// stack. They will be skipped and not executed.
826pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
827    unsafe { unwind_with(UnwindReason::Panic(payload)) }
828}
829
830/// Catches any wasm traps that happen within the execution of `closure`,
831/// returning them as a `Result`.
832///
833/// # Safety
834///
835/// Highly unsafe since `closure` won't have any dtors run.
836pub unsafe fn catch_traps<F, R: 'static>(
837    trap_handler: Option<*const TrapHandlerFn<'static>>,
838    config: &VMConfig,
839    closure: F,
840) -> Result<R, Trap>
841where
842    F: FnOnce() -> R + 'static,
843{
844    // Ensure that per-thread initialization is done.
845    lazy_per_thread_init()?;
846    let stack_size = config
847        .wasm_stack_size
848        .unwrap_or_else(|| DEFAULT_STACK_SIZE.load(Ordering::Relaxed));
849    on_wasm_stack(stack_size, trap_handler, closure).map_err(UnwindReason::into_trap)
850}
851
852// We need two separate thread-local variables here:
853// - YIELDER is set within the new stack and is used to unwind back to the root
854//   of the stack from inside it.
855// - TRAP_HANDLER is set from outside the new stack and is solely used from
856//   signal handlers. It must be atomic since it is used by signal handlers.
857//
858// We also do per-thread signal stack initialization on the first time
859// TRAP_HANDLER is accessed.
860thread_local! {
861    static YIELDER: Cell<Option<NonNull<Yielder<(), UnwindReason>>>> = const { Cell::new(None) };
862    static TRAP_HANDLER: AtomicPtr<TrapHandlerContext> = const { AtomicPtr::new(ptr::null_mut()) };
863}
864
865/// Read-only information that is used by signal handlers to handle and recover
866/// from traps.
867#[allow(clippy::type_complexity)]
868struct TrapHandlerContext {
869    inner: *const u8,
870    handle_trap: fn(
871        *const u8,
872        usize,
873        usize,
874        Option<usize>,
875        Option<TrapCode>,
876        &mut dyn FnMut(TrapHandlerRegs),
877    ) -> bool,
878    custom_trap: Option<*const TrapHandlerFn<'static>>,
879}
880struct TrapHandlerContextInner<T> {
881    /// Information about the currently running coroutine. This is used to
882    /// reset execution to the root of the coroutine when a trap is handled.
883    coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
884}
885
886impl TrapHandlerContext {
887    /// Runs the given function with a trap handler context. The previous
888    /// trap handler context is preserved and restored afterwards.
889    fn install<T, R>(
890        custom_trap: Option<*const TrapHandlerFn<'static>>,
891        coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
892        f: impl FnOnce() -> R,
893    ) -> R {
894        // Type-erase the trap handler function so that it can be placed in TLS.
895        fn func<T>(
896            ptr: *const u8,
897            pc: usize,
898            sp: usize,
899            maybe_fault_address: Option<usize>,
900            trap_code: Option<TrapCode>,
901            update_regs: &mut dyn FnMut(TrapHandlerRegs),
902        ) -> bool {
903            unsafe {
904                (*(ptr as *const TrapHandlerContextInner<T>)).handle_trap(
905                    pc,
906                    sp,
907                    maybe_fault_address,
908                    trap_code,
909                    update_regs,
910                )
911            }
912        }
913        let inner = TrapHandlerContextInner { coro_trap_handler };
914        let ctx = Self {
915            inner: &inner as *const _ as *const u8,
916            handle_trap: func::<T>,
917            custom_trap,
918        };
919
920        compiler_fence(Ordering::Release);
921        let prev = TRAP_HANDLER.with(|ptr| {
922            let prev = ptr.load(Ordering::Relaxed);
923            ptr.store(&ctx as *const Self as *mut Self, Ordering::Relaxed);
924            prev
925        });
926
927        defer! {
928            TRAP_HANDLER.with(|ptr| ptr.store(prev, Ordering::Relaxed));
929            compiler_fence(Ordering::Acquire);
930        }
931
932        f()
933    }
934
935    /// Attempts to handle the trap if it's a wasm trap.
936    unsafe fn handle_trap(
937        pc: usize,
938        sp: usize,
939        maybe_fault_address: Option<usize>,
940        trap_code: Option<TrapCode>,
941        mut update_regs: impl FnMut(TrapHandlerRegs),
942        call_handler: impl Fn(&TrapHandlerFn<'static>) -> bool,
943    ) -> bool {
944        unsafe {
945            let ptr = TRAP_HANDLER.with(|ptr| ptr.load(Ordering::Relaxed));
946            if ptr.is_null() {
947                return false;
948            }
949
950            let ctx = &*ptr;
951
952            // Check if this trap is handled by a custom trap handler.
953            if let Some(trap_handler) = ctx.custom_trap
954                && call_handler(&*trap_handler)
955            {
956                return true;
957            }
958
959            (ctx.handle_trap)(
960                ctx.inner,
961                pc,
962                sp,
963                maybe_fault_address,
964                trap_code,
965                &mut update_regs,
966            )
967        }
968    }
969}
970
971impl<T> TrapHandlerContextInner<T> {
972    unsafe fn handle_trap(
973        &self,
974        pc: usize,
975        sp: usize,
976        maybe_fault_address: Option<usize>,
977        trap_code: Option<TrapCode>,
978        update_regs: &mut dyn FnMut(TrapHandlerRegs),
979    ) -> bool {
980        unsafe {
981            // Check if this trap occurred while executing on the Wasm stack. We can
982            // only recover from traps if that is the case.
983            if !self.coro_trap_handler.stack_ptr_in_bounds(sp) {
984                return false;
985            }
986
987            let signal_trap = trap_code.or_else(|| {
988                maybe_fault_address.map(|addr| {
989                    if self.coro_trap_handler.stack_ptr_in_bounds(addr) {
990                        TrapCode::StackOverflow
991                    } else {
992                        TrapCode::HeapAccessOutOfBounds
993                    }
994                })
995            });
996
997            // Don't try to generate a backtrace for stack overflows: unwinding
998            // information is often not precise enough to properly describe what is
999            // happening during a function prologue, which can lead the unwinder to
1000            // read invalid memory addresses.
1001            //
1002            // See: https://github.com/rust-lang/backtrace-rs/pull/357
1003            let backtrace = if signal_trap == Some(TrapCode::StackOverflow) {
1004                Backtrace::from(vec![])
1005            } else {
1006                Backtrace::new_unresolved()
1007            };
1008
1009            // Set up the register state for exception return to force the
1010            // coroutine to return to its caller with UnwindReason::WasmTrap.
1011            let unwind = UnwindReason::WasmTrap {
1012                backtrace,
1013                signal_trap,
1014                pc,
1015            };
1016            let regs = self
1017                .coro_trap_handler
1018                .setup_trap_handler(move || Err(unwind));
1019            update_regs(regs);
1020            true
1021        }
1022    }
1023}
1024
1025unsafe fn unwind_with(reason: UnwindReason) -> ! {
1026    unsafe {
1027        let yielder = YIELDER
1028            .with(|cell| cell.replace(None))
1029            .expect("not running on Wasm stack");
1030
1031        yielder.as_ref().suspend(reason);
1032
1033        // on_wasm_stack will forcibly reset the coroutine stack after yielding.
1034        unreachable!();
1035    }
1036}
1037
1038/// Runs the given function on a separate stack so that its stack usage can be
1039/// bounded. Stack overflows and other traps can be caught and execution
1040/// returned to the root of the stack.
1041fn on_wasm_stack<F: FnOnce() -> T + 'static, T: 'static>(
1042    stack_size: usize,
1043    trap_handler: Option<*const TrapHandlerFn<'static>>,
1044    f: F,
1045) -> Result<T, UnwindReason> {
1046    // Reuse a cached stack — TLS first (atomic-free hot path), then the
1047    // cross-thread overflow pool, then allocate fresh. Size mismatches
1048    // (e.g. after `drain_stack_pool()` + a stack-size change) are filtered
1049    // inside `acquire_stack`. `base() - limit()` is the full mmap region
1050    // (including guard page), which is always >= the requested size for
1051    // stacks allocated with that size.
1052    let stack = acquire_stack(stack_size);
1053    let mut stack = scopeguard::guard(stack, release_stack);
1054
1055    // Create a coroutine with a new stack to run the function on.
1056    let coro = ScopedCoroutine::with_stack(&mut *stack, move |yielder, ()| {
1057        // Save the yielder to TLS so that it can be used later.
1058        YIELDER.with(|cell| cell.set(Some(yielder.into())));
1059
1060        Ok(f())
1061    });
1062
1063    // Ensure that YIELDER is reset on exit even if the coroutine panics,
1064    defer! {
1065        YIELDER.with(|cell| cell.set(None));
1066    }
1067
1068    coro.scope(|mut coro_ref| {
1069        // Set up metadata for the trap handler for the duration of the coroutine
1070        // execution. This is restored to its previous value afterwards.
1071        TrapHandlerContext::install(trap_handler, coro_ref.trap_handler(), || {
1072            match coro_ref.resume(()) {
1073                CoroutineResult::Yield(trap) => {
1074                    // This came from unwind_with which requires that there be only
1075                    // Wasm code on the stack.
1076                    unsafe {
1077                        coro_ref.force_reset();
1078                    }
1079                    Err(trap)
1080                }
1081                CoroutineResult::Return(result) => result,
1082            }
1083        })
1084    })
1085}
1086
1087/// When executing on the Wasm stack, temporarily switch back to the host stack
1088/// to perform an operation that should not be constrained by the Wasm stack
1089/// limits.
1090///
1091/// This is particularly important since the usage of the Wasm stack is under
1092/// the control of untrusted code. Malicious code could artificially induce a
1093/// stack overflow in the middle of a sensitive host operations (e.g. growing
1094/// a memory) which would be hard to recover from.
1095pub fn on_host_stack<F: FnOnce() -> T, T>(f: F) -> T {
1096    // Reset YIEDER to None for the duration of this call to indicate that we
1097    // are no longer on the Wasm stack.
1098    let yielder_ptr = YIELDER.with(|cell| cell.replace(None));
1099
1100    // If we are already on the host stack, execute the function directly. This
1101    // happens if a host function is called directly from the API.
1102    let yielder = match yielder_ptr {
1103        Some(ptr) => unsafe { ptr.as_ref() },
1104        None => return f(),
1105    };
1106
1107    // Restore YIELDER upon exiting normally or unwinding.
1108    defer! {
1109        YIELDER.with(|cell| cell.set(yielder_ptr));
1110    }
1111
1112    // on_parent_stack requires the closure to be Send so that the Yielder
1113    // cannot be called from the parent stack. This is not a problem for us
1114    // since we don't expose the Yielder.
1115    struct SendWrapper<T>(T);
1116    unsafe impl<T> Send for SendWrapper<T> {}
1117    let wrapped = SendWrapper(f);
1118    yielder.on_parent_stack(move || {
1119        let wrapped = wrapped;
1120        (wrapped.0)()
1121    })
1122}
1123
1124#[cfg(windows)]
1125pub fn lazy_per_thread_init() -> Result<(), Trap> {
1126    // We need additional space on the stack to handle stack overflow
1127    // exceptions. Rust's initialization code sets this to 0x5000 but this
1128    // seems to be insufficient in practice.
1129    use windows_sys::Win32::System::Threading::SetThreadStackGuarantee;
1130    if unsafe { SetThreadStackGuarantee(&mut 0x10000) } == 0 {
1131        panic!("failed to set thread stack guarantee");
1132    }
1133
1134    Ok(())
1135}
1136
1137/// A module for registering a custom alternate signal stack (sigaltstack).
1138///
1139/// Rust's libstd installs an alternate stack with size `SIGSTKSZ`, which is not
1140/// always large enough for our signal handling code. Override it by creating
1141/// and registering our own alternate stack that is large enough and has a guard
1142/// page.
1143#[cfg(unix)]
1144pub fn lazy_per_thread_init() -> Result<(), Trap> {
1145    use std::ptr::null_mut;
1146
1147    thread_local! {
1148        /// Thread-local state is lazy-initialized on the first time it's used,
1149        /// and dropped when the thread exits.
1150        static TLS: Tls = unsafe { init_sigstack() };
1151    }
1152
1153    /// The size of the sigaltstack (not including the guard, which will be
1154    /// added). Make this large enough to run our signal handlers.
1155    const MIN_STACK_SIZE: usize = ByteSize::kib(64).as_u64() as usize;
1156
1157    enum Tls {
1158        OutOfMemory,
1159        Allocated {
1160            mmap_ptr: *mut libc::c_void,
1161            mmap_size: usize,
1162        },
1163        BigEnough,
1164    }
1165
1166    unsafe fn init_sigstack() -> Tls {
1167        unsafe {
1168            // Check to see if the existing sigaltstack, if it exists, is big
1169            // enough. If so we don't need to allocate our own.
1170            let mut old_stack = mem::zeroed();
1171            let r = libc::sigaltstack(ptr::null(), &mut old_stack);
1172            assert_eq!(r, 0, "learning about sigaltstack failed");
1173            if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE {
1174                return Tls::BigEnough;
1175            }
1176
1177            // ... but failing that we need to allocate our own, so do all that
1178            // here.
1179            let page_size: usize = region::page::size();
1180            let guard_size = page_size;
1181            let alloc_size = guard_size + MIN_STACK_SIZE;
1182
1183            let ptr = libc::mmap(
1184                null_mut(),
1185                alloc_size,
1186                libc::PROT_NONE,
1187                libc::MAP_PRIVATE | libc::MAP_ANON,
1188                -1,
1189                0,
1190            );
1191            if ptr == libc::MAP_FAILED {
1192                return Tls::OutOfMemory;
1193            }
1194
1195            // Prepare the stack with readable/writable memory and then register it
1196            // with `sigaltstack`.
1197            let stack_ptr = (ptr as usize + guard_size) as *mut libc::c_void;
1198            let r = libc::mprotect(
1199                stack_ptr,
1200                MIN_STACK_SIZE,
1201                libc::PROT_READ | libc::PROT_WRITE,
1202            );
1203            assert_eq!(r, 0, "mprotect to configure memory for sigaltstack failed");
1204            let new_stack = libc::stack_t {
1205                ss_sp: stack_ptr,
1206                ss_flags: 0,
1207                ss_size: MIN_STACK_SIZE,
1208            };
1209            let r = libc::sigaltstack(&new_stack, ptr::null_mut());
1210            assert_eq!(r, 0, "registering new sigaltstack failed");
1211
1212            Tls::Allocated {
1213                mmap_ptr: ptr,
1214                mmap_size: alloc_size,
1215            }
1216        }
1217    }
1218
1219    // Ensure TLS runs its initializer and return an error if it failed to
1220    // set up a separate stack for signal handlers.
1221    return TLS.with(|tls| {
1222        if let Tls::OutOfMemory = tls {
1223            Err(Trap::oom())
1224        } else {
1225            Ok(())
1226        }
1227    });
1228
1229    impl Drop for Tls {
1230        fn drop(&mut self) {
1231            let (ptr, size) = match self {
1232                Self::Allocated {
1233                    mmap_ptr,
1234                    mmap_size,
1235                } => (*mmap_ptr, *mmap_size),
1236                _ => return,
1237            };
1238            unsafe {
1239                // Deallocate the stack memory.
1240                let r = libc::munmap(ptr, size);
1241                debug_assert_eq!(r, 0, "munmap failed during thread shutdown");
1242            }
1243        }
1244    }
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249    use super::*;
1250    use std::sync::Mutex;
1251
1252    // Guards tests that mutate global state (DEFAULT_STACK_SIZE, STACK_POOL).
1253    // Rust runs tests in parallel by default; this mutex serializes them so
1254    // they don't step on each other.
1255    static GLOBAL_STATE: Mutex<()> = Mutex::new(());
1256
1257    /// Saves the current stack size and restores it on drop (even on panic).
1258    struct RestoreStackSize(usize);
1259    impl Drop for RestoreStackSize {
1260        fn drop(&mut self) {
1261            set_stack_size(self.0);
1262        }
1263    }
1264
1265    #[test]
1266    fn max_stack_size_is_100mb() {
1267        assert_eq!(MAX_STACK_SIZE, ByteSize::mib(100).as_u64() as usize);
1268    }
1269
1270    #[test]
1271    fn get_set_stack_size_roundtrip() {
1272        let _lock = GLOBAL_STATE.lock().unwrap();
1273        let _restore = RestoreStackSize(get_stack_size());
1274        let new_size = ByteSize::mib(4).as_u64() as usize;
1275        set_stack_size(new_size);
1276        assert_eq!(get_stack_size(), new_size);
1277    }
1278
1279    #[test]
1280    fn set_stack_size_clamps_to_min() {
1281        let _lock = GLOBAL_STATE.lock().unwrap();
1282        let _restore = RestoreStackSize(get_stack_size());
1283        set_stack_size(1); // way below 8 KiB minimum
1284        assert_eq!(get_stack_size(), ByteSize::kib(8).as_u64() as usize);
1285    }
1286
1287    #[test]
1288    fn set_stack_size_clamps_to_max() {
1289        let _lock = GLOBAL_STATE.lock().unwrap();
1290        let _restore = RestoreStackSize(get_stack_size());
1291        set_stack_size(usize::MAX);
1292        assert_eq!(get_stack_size(), MAX_STACK_SIZE);
1293    }
1294
1295    #[test]
1296    fn drain_stack_pool_empties_pool() {
1297        let _lock = GLOBAL_STATE.lock().unwrap();
1298        let stack = DefaultStack::new(ByteSize::mib(1).as_u64() as usize).unwrap();
1299        STACK_POOL.push(stack);
1300        assert!(!STACK_POOL.is_empty());
1301        drain_stack_pool();
1302        assert!(STACK_POOL.is_empty());
1303    }
1304
1305    #[test]
1306    fn drain_stack_pool_is_idempotent() {
1307        let _lock = GLOBAL_STATE.lock().unwrap();
1308        drain_stack_pool();
1309        drain_stack_pool(); // second call on empty pool should not panic
1310        assert!(STACK_POOL.is_empty());
1311    }
1312
1313    /// The stack pool is not size-aware, so after a stack size increase it keeps
1314    /// serving cached undersized stacks. `drain_stack_pool()` breaks the cycle.
1315    ///
1316    /// 1. A call fills the pool with 500 KiB stacks (simulating normal execution).
1317    /// 2. The caller doubles the default to 1 MiB (simulating overflow retry).
1318    /// 3. WITHOUT draining, the pool still hands back a 500 KiB stack — the
1319    ///    retry would overflow again, creating an infinite loop.
1320    /// 4. After `drain_stack_pool()`, the pool is empty and the next allocation
1321    ///    must use the new, larger size.
1322    #[test]
1323    fn pool_returns_stale_stack_without_drain() {
1324        let _lock = GLOBAL_STATE.lock().unwrap();
1325        let _restore = RestoreStackSize(get_stack_size());
1326        drain_stack_pool();
1327
1328        // --- Phase 1: simulate normal execution that returns a 500 KiB stack ---
1329        let small_size = ByteSize::kib(500).as_u64() as usize;
1330        let small_stack = DefaultStack::new(small_size).unwrap();
1331        STACK_POOL.push(small_stack);
1332
1333        // --- Phase 2: "overflow detected" — caller doubles the default ---
1334        let big_size = ByteSize::mib(1).as_u64() as usize;
1335        set_stack_size(big_size);
1336        assert_eq!(get_stack_size(), big_size);
1337
1338        // --- Phase 3: WITHOUT drain, pool still returns the old small stack ---
1339        // This is the bug: the caller asked for a bigger stack but the pool
1340        // serves a cached undersized one, causing the retry to overflow again.
1341        let stale = STACK_POOL.pop();
1342        assert!(
1343            stale.is_some(),
1344            "pool should still contain the old stack (the bug scenario)"
1345        );
1346
1347        // --- Phase 4: with drain, pool is empty — next alloc uses new size ---
1348        STACK_POOL.push(stale.unwrap());
1349        drain_stack_pool();
1350        assert!(
1351            STACK_POOL.pop().is_none(),
1352            "after drain, pool must be empty so a fresh stack is allocated at the new size"
1353        );
1354    }
1355
1356    /// `on_wasm_stack` discards undersized stacks from the pool and allocates
1357    /// a fresh one instead of blindly reusing whatever the pool returns.
1358    #[test]
1359    fn on_wasm_stack_discards_undersized_stack() {
1360        let _lock = GLOBAL_STATE.lock().unwrap();
1361        let _restore = RestoreStackSize(get_stack_size());
1362        drain_stack_pool();
1363        clear_tls_stack();
1364
1365        // Push an undersized stack into the pool.
1366        let small_size = ByteSize::kib(500).as_u64() as usize;
1367        let small_stack = DefaultStack::new(small_size).unwrap();
1368        STACK_POOL.push(small_stack);
1369
1370        // Request a larger stack via on_wasm_stack.
1371        let big_size = ByteSize::mib(1).as_u64() as usize;
1372        let result = on_wasm_stack(big_size, None, || 42);
1373
1374        assert_eq!(result.ok().expect("on_wasm_stack should succeed"), 42);
1375        // The undersized stack was discarded; the correctly-sized stack
1376        // allocated for the call now lives in the TLS cache (the hot path).
1377        // It will end up in the global pool only on thread exit or eviction.
1378        let returned = TLS_STACK
1379            .with(|cache| cache.0.take())
1380            .or_else(|| STACK_POOL.pop())
1381            .expect("stack should have been returned to TLS cache or pool");
1382        assert!(
1383            returned.size() >= big_size,
1384            "returned stack must be at least as large as the requested size"
1385        );
1386
1387        // Ensure no residual TLS state leaks into other tests sharing the
1388        // runner thread. `take()` above already cleared the slot, but be
1389        // explicit so future edits cannot drop this guarantee silently.
1390        clear_tls_stack();
1391    }
1392
1393    /// After a wasm call, the freshly-used stack stays in the thread-local
1394    /// cache so subsequent calls on the same thread reuse it without touching
1395    /// the global SegQueue.
1396    #[test]
1397    fn tls_stack_caches_after_first_call() {
1398        let _lock = GLOBAL_STATE.lock().unwrap();
1399        let _restore = RestoreStackSize(get_stack_size());
1400        drain_stack_pool();
1401        clear_tls_stack();
1402
1403        let size = get_stack_size();
1404
1405        // First call: TLS + pool both empty → allocate fresh; stack ends in TLS.
1406        assert!(on_wasm_stack(size, None, || ()).is_ok());
1407        assert!(
1408            STACK_POOL.is_empty(),
1409            "pool should still be empty after a TLS-served call"
1410        );
1411
1412        // Verify TLS holds a stack, then put it back.
1413        let cached_present = TLS_STACK.with(|cache| {
1414            let taken = cache.0.take();
1415            let present = taken.is_some();
1416            cache.0.set(taken);
1417            present
1418        });
1419        assert!(cached_present, "TLS slot should hold the post-call stack");
1420
1421        // Second call should consume from TLS; pool stays empty.
1422        assert!(on_wasm_stack(size, None, || ()).is_ok());
1423        assert!(
1424            STACK_POOL.is_empty(),
1425            "second call must not push to the global pool"
1426        );
1427        let still_cached = TLS_STACK.with(|cache| {
1428            let taken = cache.0.take();
1429            let present = taken.is_some();
1430            cache.0.set(taken);
1431            present
1432        });
1433        assert!(
1434            still_cached,
1435            "TLS slot should still hold a stack after the second call"
1436        );
1437
1438        // Cleanup: clear TLS so we don't leak into other tests. The cached
1439        // stack here is dropped rather than returned to the pool; the next
1440        // test starts with `drain_stack_pool()` anyway, so there is no
1441        // observable difference.
1442        clear_tls_stack();
1443    }
1444
1445    /// On thread exit, the TLS cache's `Drop` impl returns the held stack to
1446    /// the global pool so memory cycles correctly across thread lifetimes.
1447    #[test]
1448    fn tls_stack_returns_to_pool_on_thread_exit() {
1449        // GLOBAL_STATE is the test-suite mutex used to serialize tests that
1450        // touch shared global state (STACK_POOL, the configured stack size).
1451        // The spawned worker thread does NOT touch GLOBAL_STATE — it only
1452        // calls `on_wasm_stack`, which takes neither this mutex nor any
1453        // other lock that could contend with us.
1454        //
1455        // Even so, holding the guard across `handle.join()` is unnecessary:
1456        // the only thing that needs to be serialized against other tests is
1457        // the assertion on `STACK_POOL.pop()` AFTER the join. We release the
1458        // guard before joining so future edits to `on_wasm_stack` that
1459        // happen to touch this lock can't introduce a hard-to-debug
1460        // deadlock here.
1461        let lock = GLOBAL_STATE.lock().unwrap();
1462        let _restore = RestoreStackSize(get_stack_size());
1463        drain_stack_pool();
1464        clear_tls_stack();
1465
1466        let size = get_stack_size();
1467        drop(lock);
1468
1469        let handle = std::thread::spawn(move || {
1470            assert!(on_wasm_stack(size, None, || ()).is_ok());
1471        });
1472        handle.join().unwrap();
1473
1474        let _lock = GLOBAL_STATE.lock().unwrap();
1475        // The spawned thread's TLS cache was dropped on join; the stack must
1476        // have made it back to the global pool.
1477        let returned = STACK_POOL
1478            .pop()
1479            .expect("thread exit should return TLS-cached stack to the global pool");
1480        assert!(returned.size() >= size);
1481    }
1482
1483    // -----------------------------------------------------------------
1484    // Test helpers
1485    // -----------------------------------------------------------------
1486
1487    /// Clears the current thread's TLS slot so tests don't see state from
1488    /// previous tests (they share the same thread under cargo test's
1489    /// per-test serialization via `GLOBAL_STATE`).
1490    fn clear_tls_stack() {
1491        TLS_STACK.with(|cache| cache.0.set(None));
1492    }
1493
1494    /// `base().get() - limit().get()` (i.e. `Stack::size`) is constant per
1495    /// `DefaultStack` instance, but `base().get()` itself uniquely identifies
1496    /// the mmap allocation. We use it as a cheap identity check to see which
1497    /// stack was returned by acquire/release.
1498    fn stack_id(stack: &DefaultStack) -> usize {
1499        stack.base().get()
1500    }
1501
1502    // -----------------------------------------------------------------
1503    // acquire_stack mechanics
1504    // -----------------------------------------------------------------
1505
1506    #[test]
1507    fn acquire_allocates_fresh_when_tls_and_pool_empty() {
1508        let _lock = GLOBAL_STATE.lock().unwrap();
1509        let _restore = RestoreStackSize(get_stack_size());
1510        drain_stack_pool();
1511        clear_tls_stack();
1512
1513        let size = get_stack_size();
1514        let stack = acquire_stack(size);
1515        assert!(
1516            stack.size() >= size,
1517            "freshly allocated stack must satisfy min_size"
1518        );
1519
1520        drop(stack);
1521        clear_tls_stack();
1522        drain_stack_pool();
1523    }
1524
1525    #[test]
1526    fn acquire_prefers_tls_over_pool() {
1527        let _lock = GLOBAL_STATE.lock().unwrap();
1528        let _restore = RestoreStackSize(get_stack_size());
1529        drain_stack_pool();
1530        clear_tls_stack();
1531
1532        let size = get_stack_size();
1533        let tls_stack = DefaultStack::new(size).unwrap();
1534        let tls_id = stack_id(&tls_stack);
1535        TLS_STACK.with(|cache| cache.0.set(Some(tls_stack)));
1536
1537        let pool_stack = DefaultStack::new(size).unwrap();
1538        let pool_id = stack_id(&pool_stack);
1539        STACK_POOL.push(pool_stack);
1540
1541        let got = acquire_stack(size);
1542        assert_eq!(stack_id(&got), tls_id, "acquire must prefer TLS over pool");
1543        assert_ne!(stack_id(&got), pool_id);
1544
1545        drop(got);
1546        clear_tls_stack();
1547        drain_stack_pool();
1548    }
1549
1550    #[test]
1551    fn acquire_uses_pool_when_tls_empty() {
1552        let _lock = GLOBAL_STATE.lock().unwrap();
1553        let _restore = RestoreStackSize(get_stack_size());
1554        drain_stack_pool();
1555        clear_tls_stack();
1556
1557        let size = get_stack_size();
1558        let pool_stack = DefaultStack::new(size).unwrap();
1559        let pool_id = stack_id(&pool_stack);
1560        STACK_POOL.push(pool_stack);
1561
1562        let got = acquire_stack(size);
1563        assert_eq!(
1564            stack_id(&got),
1565            pool_id,
1566            "acquire must consume from pool when TLS is empty"
1567        );
1568        assert!(
1569            STACK_POOL.is_empty(),
1570            "pool stack must be removed when used"
1571        );
1572
1573        drop(got);
1574        clear_tls_stack();
1575        drain_stack_pool();
1576    }
1577
1578    #[test]
1579    fn acquire_discards_undersized_tls_then_allocates() {
1580        let _lock = GLOBAL_STATE.lock().unwrap();
1581        let _restore = RestoreStackSize(get_stack_size());
1582        drain_stack_pool();
1583        clear_tls_stack();
1584
1585        let small_size = ByteSize::kib(512).as_u64() as usize;
1586        let undersized = DefaultStack::new(small_size).unwrap();
1587        TLS_STACK.with(|cache| cache.0.set(Some(undersized)));
1588
1589        let big_size = ByteSize::mib(2).as_u64() as usize;
1590        let got = acquire_stack(big_size);
1591
1592        // The acquired stack must be at least the requested size. We do NOT
1593        // compare base addresses: the OS can reuse a freshly munmap'd
1594        // virtual address for the next mmap, so pointer identity is not a
1595        // reliable "is this a different stack" check across a drop+alloc.
1596        // The meaningful semantic is that the undersized stack was taken
1597        // out of rotation (TLS empty, not silently pushed to the pool) and
1598        // the returned stack is sized correctly.
1599        assert!(
1600            got.size() >= big_size,
1601            "acquired stack must satisfy big_size"
1602        );
1603        let tls_empty = TLS_STACK.with(|cache| {
1604            let s = cache.0.take();
1605            let empty = s.is_none();
1606            cache.0.set(s);
1607            empty
1608        });
1609        assert!(
1610            tls_empty,
1611            "undersized TLS stack must have been taken and discarded"
1612        );
1613        assert!(
1614            STACK_POOL.is_empty(),
1615            "undersized TLS stack must be discarded, not pushed to the pool",
1616        );
1617
1618        drop(got);
1619        clear_tls_stack();
1620        drain_stack_pool();
1621    }
1622
1623    #[test]
1624    fn acquire_discards_undersized_pool_then_allocates() {
1625        let _lock = GLOBAL_STATE.lock().unwrap();
1626        let _restore = RestoreStackSize(get_stack_size());
1627        drain_stack_pool();
1628        clear_tls_stack();
1629
1630        let small_size = ByteSize::kib(512).as_u64() as usize;
1631        let undersized = DefaultStack::new(small_size).unwrap();
1632        STACK_POOL.push(undersized);
1633
1634        let big_size = ByteSize::mib(2).as_u64() as usize;
1635        let got = acquire_stack(big_size);
1636
1637        // Same caveat as the TLS variant: mmap may reuse the virtual
1638        // address of the dropped undersized stack for the new big stack,
1639        // so we verify the semantic outcome — the pool was drained of the
1640        // undersized entry and the returned stack is sized correctly.
1641        assert!(
1642            got.size() >= big_size,
1643            "acquired stack must satisfy big_size"
1644        );
1645        assert!(
1646            STACK_POOL.is_empty(),
1647            "undersized pool stack must have been popped, filtered out and dropped",
1648        );
1649
1650        drop(got);
1651        clear_tls_stack();
1652        drain_stack_pool();
1653    }
1654
1655    // -----------------------------------------------------------------
1656    // release_stack mechanics
1657    // -----------------------------------------------------------------
1658
1659    #[test]
1660    fn release_into_empty_tls_caches_there() {
1661        let _lock = GLOBAL_STATE.lock().unwrap();
1662        drain_stack_pool();
1663        clear_tls_stack();
1664
1665        let size = get_stack_size();
1666        let stack = DefaultStack::new(size).unwrap();
1667        let id = stack_id(&stack);
1668        release_stack(stack);
1669
1670        let in_tls = TLS_STACK
1671            .with(|cache| cache.0.take())
1672            .expect("release into empty TLS should leave the stack in TLS");
1673        assert_eq!(stack_id(&in_tls), id);
1674        assert!(
1675            STACK_POOL.is_empty(),
1676            "pool must not be touched when TLS is empty"
1677        );
1678
1679        drain_stack_pool();
1680    }
1681
1682    #[test]
1683    fn release_into_occupied_tls_displaces_older_to_pool() {
1684        let _lock = GLOBAL_STATE.lock().unwrap();
1685        drain_stack_pool();
1686        clear_tls_stack();
1687
1688        let size = get_stack_size();
1689        let older = DefaultStack::new(size).unwrap();
1690        let older_id = stack_id(&older);
1691        TLS_STACK.with(|cache| cache.0.set(Some(older)));
1692
1693        let newer = DefaultStack::new(size).unwrap();
1694        let newer_id = stack_id(&newer);
1695        release_stack(newer);
1696
1697        let in_tls = TLS_STACK
1698            .with(|cache| cache.0.take())
1699            .expect("TLS should hold the newly-released stack");
1700        assert_eq!(
1701            stack_id(&in_tls),
1702            newer_id,
1703            "newer stack must displace into TLS"
1704        );
1705
1706        let displaced = STACK_POOL
1707            .pop()
1708            .expect("older stack should have been pushed to global pool");
1709        assert_eq!(
1710            stack_id(&displaced),
1711            older_id,
1712            "displaced stack must be the older one"
1713        );
1714
1715        drain_stack_pool();
1716    }
1717
1718    // -----------------------------------------------------------------
1719    // drain_stack_pool extended semantics
1720    // -----------------------------------------------------------------
1721
1722    #[test]
1723    fn drain_stack_pool_clears_calling_thread_tls_slot() {
1724        let _lock = GLOBAL_STATE.lock().unwrap();
1725        drain_stack_pool();
1726        clear_tls_stack();
1727
1728        let stack = DefaultStack::new(get_stack_size()).unwrap();
1729        TLS_STACK.with(|cache| cache.0.set(Some(stack)));
1730
1731        drain_stack_pool();
1732
1733        let tls_empty = TLS_STACK.with(|cache| cache.0.take().is_none());
1734        assert!(
1735            tls_empty,
1736            "drain_stack_pool must also clear current thread's TLS slot"
1737        );
1738        assert!(STACK_POOL.is_empty());
1739    }
1740
1741    // -----------------------------------------------------------------
1742    // on_wasm_stack functional behavior
1743    // -----------------------------------------------------------------
1744
1745    #[test]
1746    fn on_wasm_stack_passes_closure_value_back() {
1747        let _lock = GLOBAL_STATE.lock().unwrap();
1748        let _restore = RestoreStackSize(get_stack_size());
1749        drain_stack_pool();
1750        clear_tls_stack();
1751
1752        let r = on_wasm_stack(get_stack_size(), None, || 12345u32);
1753        assert_eq!(r.ok(), Some(12345));
1754
1755        clear_tls_stack();
1756        drain_stack_pool();
1757    }
1758
1759    #[test]
1760    fn on_wasm_stack_passes_owning_result_back() {
1761        let _lock = GLOBAL_STATE.lock().unwrap();
1762        let _restore = RestoreStackSize(get_stack_size());
1763        drain_stack_pool();
1764        clear_tls_stack();
1765
1766        // Use a heap-allocated value to verify the move-out path: the
1767        // closure produces a `Vec<u8>` that must travel from the coroutine
1768        // stack back to the host.
1769        let r = on_wasm_stack(get_stack_size(), None, || vec![0u8, 1, 2, 3, 4]);
1770        assert_eq!(r.ok(), Some(vec![0u8, 1, 2, 3, 4]));
1771
1772        clear_tls_stack();
1773        drain_stack_pool();
1774    }
1775
1776    #[test]
1777    fn many_calls_do_not_grow_global_pool() {
1778        let _lock = GLOBAL_STATE.lock().unwrap();
1779        let _restore = RestoreStackSize(get_stack_size());
1780        drain_stack_pool();
1781        clear_tls_stack();
1782
1783        // With TLS caching, repeated single-threaded calls must keep
1784        // reusing the same TLS stack and never push to the global pool.
1785        for _ in 0..1000 {
1786            assert!(on_wasm_stack(get_stack_size(), None, || ()).is_ok());
1787        }
1788        assert!(
1789            STACK_POOL.is_empty(),
1790            "1000 sequential calls should not grow the global pool (TLS handles reuse)"
1791        );
1792
1793        clear_tls_stack();
1794        drain_stack_pool();
1795    }
1796
1797    // -----------------------------------------------------------------
1798    // Trap and unwind paths
1799    // -----------------------------------------------------------------
1800
1801    #[test]
1802    fn raise_user_trap_yields_err() {
1803        let _lock = GLOBAL_STATE.lock().unwrap();
1804        let _restore = RestoreStackSize(get_stack_size());
1805        drain_stack_pool();
1806        clear_tls_stack();
1807
1808        let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1809            raise_user_trap(Box::new(io::Error::other("user trap from test")));
1810        });
1811        assert!(r.is_err(), "raise_user_trap must produce Err");
1812
1813        clear_tls_stack();
1814        drain_stack_pool();
1815    }
1816
1817    #[test]
1818    fn raise_lib_trap_yields_err() {
1819        let _lock = GLOBAL_STATE.lock().unwrap();
1820        let _restore = RestoreStackSize(get_stack_size());
1821        drain_stack_pool();
1822        clear_tls_stack();
1823
1824        let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1825            raise_lib_trap(Trap::lib(TrapCode::IntegerDivisionByZero));
1826        });
1827        assert!(r.is_err(), "raise_lib_trap must produce Err");
1828
1829        clear_tls_stack();
1830        drain_stack_pool();
1831    }
1832
1833    #[test]
1834    fn resume_panic_yields_err_without_unwinding() {
1835        // `resume_panic` packages the payload as `UnwindReason::Panic`. The
1836        // host-side panic resumption lives in `UnwindReason::into_trap`,
1837        // which we do NOT call here — `on_wasm_stack` itself just returns
1838        // the Err, so the test does not actually panic.
1839        let _lock = GLOBAL_STATE.lock().unwrap();
1840        let _restore = RestoreStackSize(get_stack_size());
1841        drain_stack_pool();
1842        clear_tls_stack();
1843
1844        let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1845            resume_panic(Box::new("panic payload from test"));
1846        });
1847        assert!(
1848            r.is_err(),
1849            "resume_panic must surface as Err to on_wasm_stack"
1850        );
1851
1852        clear_tls_stack();
1853        drain_stack_pool();
1854    }
1855
1856    #[test]
1857    fn trap_does_not_corrupt_subsequent_calls() {
1858        // After a trap, the per-call coroutine is force-reset and dropped.
1859        // The TLS stack cache and global pool must remain in a usable state
1860        // so that subsequent calls succeed.
1861        let _lock = GLOBAL_STATE.lock().unwrap();
1862        let _restore = RestoreStackSize(get_stack_size());
1863        drain_stack_pool();
1864        clear_tls_stack();
1865
1866        let trapped: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1867            raise_user_trap(Box::new(io::Error::other("first call traps")));
1868        });
1869        assert!(trapped.is_err());
1870
1871        // Subsequent normal call must still succeed.
1872        let ok = on_wasm_stack(get_stack_size(), None, || 7u32);
1873        assert_eq!(ok.ok(), Some(7), "calls after a trap must still work");
1874
1875        clear_tls_stack();
1876        drain_stack_pool();
1877    }
1878
1879    // -----------------------------------------------------------------
1880    // on_host_stack
1881    // -----------------------------------------------------------------
1882
1883    #[test]
1884    fn on_host_stack_outside_coroutine_runs_inline() {
1885        // `on_host_stack` outside any wasm coroutine just runs `f()` directly
1886        // (no stack switch); this asserts the no-yielder branch still works.
1887        let _lock = GLOBAL_STATE.lock().unwrap();
1888        let n = on_host_stack(|| 99i32);
1889        assert_eq!(n, 99);
1890    }
1891
1892    #[test]
1893    fn on_host_stack_inside_wasm_switches_and_returns() {
1894        let _lock = GLOBAL_STATE.lock().unwrap();
1895        let _restore = RestoreStackSize(get_stack_size());
1896        drain_stack_pool();
1897        clear_tls_stack();
1898
1899        let r = on_wasm_stack(get_stack_size(), None, || on_host_stack(|| 88i32));
1900        assert_eq!(r.ok(), Some(88));
1901
1902        clear_tls_stack();
1903        drain_stack_pool();
1904    }
1905
1906    // -----------------------------------------------------------------
1907    // Re-entrancy
1908    // -----------------------------------------------------------------
1909
1910    #[test]
1911    fn reentrant_call_returns_value() {
1912        let _lock = GLOBAL_STATE.lock().unwrap();
1913        let _restore = RestoreStackSize(get_stack_size());
1914        drain_stack_pool();
1915        clear_tls_stack();
1916
1917        let outer = on_wasm_stack(get_stack_size(), None, || {
1918            on_wasm_stack(get_stack_size(), None, || 42i32)
1919                .ok()
1920                .expect("inner must succeed")
1921        });
1922        assert_eq!(outer.ok(), Some(42));
1923
1924        clear_tls_stack();
1925        drain_stack_pool();
1926    }
1927
1928    #[test]
1929    fn reentrant_calls_run_to_completion_under_pool_pressure() {
1930        // The outer call's stack is held by its scopeguard for the duration
1931        // of the call; the inner call must therefore pop a separate stack
1932        // from the pool (or allocate one). With a pre-populated pool the
1933        // inner call should consume that stack; either way the nested call
1934        // chain must complete without deadlocking or panicking from
1935        // corosensei.
1936        use std::sync::Arc;
1937        use std::sync::atomic::{AtomicUsize, Ordering as O};
1938
1939        let _lock = GLOBAL_STATE.lock().unwrap();
1940        let _restore = RestoreStackSize(get_stack_size());
1941        drain_stack_pool();
1942        clear_tls_stack();
1943
1944        // Pre-populate the pool with one stack so the inner call can grab
1945        // it instead of allocating.
1946        let pre = DefaultStack::new(get_stack_size()).unwrap();
1947        STACK_POOL.push(pre);
1948
1949        let inner_completed = Arc::new(AtomicUsize::new(0));
1950        let inner_completed_outer = inner_completed.clone();
1951        let _ = on_wasm_stack(get_stack_size(), None, move || {
1952            let inner_completed = inner_completed_outer.clone();
1953            let inner = on_wasm_stack(get_stack_size(), None, move || {
1954                inner_completed.fetch_add(1, O::Relaxed);
1955            });
1956            assert!(inner.is_ok(), "inner re-entrant call must succeed");
1957        });
1958        assert_eq!(
1959            inner_completed.load(O::Relaxed),
1960            1,
1961            "inner closure must have executed exactly once",
1962        );
1963
1964        clear_tls_stack();
1965        drain_stack_pool();
1966    }
1967
1968    #[test]
1969    fn reentrant_inner_trap_does_not_kill_outer() {
1970        let _lock = GLOBAL_STATE.lock().unwrap();
1971        let _restore = RestoreStackSize(get_stack_size());
1972        drain_stack_pool();
1973        clear_tls_stack();
1974
1975        let outer = on_wasm_stack(get_stack_size(), None, || {
1976            let inner: Result<i32, UnwindReason> =
1977                on_wasm_stack(get_stack_size(), None, || unsafe {
1978                    raise_user_trap(Box::new(io::Error::other("inner trap")));
1979                });
1980            // Outer observes inner's Err and recovers.
1981            match inner {
1982                Err(_) => 1234i32,
1983                Ok(_) => panic!("inner should have trapped"),
1984            }
1985        });
1986        assert_eq!(
1987            outer.ok(),
1988            Some(1234),
1989            "outer must recover after inner trap and run to completion"
1990        );
1991
1992        clear_tls_stack();
1993        drain_stack_pool();
1994    }
1995
1996    #[test]
1997    fn reentrant_with_on_host_stack_in_between() {
1998        // Outer wasm → on_host_stack → inner wasm. This exercises the
1999        // YIELDER save/restore in `unwind_with` / `on_host_stack` against
2000        // a re-entrant boundary.
2001        let _lock = GLOBAL_STATE.lock().unwrap();
2002        let _restore = RestoreStackSize(get_stack_size());
2003        drain_stack_pool();
2004        clear_tls_stack();
2005
2006        let r = on_wasm_stack(get_stack_size(), None, || {
2007            on_host_stack(|| {
2008                on_wasm_stack(get_stack_size(), None, || 5i32)
2009                    .ok()
2010                    .expect("nested inner must succeed")
2011            })
2012        });
2013        assert_eq!(r.ok(), Some(5));
2014
2015        clear_tls_stack();
2016        drain_stack_pool();
2017    }
2018
2019    // -----------------------------------------------------------------
2020    // Concurrency
2021    // -----------------------------------------------------------------
2022
2023    #[test]
2024    fn many_threads_in_parallel_all_succeed() {
2025        let _lock = GLOBAL_STATE.lock().unwrap();
2026        let _restore = RestoreStackSize(get_stack_size());
2027        drain_stack_pool();
2028
2029        use std::sync::Arc;
2030        use std::sync::atomic::{AtomicUsize, Ordering as O};
2031
2032        let counter = Arc::new(AtomicUsize::new(0));
2033        const THREADS: usize = 8;
2034        const CALLS_PER_THREAD: usize = 200;
2035
2036        let handles: Vec<_> = (0..THREADS)
2037            .map(|_| {
2038                let counter = counter.clone();
2039                std::thread::spawn(move || {
2040                    let size = get_stack_size();
2041                    for _ in 0..CALLS_PER_THREAD {
2042                        if on_wasm_stack(size, None, || 1u32).ok() == Some(1) {
2043                            counter.fetch_add(1, O::Relaxed);
2044                        }
2045                    }
2046                })
2047            })
2048            .collect();
2049        for h in handles {
2050            h.join().unwrap();
2051        }
2052
2053        assert_eq!(counter.load(O::Relaxed), THREADS * CALLS_PER_THREAD);
2054        // Pool should now hold at most `THREADS` stacks (one per thread that
2055        // exited). Each thread also drops its TLS slot on exit, which pushes
2056        // the stack to the pool.
2057        let mut pooled = 0usize;
2058        while STACK_POOL.pop().is_some() {
2059            pooled += 1;
2060        }
2061        assert!(
2062            pooled <= THREADS,
2063            "pool should hold at most one stack per terminated thread (got {pooled} for {THREADS} threads)"
2064        );
2065
2066        clear_tls_stack();
2067        drain_stack_pool();
2068    }
2069
2070    // -----------------------------------------------------------------
2071    // Stack size dynamics
2072    // -----------------------------------------------------------------
2073
2074    #[test]
2075    fn growing_request_discards_smaller_tls_stack() {
2076        let _lock = GLOBAL_STATE.lock().unwrap();
2077        let _restore = RestoreStackSize(get_stack_size());
2078        drain_stack_pool();
2079        clear_tls_stack();
2080
2081        // First call at a small size populates TLS with a small stack.
2082        let small = ByteSize::mib(1).as_u64() as usize;
2083        set_stack_size(small);
2084        assert!(on_wasm_stack(small, None, || ()).is_ok());
2085
2086        let cached_size = TLS_STACK.with(|cache| {
2087            let s = cache.0.take();
2088            let sz = s.as_ref().map_or(0, |s| s.size());
2089            cache.0.set(s);
2090            sz
2091        });
2092        assert!(
2093            cached_size >= small,
2094            "TLS should hold the small-sized stack"
2095        );
2096
2097        // Now request a larger stack. acquire_stack should discard the TLS
2098        // entry and either pop a big-enough one from pool or allocate.
2099        let big = ByteSize::mib(4).as_u64() as usize;
2100        set_stack_size(big);
2101        assert!(on_wasm_stack(big, None, || ()).is_ok());
2102
2103        // The TLS slot should now hold a stack that's big enough.
2104        let cached_size = TLS_STACK.with(|cache| {
2105            let s = cache.0.take();
2106            let sz = s.as_ref().map_or(0, |s| s.size());
2107            cache.0.set(s);
2108            sz
2109        });
2110        assert!(
2111            cached_size >= big,
2112            "TLS should hold the bigger stack after size bump"
2113        );
2114
2115        clear_tls_stack();
2116        drain_stack_pool();
2117    }
2118}