wasmer_vm/
threadconditions.rs

1use std::{
2    sync::atomic::AtomicPtr,
3    sync::{
4        Arc,
5        atomic::{AtomicBool, Ordering},
6    },
7    time::Duration,
8};
9
10use dashmap::DashMap;
11use fnv::FnvBuildHasher;
12use parking_lot::{Condvar, Mutex};
13use thiserror::Error;
14
15/// Error that can occur during wait/notify calls.
16// Non-exhaustive to allow for future variants without breaking changes!
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum WaiterError {
20    /// Wait/Notify is not implemented for this memory
21    Unimplemented,
22    /// To many waiter for an address
23    TooManyWaiters,
24    /// Atomic operations are disabled.
25    AtomicsDisabled,
26}
27
28impl std::fmt::Display for WaiterError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "WaiterError")
31    }
32}
33
34/// Expected value for atomic waits
35pub enum ExpectedValue {
36    /// No expected value; this is used for native waits only.
37    None,
38
39    /// 32-bit expected value
40    U32(u32),
41
42    /// 64-bit expected value
43    U64(u64),
44}
45
46/// A location in memory for a Waiter
47#[derive(Clone, Copy, Debug)]
48pub struct NotifyLocation {
49    /// The address of the Waiter location
50    pub address: u32,
51    /// The base of the memory this address is relative to
52    pub memory_base: *mut u8,
53}
54
55#[derive(Debug, Default)]
56struct NotifyMap {
57    /// If set to true, all waits will fail with an error.
58    closed: AtomicBool,
59
60    // For each wait address, we store a mutex and a condvar. The condvar is
61    // used to handle sleeping and waking, while the mutex stores the
62    // (manually-updated) number of waiters on that address. This lets us
63    // know when there are no more waiters so we can clean up the map entry.
64    // note that using a Weak here would be insufficient since it can't
65    // clean up the map entries for us, only the mutexes/condvars.
66    map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
67}
68
69/// HashMap of Waiters for the Thread/Notify opcodes
70#[derive(Debug)]
71pub struct ThreadConditions {
72    inner: Arc<NotifyMap>, // The Hasmap with the Notify for the Notify/wait opcodes
73}
74
75impl Clone for ThreadConditions {
76    fn clone(&self) -> Self {
77        Self {
78            inner: Arc::clone(&self.inner),
79        }
80    }
81}
82
83impl ThreadConditions {
84    /// Create a new ThreadConditions
85    pub fn new() -> Self {
86        Self {
87            inner: Arc::new(NotifyMap::default()),
88        }
89    }
90
91    // To implement Wait / Notify, a HasMap, behind a mutex, will be used
92    // to track the address of waiter. The key of the hashmap is based on the memory.
93    // The actual waiting is implemented with a Condvar + Mutex pair. A Weak is stored
94    // in the hashmap to at least delete the condvar and mutex when there are no
95    // waiters for a given address. Map keys are currently not cleaned up.
96
97    /// Add current thread to the waiter hash
98    ///
99    /// # Safety
100    /// If `expected` is [`ExpectedValue::None`], no safety requirements.
101    /// The notify location must have a valid base address that belongs to a memory,
102    /// and the address must be a valid offset within that memory. The offset also
103    /// must be properly aligned for the expected value type; either 4-byte aligned for
104    /// [`ExpectedValue::U32`] or 8-byte aligned for [`ExpectedValue::U64`].
105    pub unsafe fn do_wait(
106        &mut self,
107        dst: NotifyLocation,
108        expected: ExpectedValue,
109        timeout: Option<Duration>,
110    ) -> Result<u32, WaiterError> {
111        if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
112            return Err(WaiterError::AtomicsDisabled);
113        }
114
115        if self.inner.map.len() as u64 >= 1u64 << 32 {
116            return Err(WaiterError::TooManyWaiters);
117        }
118
119        // Step 1: lock the map key, so we know no one else can get/create a
120        // different Arc than the one we're getting/creating
121        let entry = self.inner.map.entry(dst.address);
122        let ref_mut = entry.or_default();
123        let arc = ref_mut.clone();
124
125        // Step 2: lock the mutex while still holding the map lock, so nobody
126        // can delete the map key or make a new Arc
127        let mut mutex_guard = arc.0.lock();
128
129        // Step 3: unlock the map key, we don't need it anymore.
130        drop(ref_mut);
131
132        // Once we lock the mutex, we can check the expected value. A notifying
133        // thread will have written an updated value to the address *before*
134        // doing the notify call, and the call has to acquire the same lock we're
135        // holding. This means we can't miss an update to the expected value that
136        // would prevent us from sleeping.
137        // This logic mirrors how the linux kernel's futex syscall works, so see
138        // the documentation on that if I made zero sense here.
139
140        // Safety: the function's safety contract ensures that the memory location is valid
141        // and can be dereferenced.
142        let should_sleep = match expected {
143            ExpectedValue::None => true,
144            ExpectedValue::U32(expected_val) => unsafe {
145                let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
146                let atomic_src = AtomicPtr::new(src);
147                let read_val = *atomic_src.load(Ordering::Acquire);
148                read_val == expected_val
149            },
150            ExpectedValue::U64(expected_val) => unsafe {
151                let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
152                let atomic_src = AtomicPtr::new(src);
153                let read_val = *atomic_src.load(Ordering::Acquire);
154                read_val == expected_val
155            },
156        };
157
158        let ret = if should_sleep {
159            *mutex_guard += 1;
160
161            let ret = if let Some(timeout) = timeout {
162                let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
163                if timeout.timed_out() {
164                    2 // timeout
165                } else {
166                    0 // notified
167                }
168            } else {
169                arc.1.wait(&mut mutex_guard);
170                0
171            };
172
173            *mutex_guard -= 1;
174
175            ret
176        } else {
177            1 // value mismatch
178        };
179
180        {
181            // Note we use two sets of locks; one for the map itself, and one per
182            // wait address. Locking order must stay consistent at all times: map
183            // first, then mutex. So we have to drop the mutex guard here and then
184            // reacquire it after locking the map key to avoid deadlocks.
185            drop(mutex_guard);
186
187            // Same as above, first lock the map key...
188            let entry = self.inner.map.entry(dst.address);
189            if let dashmap::Entry::Occupied(occupied) = entry {
190                // ... then lock the mutex.
191                let arc = occupied.get().clone();
192                let mutex_guard = arc.0.lock();
193
194                if *mutex_guard == 0 {
195                    // No more waiters, remove the map entry.
196                    occupied.remove();
197                }
198            }
199        }
200
201        Ok(ret)
202    }
203
204    /// Notify waiters from the wait list
205    pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
206        let mut count_token = 0u32;
207        if let Some(v) = self.inner.map.get(&dst) {
208            let mutex_guard = v.0.lock();
209            for _ in 0..count {
210                if !v.1.notify_one() {
211                    break;
212                }
213                count_token += 1;
214            }
215            drop(mutex_guard);
216        }
217        count_token
218    }
219
220    /// Wake all the waiters, *without* marking them as notified.
221    ///
222    /// Useful on shutdown to resume execution in all waiters.
223    pub fn wake_all_atomic_waiters(&self) {
224        for item in self.inner.map.iter_mut() {
225            let arc = item.value();
226            let _mutex_guard = arc.0.lock();
227            arc.1.notify_all();
228        }
229    }
230
231    /// Disable the use of atomics, leading to all atomic waits failing with
232    /// an error, which leads to a Webassembly trap.
233    ///
234    /// Useful for force-closing instances that keep waiting on atomics.
235    pub fn disable_atomics(&self) {
236        self.inner
237            .closed
238            .store(true, std::sync::atomic::Ordering::Release);
239        self.wake_all_atomic_waiters();
240    }
241
242    /// Get a weak handle to this `ThreadConditions` instance.
243    ///
244    /// See [`ThreadConditionsHandle`] for more information.
245    pub fn downgrade(&self) -> ThreadConditionsHandle {
246        ThreadConditionsHandle {
247            inner: Arc::downgrade(&self.inner),
248        }
249    }
250}
251
252/// A weak handle to a `ThreadConditions` instance, which does not prolong its
253/// lifetime.
254///
255/// Internally holds a [`std::sync::Weak`] pointer.
256pub struct ThreadConditionsHandle {
257    inner: std::sync::Weak<NotifyMap>,
258}
259
260impl ThreadConditionsHandle {
261    /// Attempt to upgrade this handle to a strong reference.
262    ///
263    /// Returns `None` if the original `ThreadConditions` instance has been dropped.
264    pub fn upgrade(&self) -> Option<ThreadConditions> {
265        self.inner.upgrade().map(|inner| ThreadConditions { inner })
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn threadconditions_notify_nowaiters() {
275        let mut conditions = ThreadConditions::new();
276        let ret = conditions.do_notify(0, 1);
277        assert_eq!(ret, 0);
278    }
279
280    #[test]
281    fn threadconditions_notify_1waiter() {
282        use std::thread;
283
284        let mut conditions = ThreadConditions::new();
285        let mut threadcond = conditions.clone();
286
287        thread::spawn(move || {
288            let dst = NotifyLocation {
289                address: 0,
290                memory_base: std::ptr::null_mut(),
291            };
292            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
293            assert_eq!(ret, 0);
294        });
295        thread::sleep(Duration::from_millis(10));
296        let ret = conditions.do_notify(0, 1);
297        assert_eq!(ret, 1);
298    }
299
300    #[test]
301    fn threadconditions_notify_waiter_timeout() {
302        use std::thread;
303
304        let mut conditions = ThreadConditions::new();
305        let mut threadcond = conditions.clone();
306
307        thread::spawn(move || {
308            let dst = NotifyLocation {
309                address: 0,
310                memory_base: std::ptr::null_mut(),
311            };
312            let ret = unsafe {
313                threadcond
314                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
315                    .unwrap()
316            };
317            assert_eq!(ret, 2);
318        });
319        thread::sleep(Duration::from_millis(50));
320        let ret = conditions.do_notify(0, 1);
321        assert_eq!(ret, 0);
322    }
323
324    #[test]
325    fn threadconditions_notify_waiter_mismatch() {
326        use std::thread;
327
328        let mut conditions = ThreadConditions::new();
329        let mut threadcond = conditions.clone();
330
331        thread::spawn(move || {
332            let dst = NotifyLocation {
333                address: 8,
334                memory_base: std::ptr::null_mut(),
335            };
336            let ret = unsafe {
337                threadcond
338                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
339                    .unwrap()
340            };
341            assert_eq!(ret, 2);
342        });
343        thread::sleep(Duration::from_millis(1));
344        let ret = conditions.do_notify(0, 1);
345        assert_eq!(ret, 0);
346        thread::sleep(Duration::from_millis(100));
347    }
348
349    #[test]
350    fn threadconditions_notify_2waiters() {
351        use std::thread;
352
353        let mut conditions = ThreadConditions::new();
354        let mut threadcond = conditions.clone();
355        let mut threadcond2 = conditions.clone();
356
357        thread::spawn(move || {
358            let dst = NotifyLocation {
359                address: 0,
360                memory_base: std::ptr::null_mut(),
361            };
362            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
363            assert_eq!(ret, 0);
364        });
365        thread::spawn(move || {
366            let dst = NotifyLocation {
367                address: 0,
368                memory_base: std::ptr::null_mut(),
369            };
370            let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
371            assert_eq!(ret, 0);
372        });
373        thread::sleep(Duration::from_millis(20));
374        let ret = conditions.do_notify(0, 5);
375        assert_eq!(ret, 2);
376    }
377
378    #[test]
379    fn threadconditions_value_mismatch() {
380        let mut conditions = ThreadConditions::new();
381        let mut data: u32 = 42;
382        let dst = NotifyLocation {
383            address: 0,
384            memory_base: (&mut data as *mut u32) as *mut u8,
385        };
386        let ret = unsafe {
387            conditions
388                .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
389                .unwrap()
390        };
391        assert_eq!(ret, 1);
392    }
393}