1#![allow(static_mut_refs)]
5
6use super::trap::UnwindReason;
10use crate::Trap;
11#[cfg(all(unix, feature = "experimental-host-interrupt"))]
12use crate::interrupt_registry;
13use backtrace::Backtrace;
14use bytesize::ByteSize;
15use core::ptr::{read, read_unaligned};
16use corosensei::stack::{DefaultStack, Stack};
17use corosensei::trap::{CoroutineTrapHandler, TrapHandlerRegs};
18use corosensei::{CoroutineResult, ScopedCoroutine, Yielder};
19use scopeguard::defer;
20use std::any::Any;
21use std::cell::Cell;
22use std::error::Error;
23use std::io;
24use std::mem;
25#[cfg(unix)]
26use std::mem::MaybeUninit;
27use std::ptr::{self, NonNull};
28use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering, compiler_fence};
29use std::sync::{LazyLock, Once};
30use wasmer_types::TrapCode;
31
32trait StackExt: Stack {
34 fn size(&self) -> usize {
36 self.base().get() - self.limit().get()
37 }
38}
39impl<T: Stack> StackExt for T {}
40
41pub struct VMConfig {
44 pub wasm_stack_size: Option<usize>,
46}
47
48static MAGIC: u8 = 0xc0;
52
53static DEFAULT_STACK_SIZE: AtomicUsize = AtomicUsize::new(ByteSize::mib(1).as_u64() as usize);
54
55pub const MAX_STACK_SIZE: usize = ByteSize::mib(100).as_u64() as usize;
58
59#[repr(C)]
62#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
63#[allow(non_camel_case_types)]
64struct ucontext_t {
65 uc_onstack: libc::c_int,
66 uc_sigmask: libc::sigset_t,
67 uc_stack: libc::stack_t,
68 uc_link: *mut libc::ucontext_t,
69 uc_mcsize: usize,
70 uc_mcontext: libc::mcontext_t,
71}
72
73#[cfg(all(unix, not(all(target_arch = "aarch64", target_os = "macos"))))]
74use libc::ucontext_t;
75
76pub fn set_stack_size(size: usize) {
79 DEFAULT_STACK_SIZE.store(
80 size.clamp(ByteSize::kib(8).as_u64() as usize, MAX_STACK_SIZE),
81 Ordering::Relaxed,
82 );
83}
84
85pub fn get_stack_size() -> usize {
87 DEFAULT_STACK_SIZE.load(Ordering::Relaxed)
88}
89
90static STACK_POOL: LazyLock<crossbeam_queue::SegQueue<DefaultStack>> =
94 LazyLock::new(crossbeam_queue::SegQueue::new);
95
96struct StackCache(Cell<Option<DefaultStack>>);
104
105impl Drop for StackCache {
106 fn drop(&mut self) {
107 if let Some(stack) = self.0.take() {
108 STACK_POOL.push(stack);
109 }
110 }
111}
112
113thread_local! {
114 static TLS_STACK: StackCache = const { StackCache(Cell::new(None)) };
115}
116
117fn acquire_stack(min_size: usize) -> DefaultStack {
121 if let Some(stack) = TLS_STACK.with(|cache| cache.0.take()) {
124 if stack.size() >= min_size {
125 return stack;
126 }
127 drop(stack);
130 }
131 STACK_POOL
134 .pop()
135 .filter(|s| s.size() >= min_size)
136 .unwrap_or_else(|| DefaultStack::new(min_size).unwrap())
137}
138
139fn release_stack(stack: DefaultStack) {
143 let displaced = TLS_STACK.with(|cache| cache.0.replace(Some(stack)));
144 if let Some(displaced) = displaced {
145 STACK_POOL.push(displaced);
146 }
147}
148
149pub fn drain_stack_pool() {
165 if let Some(stack) = TLS_STACK.with(|cache| cache.0.take()) {
168 drop(stack);
169 }
170 while STACK_POOL.pop().is_some() {}
171}
172
173cfg_if::cfg_if! {
174 if #[cfg(unix)] {
175 pub type TrapHandlerFn<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + Send + Sync + 'a;
177 } else if #[cfg(target_os = "windows")] {
178 pub type TrapHandlerFn<'a> = dyn Fn(*mut windows_sys::Win32::System::Diagnostics::Debug::EXCEPTION_POINTERS) -> bool + Send + Sync + 'a;
180 }
181}
182
183unsafe fn process_illegal_op(addr: usize) -> Option<TrapCode> {
185 let mut val: Option<u8> = None;
186 unsafe {
187 if cfg!(target_arch = "x86_64") {
188 val = if read(addr as *mut u8) & 0xf0 == 0x40
189 && read((addr + 1) as *mut u8) == 0x0f
190 && read((addr + 2) as *mut u8) == 0xb9
191 {
192 Some(read((addr + 3) as *mut u8))
193 } else if read(addr as *mut u8) == 0x0f && read((addr + 1) as *mut u8) == 0xb9 {
194 Some(read((addr + 2) as *mut u8))
195 } else {
196 None
197 }
198 }
199 if cfg!(target_arch = "aarch64") {
200 val = if read_unaligned(addr as *mut u32) & 0xffff0000 == 0 {
201 Some(read(addr as *mut u8))
202 } else {
203 None
204 }
205 }
206 if cfg!(target_arch = "riscv64") {
207 let addr = addr as *mut u32;
208 val = if read(addr) == 0xc0001073 {
210 let prev_insn = read(addr.sub(1));
213 if (prev_insn & 0xffff) == 0x0513 {
214 Some((prev_insn >> 20) as u8)
215 } else {
216 None
217 }
218 } else {
219 None
220 };
221 }
222 }
223
224 if cfg!(target_arch = "x86_64") || cfg!(target_arch = "aarch64") {
226 val = val.and_then(|val| {
227 if val & MAGIC == MAGIC {
228 Some(val & 0xf)
229 } else {
230 None
231 }
232 });
233 }
234
235 match val {
236 None => None,
237 Some(val) => match val {
238 0 => Some(TrapCode::StackOverflow),
239 1 => Some(TrapCode::HeapAccessOutOfBounds),
240 2 => Some(TrapCode::HeapMisaligned),
241 3 => Some(TrapCode::TableAccessOutOfBounds),
242 4 => Some(TrapCode::IndirectCallToNull),
243 5 => Some(TrapCode::BadSignature),
244 6 => Some(TrapCode::IntegerOverflow),
245 7 => Some(TrapCode::IntegerDivisionByZero),
246 8 => Some(TrapCode::BadConversionToInteger),
247 9 => Some(TrapCode::UnreachableCodeReached),
248 10 => Some(TrapCode::UnalignedAtomic),
249 _ => None,
250 },
251 }
252}
253
254cfg_if::cfg_if! {
255 if #[cfg(unix)] {
256 static mut PREV_SIGSEGV: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
257 static mut PREV_SIGBUS: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
258 static mut PREV_SIGILL: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
259 static mut PREV_SIGFPE: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
260
261 #[cfg(feature = "experimental-host-interrupt")]
262 static mut PREV_SIGUSR1: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
263
264 unsafe fn platform_init() { unsafe {
265 let register = |slot: &mut MaybeUninit<libc::sigaction>, signal: i32, nodefer: bool| {
266 let mut handler: libc::sigaction = mem::zeroed();
267 handler.sa_flags = libc::SA_SIGINFO | libc::SA_ONSTACK;
281 if nodefer {
282 handler.sa_flags |= libc::SA_NODEFER;
283 }
284 handler.sa_sigaction = trap_handler as *const () as usize;
285 libc::sigemptyset(&mut handler.sa_mask);
286 if libc::sigaction(signal, &handler, slot.as_mut_ptr()) != 0 {
287 panic!(
288 "unable to install signal handler: {}",
289 io::Error::last_os_error(),
290 );
291 }
292 };
293
294 register(&mut PREV_SIGSEGV, libc::SIGSEGV, true);
296
297 register(&mut PREV_SIGILL, libc::SIGILL, true);
299
300 #[cfg(feature = "experimental-host-interrupt")]
305 register(&mut PREV_SIGUSR1, libc::SIGUSR1, false);
306
307 if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
309 register(&mut PREV_SIGFPE, libc::SIGFPE, true);
310 }
311
312 if cfg!(target_arch = "arm") || cfg!(target_vendor = "apple") {
315 register(&mut PREV_SIGBUS, libc::SIGBUS, true);
316 }
317
318 #[cfg(target_vendor = "apple")]
321 {
322 use mach2::exception_types::*;
323 use mach2::kern_return::*;
324 use mach2::port::*;
325 use mach2::thread_status::*;
326 use mach2::traps::*;
327 use mach2::mach_types::*;
328
329 unsafe extern "C" {
330 fn task_set_exception_ports(
331 task: task_t,
332 exception_mask: exception_mask_t,
333 new_port: mach_port_t,
334 behavior: exception_behavior_t,
335 new_flavor: thread_state_flavor_t,
336 ) -> kern_return_t;
337 }
338
339 #[allow(non_snake_case)]
340 #[cfg(target_arch = "x86_64")]
341 let MACHINE_THREAD_STATE = x86_THREAD_STATE64;
342 #[allow(non_snake_case)]
343 #[cfg(target_arch = "aarch64")]
344 let MACHINE_THREAD_STATE = 6;
345
346 task_set_exception_ports(
347 mach_task_self(),
348 EXC_MASK_BAD_ACCESS | EXC_MASK_ARITHMETIC | EXC_MASK_BAD_INSTRUCTION,
349 MACH_PORT_NULL,
350 EXCEPTION_STATE_IDENTITY as exception_behavior_t,
351 MACHINE_THREAD_STATE,
352 );
353 }
354 }}
355
356 unsafe extern "C" fn trap_handler(
357 signum: libc::c_int,
358 siginfo: *mut libc::siginfo_t,
359 context: *mut libc::c_void,
360 ) { unsafe {
361 let previous = match signum {
362 libc::SIGSEGV => &PREV_SIGSEGV,
363 libc::SIGBUS => &PREV_SIGBUS,
364 libc::SIGFPE => &PREV_SIGFPE,
365 libc::SIGILL => &PREV_SIGILL,
366 #[cfg(feature = "experimental-host-interrupt")]
367 libc::SIGUSR1 => &PREV_SIGUSR1,
368 _ => panic!("unknown signal: {signum}"),
369 };
370 let maybe_fault_address = match signum {
372 libc::SIGSEGV | libc::SIGBUS => {
373 Some((*siginfo).si_addr() as usize)
374 }
375 _ => None,
376 };
377 let trap_code = match signum {
378 libc::SIGILL => {
380 let addr = (*siginfo).si_addr() as usize;
381 process_illegal_op(addr)
382 }
383 #[cfg(feature = "experimental-host-interrupt")]
384 libc::SIGUSR1 => {
385 if !interrupt_registry::on_interrupted() {
388 return;
389 }
390 Some(TrapCode::HostInterrupt)
391 }
392 _ => None,
393 };
394 let ucontext = &mut *(context as *mut ucontext_t);
395 let (pc, sp) = get_pc_sp(ucontext);
396 let handled = TrapHandlerContext::handle_trap(
397 pc,
398 sp,
399 maybe_fault_address,
400 trap_code,
401 |regs| update_context(ucontext, regs),
402 |handler| handler(signum, siginfo, context),
403 );
404
405 if handled {
406 return;
407 }
408
409 #[cfg(feature = "experimental-host-interrupt")]
412 if signum == libc::SIGUSR1 {
413 return;
414 }
415
416 let previous = &*previous.as_ptr();
426 if previous.sa_flags & libc::SA_SIGINFO != 0 {
427 mem::transmute::<
428 usize,
429 extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void),
430 >(previous.sa_sigaction)(signum, siginfo, context)
431 } else if previous.sa_sigaction == libc::SIG_DFL
432 {
433 libc::sigaction(signum, previous, ptr::null_mut());
434 } else if previous.sa_sigaction != libc::SIG_IGN {
435 mem::transmute::<usize, extern "C" fn(libc::c_int)>(
436 previous.sa_sigaction
437 )(signum)
438 }
439 }}
440
441 unsafe fn get_pc_sp(context: &ucontext_t) -> (usize, usize) {
442 let (pc, sp);
443 cfg_if::cfg_if! {
444 if #[cfg(all(
445 any(target_os = "linux", target_os = "android"),
446 target_arch = "x86_64",
447 ))] {
448 pc = context.uc_mcontext.gregs[libc::REG_RIP as usize] as usize;
449 sp = context.uc_mcontext.gregs[libc::REG_RSP as usize] as usize;
450 } else if #[cfg(all(
451 any(target_os = "linux", target_os = "android"),
452 target_arch = "x86",
453 ))] {
454 pc = context.uc_mcontext.gregs[libc::REG_EIP as usize] as usize;
455 sp = context.uc_mcontext.gregs[libc::REG_ESP as usize] as usize;
456 } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
457 pc = context.uc_mcontext.mc_eip as usize;
458 sp = context.uc_mcontext.mc_esp as usize;
459 } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
460 pc = context.uc_mcontext.mc_rip as usize;
461 sp = context.uc_mcontext.mc_rsp as usize;
462 } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
463 let mcontext = unsafe { &*context.uc_mcontext };
464 pc = mcontext.__ss.__rip as usize;
465 sp = mcontext.__ss.__rsp as usize;
466 } else if #[cfg(all(
467 any(target_os = "linux", target_os = "android"),
468 target_arch = "aarch64",
469 ))] {
470 pc = context.uc_mcontext.pc as usize;
471 sp = context.uc_mcontext.sp as usize;
472 } else if #[cfg(all(
473 any(target_os = "linux", target_os = "android"),
474 target_arch = "arm",
475 ))] {
476 pc = context.uc_mcontext.arm_pc as usize;
477 sp = context.uc_mcontext.arm_sp as usize;
478 } else if #[cfg(all(
479 any(target_os = "linux", target_os = "android"),
480 any(target_arch = "riscv64", target_arch = "riscv32"),
481 ))] {
482 pc = context.uc_mcontext.__gregs[libc::REG_PC] as usize;
483 sp = context.uc_mcontext.__gregs[libc::REG_SP] as usize;
484 } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
485 let mcontext = unsafe { &*context.uc_mcontext };
486 pc = mcontext.__ss.__pc as usize;
487 sp = mcontext.__ss.__sp as usize;
488 } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
489 pc = context.uc_mcontext.mc_gpregs.gp_elr as usize;
490 sp = context.uc_mcontext.mc_gpregs.gp_sp as usize;
491 } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
492 pc = context.uc_mcontext.__gregs[1] as usize;
493 sp = context.uc_mcontext.__gregs[3] as usize;
494 } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
495 pc = (*context.uc_mcontext.regs).nip as usize;
496 sp = (*context.uc_mcontext.regs).gpr[1] as usize;
497 } else {
498 compile_error!("Unsupported platform");
499 }
500 };
501 (pc, sp)
502 }
503
504 unsafe fn update_context(context: &mut ucontext_t, regs: TrapHandlerRegs) {
505 cfg_if::cfg_if! {
506 if #[cfg(all(
507 any(target_os = "linux", target_os = "android"),
508 target_arch = "x86_64",
509 ))] {
510 let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
511 context.uc_mcontext.gregs[libc::REG_RIP as usize] = rip as i64;
512 context.uc_mcontext.gregs[libc::REG_RSP as usize] = rsp as i64;
513 context.uc_mcontext.gregs[libc::REG_RBP as usize] = rbp as i64;
514 context.uc_mcontext.gregs[libc::REG_RDI as usize] = rdi as i64;
515 context.uc_mcontext.gregs[libc::REG_RSI as usize] = rsi as i64;
516 } else if #[cfg(all(
517 any(target_os = "linux", target_os = "android"),
518 target_arch = "x86",
519 ))] {
520 let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
521 context.uc_mcontext.gregs[libc::REG_EIP as usize] = eip as i32;
522 context.uc_mcontext.gregs[libc::REG_ESP as usize] = esp as i32;
523 context.uc_mcontext.gregs[libc::REG_EBP as usize] = ebp as i32;
524 context.uc_mcontext.gregs[libc::REG_ECX as usize] = ecx as i32;
525 context.uc_mcontext.gregs[libc::REG_EDX as usize] = edx as i32;
526 } else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
527 let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
528 let mcontext = unsafe { &mut *context.uc_mcontext };
529 mcontext.__ss.__rip = rip;
530 mcontext.__ss.__rsp = rsp;
531 mcontext.__ss.__rbp = rbp;
532 mcontext.__ss.__rdi = rdi;
533 mcontext.__ss.__rsi = rsi;
534 } else if #[cfg(all(target_os = "freebsd", target_arch = "x86"))] {
535 let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
536 context.uc_mcontext.mc_eip = eip as libc::register_t;
537 context.uc_mcontext.mc_esp = esp as libc::register_t;
538 context.uc_mcontext.mc_ebp = ebp as libc::register_t;
539 context.uc_mcontext.mc_ecx = ecx as libc::register_t;
540 context.uc_mcontext.mc_edx = edx as libc::register_t;
541 } else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
542 let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
543 context.uc_mcontext.mc_rip = rip as libc::register_t;
544 context.uc_mcontext.mc_rsp = rsp as libc::register_t;
545 context.uc_mcontext.mc_rbp = rbp as libc::register_t;
546 context.uc_mcontext.mc_rdi = rdi as libc::register_t;
547 context.uc_mcontext.mc_rsi = rsi as libc::register_t;
548 } else if #[cfg(all(
549 any(target_os = "linux", target_os = "android"),
550 target_arch = "aarch64",
551 ))] {
552 let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
553 context.uc_mcontext.pc = pc;
554 context.uc_mcontext.sp = sp;
555 context.uc_mcontext.regs[0] = x0;
556 context.uc_mcontext.regs[1] = x1;
557 context.uc_mcontext.regs[29] = x29;
558 context.uc_mcontext.regs[30] = lr;
559 } else if #[cfg(all(
560 any(target_os = "linux", target_os = "android"),
561 target_arch = "arm",
562 ))] {
563 let TrapHandlerRegs {
564 pc,
565 r0,
566 r1,
567 r7,
568 r11,
569 r13,
570 r14,
571 cpsr_thumb,
572 cpsr_endian,
573 } = regs;
574 context.uc_mcontext.arm_pc = pc;
575 context.uc_mcontext.arm_r0 = r0;
576 context.uc_mcontext.arm_r1 = r1;
577 context.uc_mcontext.arm_r7 = r7;
578 context.uc_mcontext.arm_fp = r11;
579 context.uc_mcontext.arm_sp = r13;
580 context.uc_mcontext.arm_lr = r14;
581 if cpsr_thumb {
582 context.uc_mcontext.arm_cpsr |= 0x20;
583 } else {
584 context.uc_mcontext.arm_cpsr &= !0x20;
585 }
586 if cpsr_endian {
587 context.uc_mcontext.arm_cpsr |= 0x200;
588 } else {
589 context.uc_mcontext.arm_cpsr &= !0x200;
590 }
591 } else if #[cfg(all(
592 any(target_os = "linux", target_os = "android"),
593 any(target_arch = "riscv64", target_arch = "riscv32"),
594 ))] {
595 let TrapHandlerRegs { pc, ra, sp, a0, a1, s0 } = regs;
596 context.uc_mcontext.__gregs[libc::REG_PC] = pc as libc::c_ulong;
597 context.uc_mcontext.__gregs[libc::REG_RA] = ra as libc::c_ulong;
598 context.uc_mcontext.__gregs[libc::REG_SP] = sp as libc::c_ulong;
599 context.uc_mcontext.__gregs[libc::REG_A0] = a0 as libc::c_ulong;
600 context.uc_mcontext.__gregs[libc::REG_A0 + 1] = a1 as libc::c_ulong;
601 context.uc_mcontext.__gregs[libc::REG_S0] = s0 as libc::c_ulong;
602 } else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
603 let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
604 let mcontext = unsafe { &mut *context.uc_mcontext };
605 mcontext.__ss.__pc = pc;
606 mcontext.__ss.__sp = sp;
607 mcontext.__ss.__x[0] = x0;
608 mcontext.__ss.__x[1] = x1;
609 mcontext.__ss.__fp = x29;
610 mcontext.__ss.__lr = lr;
611 } else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
612 let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
613 context.uc_mcontext.mc_gpregs.gp_elr = pc as libc::register_t;
614 context.uc_mcontext.mc_gpregs.gp_sp = sp as libc::register_t;
615 context.uc_mcontext.mc_gpregs.gp_x[0] = x0 as libc::register_t;
616 context.uc_mcontext.mc_gpregs.gp_x[1] = x1 as libc::register_t;
617 context.uc_mcontext.mc_gpregs.gp_x[29] = x29 as libc::register_t;
618 context.uc_mcontext.mc_gpregs.gp_lr = lr as libc::register_t;
619 } else if #[cfg(all(target_os = "linux", target_arch = "loongarch64"))] {
620 let TrapHandlerRegs { pc, sp, a0, a1, fp, ra } = regs;
621 context.uc_mcontext.__pc = pc;
622 context.uc_mcontext.__gregs[1] = ra;
623 context.uc_mcontext.__gregs[3] = sp;
624 context.uc_mcontext.__gregs[4] = a0;
625 context.uc_mcontext.__gregs[5] = a1;
626 context.uc_mcontext.__gregs[22] = fp;
627 } else if #[cfg(all(target_os = "linux", target_arch = "powerpc64"))] {
628 let TrapHandlerRegs { pc, sp, r3, r4, r31, lr } = regs;
629 (*context.uc_mcontext.regs).nip = pc;
630 (*context.uc_mcontext.regs).gpr[1] = sp;
631 (*context.uc_mcontext.regs).gpr[3] = r3;
632 (*context.uc_mcontext.regs).gpr[4] = r4;
633 (*context.uc_mcontext.regs).gpr[31] = r31;
634 (*context.uc_mcontext.regs).link = lr;
635 } else {
636 compile_error!("Unsupported platform");
637 }
638 };
639 }
640 } else if #[cfg(target_os = "windows")] {
641 use windows_sys::Win32::System::Diagnostics::Debug::{
642 AddVectoredExceptionHandler,
643 CONTEXT,
644 EXCEPTION_CONTINUE_EXECUTION,
645 EXCEPTION_CONTINUE_SEARCH,
646 EXCEPTION_POINTERS,
647 };
648 use windows_sys::Win32::Foundation::{
649 EXCEPTION_ACCESS_VIOLATION,
650 EXCEPTION_ILLEGAL_INSTRUCTION,
651 EXCEPTION_INT_DIVIDE_BY_ZERO,
652 EXCEPTION_INT_OVERFLOW,
653 EXCEPTION_STACK_OVERFLOW,
654 };
655
656 unsafe fn platform_init() {
657 unsafe {
658 let handler = AddVectoredExceptionHandler(1, Some(exception_handler));
662 if handler.is_null() {
663 panic!("failed to add exception handler: {}", io::Error::last_os_error());
664 }
665 }
666 }
667
668 unsafe extern "system" fn exception_handler(
669 exception_info: *mut EXCEPTION_POINTERS
670 ) -> i32 {
671 unsafe {
672 let record = &*(*exception_info).ExceptionRecord;
676 if record.ExceptionCode != EXCEPTION_ACCESS_VIOLATION &&
677 record.ExceptionCode != EXCEPTION_ILLEGAL_INSTRUCTION &&
678 record.ExceptionCode != EXCEPTION_STACK_OVERFLOW &&
679 record.ExceptionCode != EXCEPTION_INT_DIVIDE_BY_ZERO &&
680 record.ExceptionCode != EXCEPTION_INT_OVERFLOW
681 {
682 return EXCEPTION_CONTINUE_SEARCH;
683 }
684
685 let context = &mut *(*exception_info).ContextRecord;
698 let (pc, sp) = get_pc_sp(context);
699
700 let maybe_fault_address = match record.ExceptionCode {
702 EXCEPTION_ACCESS_VIOLATION => Some(record.ExceptionInformation[1]),
703 EXCEPTION_STACK_OVERFLOW => Some(sp),
704 _ => None,
705 };
706 let trap_code = match record.ExceptionCode {
707 EXCEPTION_ILLEGAL_INSTRUCTION => {
709 process_illegal_op(pc)
710 }
711 _ => None,
712 };
713 let handled = TrapHandlerContext::handle_trap(
716 pc,
717 sp,
718 maybe_fault_address,
719 trap_code,
720 |regs| update_context(context, regs),
721 |handler| handler(exception_info),
722 );
723
724 if handled {
725 EXCEPTION_CONTINUE_EXECUTION
726 } else {
727 EXCEPTION_CONTINUE_SEARCH
728 }
729 }
730 }
731
732 unsafe fn get_pc_sp(context: &CONTEXT) -> (usize, usize) {
733 let (pc, sp);
734 cfg_if::cfg_if! {
735 if #[cfg(target_arch = "x86_64")] {
736 pc = context.Rip as usize;
737 sp = context.Rsp as usize;
738 } else if #[cfg(target_arch = "x86")] {
739 pc = context.Eip as usize;
740 sp = context.Esp as usize;
741 } else {
742 compile_error!("Unsupported platform");
743 }
744 };
745 (pc, sp)
746 }
747
748 unsafe fn update_context(context: &mut CONTEXT, regs: TrapHandlerRegs) {
749 cfg_if::cfg_if! {
750 if #[cfg(target_arch = "x86_64")] {
751 let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
752 context.Rip = rip;
753 context.Rsp = rsp;
754 context.Rbp = rbp;
755 context.Rdi = rdi;
756 context.Rsi = rsi;
757 } else if #[cfg(target_arch = "x86")] {
758 let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
759 context.Eip = eip;
760 context.Esp = esp;
761 context.Ebp = ebp;
762 context.Ecx = ecx;
763 context.Edx = edx;
764 } else {
765 compile_error!("Unsupported platform");
766 }
767 };
768 }
769 }
770}
771
772pub fn init_traps() {
781 static INIT: Once = Once::new();
782 INIT.call_once(|| unsafe {
783 platform_init();
784 });
785}
786
787pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
800 unsafe { unwind_with(UnwindReason::UserTrap(data)) }
801}
802
803pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
815 unsafe { unwind_with(UnwindReason::LibTrap(trap)) }
816}
817
818pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
827 unsafe { unwind_with(UnwindReason::Panic(payload)) }
828}
829
830pub unsafe fn catch_traps<F, R: 'static>(
837 trap_handler: Option<*const TrapHandlerFn<'static>>,
838 config: &VMConfig,
839 closure: F,
840) -> Result<R, Trap>
841where
842 F: FnOnce() -> R + 'static,
843{
844 lazy_per_thread_init()?;
846 let stack_size = config
847 .wasm_stack_size
848 .unwrap_or_else(|| DEFAULT_STACK_SIZE.load(Ordering::Relaxed));
849 on_wasm_stack(stack_size, trap_handler, closure).map_err(UnwindReason::into_trap)
850}
851
852thread_local! {
861 static YIELDER: Cell<Option<NonNull<Yielder<(), UnwindReason>>>> = const { Cell::new(None) };
862 static TRAP_HANDLER: AtomicPtr<TrapHandlerContext> = const { AtomicPtr::new(ptr::null_mut()) };
863}
864
865#[allow(clippy::type_complexity)]
868struct TrapHandlerContext {
869 inner: *const u8,
870 handle_trap: fn(
871 *const u8,
872 usize,
873 usize,
874 Option<usize>,
875 Option<TrapCode>,
876 &mut dyn FnMut(TrapHandlerRegs),
877 ) -> bool,
878 custom_trap: Option<*const TrapHandlerFn<'static>>,
879}
880struct TrapHandlerContextInner<T> {
881 coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
884}
885
886impl TrapHandlerContext {
887 fn install<T, R>(
890 custom_trap: Option<*const TrapHandlerFn<'static>>,
891 coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
892 f: impl FnOnce() -> R,
893 ) -> R {
894 fn func<T>(
896 ptr: *const u8,
897 pc: usize,
898 sp: usize,
899 maybe_fault_address: Option<usize>,
900 trap_code: Option<TrapCode>,
901 update_regs: &mut dyn FnMut(TrapHandlerRegs),
902 ) -> bool {
903 unsafe {
904 (*(ptr as *const TrapHandlerContextInner<T>)).handle_trap(
905 pc,
906 sp,
907 maybe_fault_address,
908 trap_code,
909 update_regs,
910 )
911 }
912 }
913 let inner = TrapHandlerContextInner { coro_trap_handler };
914 let ctx = Self {
915 inner: &inner as *const _ as *const u8,
916 handle_trap: func::<T>,
917 custom_trap,
918 };
919
920 compiler_fence(Ordering::Release);
921 let prev = TRAP_HANDLER.with(|ptr| {
922 let prev = ptr.load(Ordering::Relaxed);
923 ptr.store(&ctx as *const Self as *mut Self, Ordering::Relaxed);
924 prev
925 });
926
927 defer! {
928 TRAP_HANDLER.with(|ptr| ptr.store(prev, Ordering::Relaxed));
929 compiler_fence(Ordering::Acquire);
930 }
931
932 f()
933 }
934
935 unsafe fn handle_trap(
937 pc: usize,
938 sp: usize,
939 maybe_fault_address: Option<usize>,
940 trap_code: Option<TrapCode>,
941 mut update_regs: impl FnMut(TrapHandlerRegs),
942 call_handler: impl Fn(&TrapHandlerFn<'static>) -> bool,
943 ) -> bool {
944 unsafe {
945 let ptr = TRAP_HANDLER.with(|ptr| ptr.load(Ordering::Relaxed));
946 if ptr.is_null() {
947 return false;
948 }
949
950 let ctx = &*ptr;
951
952 if let Some(trap_handler) = ctx.custom_trap
954 && call_handler(&*trap_handler)
955 {
956 return true;
957 }
958
959 (ctx.handle_trap)(
960 ctx.inner,
961 pc,
962 sp,
963 maybe_fault_address,
964 trap_code,
965 &mut update_regs,
966 )
967 }
968 }
969}
970
971impl<T> TrapHandlerContextInner<T> {
972 unsafe fn handle_trap(
973 &self,
974 pc: usize,
975 sp: usize,
976 maybe_fault_address: Option<usize>,
977 trap_code: Option<TrapCode>,
978 update_regs: &mut dyn FnMut(TrapHandlerRegs),
979 ) -> bool {
980 unsafe {
981 if !self.coro_trap_handler.stack_ptr_in_bounds(sp) {
984 return false;
985 }
986
987 let signal_trap = trap_code.or_else(|| {
988 maybe_fault_address.map(|addr| {
989 if self.coro_trap_handler.stack_ptr_in_bounds(addr) {
990 TrapCode::StackOverflow
991 } else {
992 TrapCode::HeapAccessOutOfBounds
993 }
994 })
995 });
996
997 let backtrace = if signal_trap == Some(TrapCode::StackOverflow) {
1004 Backtrace::from(vec![])
1005 } else {
1006 Backtrace::new_unresolved()
1007 };
1008
1009 let unwind = UnwindReason::WasmTrap {
1012 backtrace,
1013 signal_trap,
1014 pc,
1015 };
1016 let regs = self
1017 .coro_trap_handler
1018 .setup_trap_handler(move || Err(unwind));
1019 update_regs(regs);
1020 true
1021 }
1022 }
1023}
1024
1025unsafe fn unwind_with(reason: UnwindReason) -> ! {
1026 unsafe {
1027 let yielder = YIELDER
1028 .with(|cell| cell.replace(None))
1029 .expect("not running on Wasm stack");
1030
1031 yielder.as_ref().suspend(reason);
1032
1033 unreachable!();
1035 }
1036}
1037
1038fn on_wasm_stack<F: FnOnce() -> T + 'static, T: 'static>(
1042 stack_size: usize,
1043 trap_handler: Option<*const TrapHandlerFn<'static>>,
1044 f: F,
1045) -> Result<T, UnwindReason> {
1046 let stack = acquire_stack(stack_size);
1053 let mut stack = scopeguard::guard(stack, release_stack);
1054
1055 let coro = ScopedCoroutine::with_stack(&mut *stack, move |yielder, ()| {
1057 YIELDER.with(|cell| cell.set(Some(yielder.into())));
1059
1060 Ok(f())
1061 });
1062
1063 defer! {
1065 YIELDER.with(|cell| cell.set(None));
1066 }
1067
1068 coro.scope(|mut coro_ref| {
1069 TrapHandlerContext::install(trap_handler, coro_ref.trap_handler(), || {
1072 match coro_ref.resume(()) {
1073 CoroutineResult::Yield(trap) => {
1074 unsafe {
1077 coro_ref.force_reset();
1078 }
1079 Err(trap)
1080 }
1081 CoroutineResult::Return(result) => result,
1082 }
1083 })
1084 })
1085}
1086
1087pub fn on_host_stack<F: FnOnce() -> T, T>(f: F) -> T {
1096 let yielder_ptr = YIELDER.with(|cell| cell.replace(None));
1099
1100 let yielder = match yielder_ptr {
1103 Some(ptr) => unsafe { ptr.as_ref() },
1104 None => return f(),
1105 };
1106
1107 defer! {
1109 YIELDER.with(|cell| cell.set(yielder_ptr));
1110 }
1111
1112 struct SendWrapper<T>(T);
1116 unsafe impl<T> Send for SendWrapper<T> {}
1117 let wrapped = SendWrapper(f);
1118 yielder.on_parent_stack(move || {
1119 let wrapped = wrapped;
1120 (wrapped.0)()
1121 })
1122}
1123
1124#[cfg(windows)]
1125pub fn lazy_per_thread_init() -> Result<(), Trap> {
1126 use windows_sys::Win32::System::Threading::SetThreadStackGuarantee;
1130 if unsafe { SetThreadStackGuarantee(&mut 0x10000) } == 0 {
1131 panic!("failed to set thread stack guarantee");
1132 }
1133
1134 Ok(())
1135}
1136
1137#[cfg(unix)]
1144pub fn lazy_per_thread_init() -> Result<(), Trap> {
1145 use std::ptr::null_mut;
1146
1147 thread_local! {
1148 static TLS: Tls = unsafe { init_sigstack() };
1151 }
1152
1153 const MIN_STACK_SIZE: usize = ByteSize::kib(64).as_u64() as usize;
1156
1157 enum Tls {
1158 OutOfMemory,
1159 Allocated {
1160 mmap_ptr: *mut libc::c_void,
1161 mmap_size: usize,
1162 },
1163 BigEnough,
1164 }
1165
1166 unsafe fn init_sigstack() -> Tls {
1167 unsafe {
1168 let mut old_stack = mem::zeroed();
1171 let r = libc::sigaltstack(ptr::null(), &mut old_stack);
1172 assert_eq!(r, 0, "learning about sigaltstack failed");
1173 if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE {
1174 return Tls::BigEnough;
1175 }
1176
1177 let page_size: usize = region::page::size();
1180 let guard_size = page_size;
1181 let alloc_size = guard_size + MIN_STACK_SIZE;
1182
1183 let ptr = libc::mmap(
1184 null_mut(),
1185 alloc_size,
1186 libc::PROT_NONE,
1187 libc::MAP_PRIVATE | libc::MAP_ANON,
1188 -1,
1189 0,
1190 );
1191 if ptr == libc::MAP_FAILED {
1192 return Tls::OutOfMemory;
1193 }
1194
1195 let stack_ptr = (ptr as usize + guard_size) as *mut libc::c_void;
1198 let r = libc::mprotect(
1199 stack_ptr,
1200 MIN_STACK_SIZE,
1201 libc::PROT_READ | libc::PROT_WRITE,
1202 );
1203 assert_eq!(r, 0, "mprotect to configure memory for sigaltstack failed");
1204 let new_stack = libc::stack_t {
1205 ss_sp: stack_ptr,
1206 ss_flags: 0,
1207 ss_size: MIN_STACK_SIZE,
1208 };
1209 let r = libc::sigaltstack(&new_stack, ptr::null_mut());
1210 assert_eq!(r, 0, "registering new sigaltstack failed");
1211
1212 Tls::Allocated {
1213 mmap_ptr: ptr,
1214 mmap_size: alloc_size,
1215 }
1216 }
1217 }
1218
1219 return TLS.with(|tls| {
1222 if let Tls::OutOfMemory = tls {
1223 Err(Trap::oom())
1224 } else {
1225 Ok(())
1226 }
1227 });
1228
1229 impl Drop for Tls {
1230 fn drop(&mut self) {
1231 let (ptr, size) = match self {
1232 Self::Allocated {
1233 mmap_ptr,
1234 mmap_size,
1235 } => (*mmap_ptr, *mmap_size),
1236 _ => return,
1237 };
1238 unsafe {
1239 let r = libc::munmap(ptr, size);
1241 debug_assert_eq!(r, 0, "munmap failed during thread shutdown");
1242 }
1243 }
1244 }
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249 use super::*;
1250 use std::sync::Mutex;
1251
1252 static GLOBAL_STATE: Mutex<()> = Mutex::new(());
1256
1257 struct RestoreStackSize(usize);
1259 impl Drop for RestoreStackSize {
1260 fn drop(&mut self) {
1261 set_stack_size(self.0);
1262 }
1263 }
1264
1265 #[test]
1266 fn max_stack_size_is_100mb() {
1267 assert_eq!(MAX_STACK_SIZE, ByteSize::mib(100).as_u64() as usize);
1268 }
1269
1270 #[test]
1271 fn get_set_stack_size_roundtrip() {
1272 let _lock = GLOBAL_STATE.lock().unwrap();
1273 let _restore = RestoreStackSize(get_stack_size());
1274 let new_size = ByteSize::mib(4).as_u64() as usize;
1275 set_stack_size(new_size);
1276 assert_eq!(get_stack_size(), new_size);
1277 }
1278
1279 #[test]
1280 fn set_stack_size_clamps_to_min() {
1281 let _lock = GLOBAL_STATE.lock().unwrap();
1282 let _restore = RestoreStackSize(get_stack_size());
1283 set_stack_size(1); assert_eq!(get_stack_size(), ByteSize::kib(8).as_u64() as usize);
1285 }
1286
1287 #[test]
1288 fn set_stack_size_clamps_to_max() {
1289 let _lock = GLOBAL_STATE.lock().unwrap();
1290 let _restore = RestoreStackSize(get_stack_size());
1291 set_stack_size(usize::MAX);
1292 assert_eq!(get_stack_size(), MAX_STACK_SIZE);
1293 }
1294
1295 #[test]
1296 fn drain_stack_pool_empties_pool() {
1297 let _lock = GLOBAL_STATE.lock().unwrap();
1298 let stack = DefaultStack::new(ByteSize::mib(1).as_u64() as usize).unwrap();
1299 STACK_POOL.push(stack);
1300 assert!(!STACK_POOL.is_empty());
1301 drain_stack_pool();
1302 assert!(STACK_POOL.is_empty());
1303 }
1304
1305 #[test]
1306 fn drain_stack_pool_is_idempotent() {
1307 let _lock = GLOBAL_STATE.lock().unwrap();
1308 drain_stack_pool();
1309 drain_stack_pool(); assert!(STACK_POOL.is_empty());
1311 }
1312
1313 #[test]
1323 fn pool_returns_stale_stack_without_drain() {
1324 let _lock = GLOBAL_STATE.lock().unwrap();
1325 let _restore = RestoreStackSize(get_stack_size());
1326 drain_stack_pool();
1327
1328 let small_size = ByteSize::kib(500).as_u64() as usize;
1330 let small_stack = DefaultStack::new(small_size).unwrap();
1331 STACK_POOL.push(small_stack);
1332
1333 let big_size = ByteSize::mib(1).as_u64() as usize;
1335 set_stack_size(big_size);
1336 assert_eq!(get_stack_size(), big_size);
1337
1338 let stale = STACK_POOL.pop();
1342 assert!(
1343 stale.is_some(),
1344 "pool should still contain the old stack (the bug scenario)"
1345 );
1346
1347 STACK_POOL.push(stale.unwrap());
1349 drain_stack_pool();
1350 assert!(
1351 STACK_POOL.pop().is_none(),
1352 "after drain, pool must be empty so a fresh stack is allocated at the new size"
1353 );
1354 }
1355
1356 #[test]
1359 fn on_wasm_stack_discards_undersized_stack() {
1360 let _lock = GLOBAL_STATE.lock().unwrap();
1361 let _restore = RestoreStackSize(get_stack_size());
1362 drain_stack_pool();
1363 clear_tls_stack();
1364
1365 let small_size = ByteSize::kib(500).as_u64() as usize;
1367 let small_stack = DefaultStack::new(small_size).unwrap();
1368 STACK_POOL.push(small_stack);
1369
1370 let big_size = ByteSize::mib(1).as_u64() as usize;
1372 let result = on_wasm_stack(big_size, None, || 42);
1373
1374 assert_eq!(result.ok().expect("on_wasm_stack should succeed"), 42);
1375 let returned = TLS_STACK
1379 .with(|cache| cache.0.take())
1380 .or_else(|| STACK_POOL.pop())
1381 .expect("stack should have been returned to TLS cache or pool");
1382 assert!(
1383 returned.size() >= big_size,
1384 "returned stack must be at least as large as the requested size"
1385 );
1386
1387 clear_tls_stack();
1391 }
1392
1393 #[test]
1397 fn tls_stack_caches_after_first_call() {
1398 let _lock = GLOBAL_STATE.lock().unwrap();
1399 let _restore = RestoreStackSize(get_stack_size());
1400 drain_stack_pool();
1401 clear_tls_stack();
1402
1403 let size = get_stack_size();
1404
1405 assert!(on_wasm_stack(size, None, || ()).is_ok());
1407 assert!(
1408 STACK_POOL.is_empty(),
1409 "pool should still be empty after a TLS-served call"
1410 );
1411
1412 let cached_present = TLS_STACK.with(|cache| {
1414 let taken = cache.0.take();
1415 let present = taken.is_some();
1416 cache.0.set(taken);
1417 present
1418 });
1419 assert!(cached_present, "TLS slot should hold the post-call stack");
1420
1421 assert!(on_wasm_stack(size, None, || ()).is_ok());
1423 assert!(
1424 STACK_POOL.is_empty(),
1425 "second call must not push to the global pool"
1426 );
1427 let still_cached = TLS_STACK.with(|cache| {
1428 let taken = cache.0.take();
1429 let present = taken.is_some();
1430 cache.0.set(taken);
1431 present
1432 });
1433 assert!(
1434 still_cached,
1435 "TLS slot should still hold a stack after the second call"
1436 );
1437
1438 clear_tls_stack();
1443 }
1444
1445 #[test]
1448 fn tls_stack_returns_to_pool_on_thread_exit() {
1449 let lock = GLOBAL_STATE.lock().unwrap();
1462 let _restore = RestoreStackSize(get_stack_size());
1463 drain_stack_pool();
1464 clear_tls_stack();
1465
1466 let size = get_stack_size();
1467 drop(lock);
1468
1469 let handle = std::thread::spawn(move || {
1470 assert!(on_wasm_stack(size, None, || ()).is_ok());
1471 });
1472 handle.join().unwrap();
1473
1474 let _lock = GLOBAL_STATE.lock().unwrap();
1475 let returned = STACK_POOL
1478 .pop()
1479 .expect("thread exit should return TLS-cached stack to the global pool");
1480 assert!(returned.size() >= size);
1481 }
1482
1483 fn clear_tls_stack() {
1491 TLS_STACK.with(|cache| cache.0.set(None));
1492 }
1493
1494 fn stack_id(stack: &DefaultStack) -> usize {
1499 stack.base().get()
1500 }
1501
1502 #[test]
1507 fn acquire_allocates_fresh_when_tls_and_pool_empty() {
1508 let _lock = GLOBAL_STATE.lock().unwrap();
1509 let _restore = RestoreStackSize(get_stack_size());
1510 drain_stack_pool();
1511 clear_tls_stack();
1512
1513 let size = get_stack_size();
1514 let stack = acquire_stack(size);
1515 assert!(
1516 stack.size() >= size,
1517 "freshly allocated stack must satisfy min_size"
1518 );
1519
1520 drop(stack);
1521 clear_tls_stack();
1522 drain_stack_pool();
1523 }
1524
1525 #[test]
1526 fn acquire_prefers_tls_over_pool() {
1527 let _lock = GLOBAL_STATE.lock().unwrap();
1528 let _restore = RestoreStackSize(get_stack_size());
1529 drain_stack_pool();
1530 clear_tls_stack();
1531
1532 let size = get_stack_size();
1533 let tls_stack = DefaultStack::new(size).unwrap();
1534 let tls_id = stack_id(&tls_stack);
1535 TLS_STACK.with(|cache| cache.0.set(Some(tls_stack)));
1536
1537 let pool_stack = DefaultStack::new(size).unwrap();
1538 let pool_id = stack_id(&pool_stack);
1539 STACK_POOL.push(pool_stack);
1540
1541 let got = acquire_stack(size);
1542 assert_eq!(stack_id(&got), tls_id, "acquire must prefer TLS over pool");
1543 assert_ne!(stack_id(&got), pool_id);
1544
1545 drop(got);
1546 clear_tls_stack();
1547 drain_stack_pool();
1548 }
1549
1550 #[test]
1551 fn acquire_uses_pool_when_tls_empty() {
1552 let _lock = GLOBAL_STATE.lock().unwrap();
1553 let _restore = RestoreStackSize(get_stack_size());
1554 drain_stack_pool();
1555 clear_tls_stack();
1556
1557 let size = get_stack_size();
1558 let pool_stack = DefaultStack::new(size).unwrap();
1559 let pool_id = stack_id(&pool_stack);
1560 STACK_POOL.push(pool_stack);
1561
1562 let got = acquire_stack(size);
1563 assert_eq!(
1564 stack_id(&got),
1565 pool_id,
1566 "acquire must consume from pool when TLS is empty"
1567 );
1568 assert!(
1569 STACK_POOL.is_empty(),
1570 "pool stack must be removed when used"
1571 );
1572
1573 drop(got);
1574 clear_tls_stack();
1575 drain_stack_pool();
1576 }
1577
1578 #[test]
1579 fn acquire_discards_undersized_tls_then_allocates() {
1580 let _lock = GLOBAL_STATE.lock().unwrap();
1581 let _restore = RestoreStackSize(get_stack_size());
1582 drain_stack_pool();
1583 clear_tls_stack();
1584
1585 let small_size = ByteSize::kib(512).as_u64() as usize;
1586 let undersized = DefaultStack::new(small_size).unwrap();
1587 TLS_STACK.with(|cache| cache.0.set(Some(undersized)));
1588
1589 let big_size = ByteSize::mib(2).as_u64() as usize;
1590 let got = acquire_stack(big_size);
1591
1592 assert!(
1600 got.size() >= big_size,
1601 "acquired stack must satisfy big_size"
1602 );
1603 let tls_empty = TLS_STACK.with(|cache| {
1604 let s = cache.0.take();
1605 let empty = s.is_none();
1606 cache.0.set(s);
1607 empty
1608 });
1609 assert!(
1610 tls_empty,
1611 "undersized TLS stack must have been taken and discarded"
1612 );
1613 assert!(
1614 STACK_POOL.is_empty(),
1615 "undersized TLS stack must be discarded, not pushed to the pool",
1616 );
1617
1618 drop(got);
1619 clear_tls_stack();
1620 drain_stack_pool();
1621 }
1622
1623 #[test]
1624 fn acquire_discards_undersized_pool_then_allocates() {
1625 let _lock = GLOBAL_STATE.lock().unwrap();
1626 let _restore = RestoreStackSize(get_stack_size());
1627 drain_stack_pool();
1628 clear_tls_stack();
1629
1630 let small_size = ByteSize::kib(512).as_u64() as usize;
1631 let undersized = DefaultStack::new(small_size).unwrap();
1632 STACK_POOL.push(undersized);
1633
1634 let big_size = ByteSize::mib(2).as_u64() as usize;
1635 let got = acquire_stack(big_size);
1636
1637 assert!(
1642 got.size() >= big_size,
1643 "acquired stack must satisfy big_size"
1644 );
1645 assert!(
1646 STACK_POOL.is_empty(),
1647 "undersized pool stack must have been popped, filtered out and dropped",
1648 );
1649
1650 drop(got);
1651 clear_tls_stack();
1652 drain_stack_pool();
1653 }
1654
1655 #[test]
1660 fn release_into_empty_tls_caches_there() {
1661 let _lock = GLOBAL_STATE.lock().unwrap();
1662 drain_stack_pool();
1663 clear_tls_stack();
1664
1665 let size = get_stack_size();
1666 let stack = DefaultStack::new(size).unwrap();
1667 let id = stack_id(&stack);
1668 release_stack(stack);
1669
1670 let in_tls = TLS_STACK
1671 .with(|cache| cache.0.take())
1672 .expect("release into empty TLS should leave the stack in TLS");
1673 assert_eq!(stack_id(&in_tls), id);
1674 assert!(
1675 STACK_POOL.is_empty(),
1676 "pool must not be touched when TLS is empty"
1677 );
1678
1679 drain_stack_pool();
1680 }
1681
1682 #[test]
1683 fn release_into_occupied_tls_displaces_older_to_pool() {
1684 let _lock = GLOBAL_STATE.lock().unwrap();
1685 drain_stack_pool();
1686 clear_tls_stack();
1687
1688 let size = get_stack_size();
1689 let older = DefaultStack::new(size).unwrap();
1690 let older_id = stack_id(&older);
1691 TLS_STACK.with(|cache| cache.0.set(Some(older)));
1692
1693 let newer = DefaultStack::new(size).unwrap();
1694 let newer_id = stack_id(&newer);
1695 release_stack(newer);
1696
1697 let in_tls = TLS_STACK
1698 .with(|cache| cache.0.take())
1699 .expect("TLS should hold the newly-released stack");
1700 assert_eq!(
1701 stack_id(&in_tls),
1702 newer_id,
1703 "newer stack must displace into TLS"
1704 );
1705
1706 let displaced = STACK_POOL
1707 .pop()
1708 .expect("older stack should have been pushed to global pool");
1709 assert_eq!(
1710 stack_id(&displaced),
1711 older_id,
1712 "displaced stack must be the older one"
1713 );
1714
1715 drain_stack_pool();
1716 }
1717
1718 #[test]
1723 fn drain_stack_pool_clears_calling_thread_tls_slot() {
1724 let _lock = GLOBAL_STATE.lock().unwrap();
1725 drain_stack_pool();
1726 clear_tls_stack();
1727
1728 let stack = DefaultStack::new(get_stack_size()).unwrap();
1729 TLS_STACK.with(|cache| cache.0.set(Some(stack)));
1730
1731 drain_stack_pool();
1732
1733 let tls_empty = TLS_STACK.with(|cache| cache.0.take().is_none());
1734 assert!(
1735 tls_empty,
1736 "drain_stack_pool must also clear current thread's TLS slot"
1737 );
1738 assert!(STACK_POOL.is_empty());
1739 }
1740
1741 #[test]
1746 fn on_wasm_stack_passes_closure_value_back() {
1747 let _lock = GLOBAL_STATE.lock().unwrap();
1748 let _restore = RestoreStackSize(get_stack_size());
1749 drain_stack_pool();
1750 clear_tls_stack();
1751
1752 let r = on_wasm_stack(get_stack_size(), None, || 12345u32);
1753 assert_eq!(r.ok(), Some(12345));
1754
1755 clear_tls_stack();
1756 drain_stack_pool();
1757 }
1758
1759 #[test]
1760 fn on_wasm_stack_passes_owning_result_back() {
1761 let _lock = GLOBAL_STATE.lock().unwrap();
1762 let _restore = RestoreStackSize(get_stack_size());
1763 drain_stack_pool();
1764 clear_tls_stack();
1765
1766 let r = on_wasm_stack(get_stack_size(), None, || vec![0u8, 1, 2, 3, 4]);
1770 assert_eq!(r.ok(), Some(vec![0u8, 1, 2, 3, 4]));
1771
1772 clear_tls_stack();
1773 drain_stack_pool();
1774 }
1775
1776 #[test]
1777 fn many_calls_do_not_grow_global_pool() {
1778 let _lock = GLOBAL_STATE.lock().unwrap();
1779 let _restore = RestoreStackSize(get_stack_size());
1780 drain_stack_pool();
1781 clear_tls_stack();
1782
1783 for _ in 0..1000 {
1786 assert!(on_wasm_stack(get_stack_size(), None, || ()).is_ok());
1787 }
1788 assert!(
1789 STACK_POOL.is_empty(),
1790 "1000 sequential calls should not grow the global pool (TLS handles reuse)"
1791 );
1792
1793 clear_tls_stack();
1794 drain_stack_pool();
1795 }
1796
1797 #[test]
1802 fn raise_user_trap_yields_err() {
1803 let _lock = GLOBAL_STATE.lock().unwrap();
1804 let _restore = RestoreStackSize(get_stack_size());
1805 drain_stack_pool();
1806 clear_tls_stack();
1807
1808 let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1809 raise_user_trap(Box::new(io::Error::other("user trap from test")));
1810 });
1811 assert!(r.is_err(), "raise_user_trap must produce Err");
1812
1813 clear_tls_stack();
1814 drain_stack_pool();
1815 }
1816
1817 #[test]
1818 fn raise_lib_trap_yields_err() {
1819 let _lock = GLOBAL_STATE.lock().unwrap();
1820 let _restore = RestoreStackSize(get_stack_size());
1821 drain_stack_pool();
1822 clear_tls_stack();
1823
1824 let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1825 raise_lib_trap(Trap::lib(TrapCode::IntegerDivisionByZero));
1826 });
1827 assert!(r.is_err(), "raise_lib_trap must produce Err");
1828
1829 clear_tls_stack();
1830 drain_stack_pool();
1831 }
1832
1833 #[test]
1834 fn resume_panic_yields_err_without_unwinding() {
1835 let _lock = GLOBAL_STATE.lock().unwrap();
1840 let _restore = RestoreStackSize(get_stack_size());
1841 drain_stack_pool();
1842 clear_tls_stack();
1843
1844 let r: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1845 resume_panic(Box::new("panic payload from test"));
1846 });
1847 assert!(
1848 r.is_err(),
1849 "resume_panic must surface as Err to on_wasm_stack"
1850 );
1851
1852 clear_tls_stack();
1853 drain_stack_pool();
1854 }
1855
1856 #[test]
1857 fn trap_does_not_corrupt_subsequent_calls() {
1858 let _lock = GLOBAL_STATE.lock().unwrap();
1862 let _restore = RestoreStackSize(get_stack_size());
1863 drain_stack_pool();
1864 clear_tls_stack();
1865
1866 let trapped: Result<(), UnwindReason> = on_wasm_stack(get_stack_size(), None, || unsafe {
1867 raise_user_trap(Box::new(io::Error::other("first call traps")));
1868 });
1869 assert!(trapped.is_err());
1870
1871 let ok = on_wasm_stack(get_stack_size(), None, || 7u32);
1873 assert_eq!(ok.ok(), Some(7), "calls after a trap must still work");
1874
1875 clear_tls_stack();
1876 drain_stack_pool();
1877 }
1878
1879 #[test]
1884 fn on_host_stack_outside_coroutine_runs_inline() {
1885 let _lock = GLOBAL_STATE.lock().unwrap();
1888 let n = on_host_stack(|| 99i32);
1889 assert_eq!(n, 99);
1890 }
1891
1892 #[test]
1893 fn on_host_stack_inside_wasm_switches_and_returns() {
1894 let _lock = GLOBAL_STATE.lock().unwrap();
1895 let _restore = RestoreStackSize(get_stack_size());
1896 drain_stack_pool();
1897 clear_tls_stack();
1898
1899 let r = on_wasm_stack(get_stack_size(), None, || on_host_stack(|| 88i32));
1900 assert_eq!(r.ok(), Some(88));
1901
1902 clear_tls_stack();
1903 drain_stack_pool();
1904 }
1905
1906 #[test]
1911 fn reentrant_call_returns_value() {
1912 let _lock = GLOBAL_STATE.lock().unwrap();
1913 let _restore = RestoreStackSize(get_stack_size());
1914 drain_stack_pool();
1915 clear_tls_stack();
1916
1917 let outer = on_wasm_stack(get_stack_size(), None, || {
1918 on_wasm_stack(get_stack_size(), None, || 42i32)
1919 .ok()
1920 .expect("inner must succeed")
1921 });
1922 assert_eq!(outer.ok(), Some(42));
1923
1924 clear_tls_stack();
1925 drain_stack_pool();
1926 }
1927
1928 #[test]
1929 fn reentrant_calls_run_to_completion_under_pool_pressure() {
1930 use std::sync::Arc;
1937 use std::sync::atomic::{AtomicUsize, Ordering as O};
1938
1939 let _lock = GLOBAL_STATE.lock().unwrap();
1940 let _restore = RestoreStackSize(get_stack_size());
1941 drain_stack_pool();
1942 clear_tls_stack();
1943
1944 let pre = DefaultStack::new(get_stack_size()).unwrap();
1947 STACK_POOL.push(pre);
1948
1949 let inner_completed = Arc::new(AtomicUsize::new(0));
1950 let inner_completed_outer = inner_completed.clone();
1951 let _ = on_wasm_stack(get_stack_size(), None, move || {
1952 let inner_completed = inner_completed_outer.clone();
1953 let inner = on_wasm_stack(get_stack_size(), None, move || {
1954 inner_completed.fetch_add(1, O::Relaxed);
1955 });
1956 assert!(inner.is_ok(), "inner re-entrant call must succeed");
1957 });
1958 assert_eq!(
1959 inner_completed.load(O::Relaxed),
1960 1,
1961 "inner closure must have executed exactly once",
1962 );
1963
1964 clear_tls_stack();
1965 drain_stack_pool();
1966 }
1967
1968 #[test]
1969 fn reentrant_inner_trap_does_not_kill_outer() {
1970 let _lock = GLOBAL_STATE.lock().unwrap();
1971 let _restore = RestoreStackSize(get_stack_size());
1972 drain_stack_pool();
1973 clear_tls_stack();
1974
1975 let outer = on_wasm_stack(get_stack_size(), None, || {
1976 let inner: Result<i32, UnwindReason> =
1977 on_wasm_stack(get_stack_size(), None, || unsafe {
1978 raise_user_trap(Box::new(io::Error::other("inner trap")));
1979 });
1980 match inner {
1982 Err(_) => 1234i32,
1983 Ok(_) => panic!("inner should have trapped"),
1984 }
1985 });
1986 assert_eq!(
1987 outer.ok(),
1988 Some(1234),
1989 "outer must recover after inner trap and run to completion"
1990 );
1991
1992 clear_tls_stack();
1993 drain_stack_pool();
1994 }
1995
1996 #[test]
1997 fn reentrant_with_on_host_stack_in_between() {
1998 let _lock = GLOBAL_STATE.lock().unwrap();
2002 let _restore = RestoreStackSize(get_stack_size());
2003 drain_stack_pool();
2004 clear_tls_stack();
2005
2006 let r = on_wasm_stack(get_stack_size(), None, || {
2007 on_host_stack(|| {
2008 on_wasm_stack(get_stack_size(), None, || 5i32)
2009 .ok()
2010 .expect("nested inner must succeed")
2011 })
2012 });
2013 assert_eq!(r.ok(), Some(5));
2014
2015 clear_tls_stack();
2016 drain_stack_pool();
2017 }
2018
2019 #[test]
2024 fn many_threads_in_parallel_all_succeed() {
2025 let _lock = GLOBAL_STATE.lock().unwrap();
2026 let _restore = RestoreStackSize(get_stack_size());
2027 drain_stack_pool();
2028
2029 use std::sync::Arc;
2030 use std::sync::atomic::{AtomicUsize, Ordering as O};
2031
2032 let counter = Arc::new(AtomicUsize::new(0));
2033 const THREADS: usize = 8;
2034 const CALLS_PER_THREAD: usize = 200;
2035
2036 let handles: Vec<_> = (0..THREADS)
2037 .map(|_| {
2038 let counter = counter.clone();
2039 std::thread::spawn(move || {
2040 let size = get_stack_size();
2041 for _ in 0..CALLS_PER_THREAD {
2042 if on_wasm_stack(size, None, || 1u32).ok() == Some(1) {
2043 counter.fetch_add(1, O::Relaxed);
2044 }
2045 }
2046 })
2047 })
2048 .collect();
2049 for h in handles {
2050 h.join().unwrap();
2051 }
2052
2053 assert_eq!(counter.load(O::Relaxed), THREADS * CALLS_PER_THREAD);
2054 let mut pooled = 0usize;
2058 while STACK_POOL.pop().is_some() {
2059 pooled += 1;
2060 }
2061 assert!(
2062 pooled <= THREADS,
2063 "pool should hold at most one stack per terminated thread (got {pooled} for {THREADS} threads)"
2064 );
2065
2066 clear_tls_stack();
2067 drain_stack_pool();
2068 }
2069
2070 #[test]
2075 fn growing_request_discards_smaller_tls_stack() {
2076 let _lock = GLOBAL_STATE.lock().unwrap();
2077 let _restore = RestoreStackSize(get_stack_size());
2078 drain_stack_pool();
2079 clear_tls_stack();
2080
2081 let small = ByteSize::mib(1).as_u64() as usize;
2083 set_stack_size(small);
2084 assert!(on_wasm_stack(small, None, || ()).is_ok());
2085
2086 let cached_size = TLS_STACK.with(|cache| {
2087 let s = cache.0.take();
2088 let sz = s.as_ref().map_or(0, |s| s.size());
2089 cache.0.set(s);
2090 sz
2091 });
2092 assert!(
2093 cached_size >= small,
2094 "TLS should hold the small-sized stack"
2095 );
2096
2097 let big = ByteSize::mib(4).as_u64() as usize;
2100 set_stack_size(big);
2101 assert!(on_wasm_stack(big, None, || ()).is_ok());
2102
2103 let cached_size = TLS_STACK.with(|cache| {
2105 let s = cache.0.take();
2106 let sz = s.as_ref().map_or(0, |s| s.size());
2107 cache.0.set(s);
2108 sz
2109 });
2110 assert!(
2111 cached_size >= big,
2112 "TLS should hold the bigger stack after size bump"
2113 );
2114
2115 clear_tls_stack();
2116 drain_stack_pool();
2117 }
2118}