wasmer_vm/
threadconditions.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
5    },
6    time::Duration,
7};
8
9use dashmap::DashMap;
10use fnv::FnvBuildHasher;
11use parking_lot::{Condvar, Mutex};
12use thiserror::Error;
13
14/// Error that can occur during wait/notify calls.
15// Non-exhaustive to allow for future variants without breaking changes!
16#[derive(Debug, Error)]
17#[non_exhaustive]
18pub enum WaiterError {
19    /// Wait/Notify is not implemented for this memory
20    Unimplemented,
21    /// To many waiter for an address
22    TooManyWaiters,
23    /// Atomic operations are disabled.
24    AtomicsDisabled,
25}
26
27impl std::fmt::Display for WaiterError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "WaiterError")
30    }
31}
32
33/// Expected value for atomic waits
34pub enum ExpectedValue {
35    /// No expected value; this is used for native waits only.
36    None,
37
38    /// 32-bit expected value
39    U32(u32),
40
41    /// 64-bit expected value
42    U64(u64),
43}
44
45/// A location in memory for a Waiter
46#[derive(Clone, Copy, Debug)]
47pub struct NotifyLocation {
48    /// The address of the Waiter location
49    pub address: u32,
50    /// The base of the memory this address is relative to
51    pub memory_base: *mut u8,
52}
53
54#[derive(Debug, Default)]
55struct NotifyMap {
56    /// If set to true, all waits will fail with an error.
57    closed: AtomicBool,
58
59    // For each wait address, we store a mutex and a condvar. The condvar is
60    // used to handle sleeping and waking, while the mutex stores the
61    // (manually-updated) number of waiters on that address. This lets us
62    // know when there are no more waiters so we can clean up the map entry.
63    // note that using a Weak here would be insufficient since it can't
64    // clean up the map entries for us, only the mutexes/condvars.
65    map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
66}
67
68/// HashMap of Waiters for the Thread/Notify opcodes
69#[derive(Debug)]
70pub struct ThreadConditions {
71    inner: Arc<NotifyMap>, // The Hasmap with the Notify for the Notify/wait opcodes
72}
73
74impl Clone for ThreadConditions {
75    fn clone(&self) -> Self {
76        Self {
77            inner: Arc::clone(&self.inner),
78        }
79    }
80}
81
82impl ThreadConditions {
83    /// Create a new ThreadConditions
84    pub fn new() -> Self {
85        Self {
86            inner: Arc::new(NotifyMap::default()),
87        }
88    }
89
90    // To implement Wait / Notify, a HasMap, behind a mutex, will be used
91    // to track the address of waiter. The key of the hashmap is based on the memory.
92    // The actual waiting is implemented with a Condvar + Mutex pair. A Weak is stored
93    // in the hashmap to at least delete the condvar and mutex when there are no
94    // waiters for a given address. Map keys are currently not cleaned up.
95
96    /// Add current thread to the waiter hash
97    ///
98    /// # Safety
99    /// If `expected` is [`ExpectedValue::None`], no safety requirements.
100    /// The notify location must have a valid base address that belongs to a memory,
101    /// and the address must be a valid offset within that memory. The offset also
102    /// must be properly aligned for the expected value type; either 4-byte aligned for
103    /// [`ExpectedValue::U32`] or 8-byte aligned for [`ExpectedValue::U64`].
104    pub unsafe fn do_wait(
105        &mut self,
106        dst: NotifyLocation,
107        expected: ExpectedValue,
108        timeout: Option<Duration>,
109    ) -> Result<u32, WaiterError> {
110        if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
111            return Err(WaiterError::AtomicsDisabled);
112        }
113
114        if self.inner.map.len() as u64 >= 1u64 << 32 {
115            return Err(WaiterError::TooManyWaiters);
116        }
117
118        // Step 1: lock the map key, so we know no one else can get/create a
119        // different Arc than the one we're getting/creating
120        let entry = self.inner.map.entry(dst.address);
121        let ref_mut = entry.or_default();
122        let arc = ref_mut.clone();
123
124        // Step 2: lock the mutex while still holding the map lock, so nobody
125        // can delete the map key or make a new Arc
126        let mut mutex_guard = arc.0.lock();
127
128        // Step 3: unlock the map key, we don't need it anymore.
129        drop(ref_mut);
130
131        // Once we lock the mutex, we can check the expected value. A notifying
132        // thread will have written an updated value to the address *before*
133        // doing the notify call, and the call has to acquire the same lock we're
134        // holding. This means we can't miss an update to the expected value that
135        // would prevent us from sleeping.
136        // This logic mirrors how the linux kernel's futex syscall works, so see
137        // the documentation on that if I made zero sense here.
138
139        // Safety: the function's safety contract ensures that the memory location is valid
140        // and can be dereferenced.
141        let should_sleep = match expected {
142            ExpectedValue::None => true,
143            ExpectedValue::U32(expected_val) => unsafe {
144                let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
145                let read_val = AtomicU32::from_ptr(src).load(Ordering::Acquire);
146                read_val == expected_val
147            },
148            ExpectedValue::U64(expected_val) => unsafe {
149                let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
150                let read_val = AtomicU64::from_ptr(src).load(Ordering::Acquire);
151                read_val == expected_val
152            },
153        };
154
155        let ret = if should_sleep {
156            *mutex_guard += 1;
157
158            let ret = if let Some(timeout) = timeout {
159                let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
160                if timeout.timed_out() {
161                    2 // timeout
162                } else {
163                    0 // notified
164                }
165            } else {
166                arc.1.wait(&mut mutex_guard);
167                0
168            };
169
170            *mutex_guard -= 1;
171
172            if self.inner.closed.load(Ordering::Acquire) {
173                return Err(WaiterError::AtomicsDisabled);
174            }
175
176            ret
177        } else {
178            1 // value mismatch
179        };
180
181        {
182            // Note we use two sets of locks; one for the map itself, and one per
183            // wait address. Locking order must stay consistent at all times: map
184            // first, then mutex. So we have to drop the mutex guard here and then
185            // reacquire it after locking the map key to avoid deadlocks.
186            drop(mutex_guard);
187
188            // Same as above, first lock the map key...
189            let entry = self.inner.map.entry(dst.address);
190            if let dashmap::Entry::Occupied(occupied) = entry {
191                // ... then lock the mutex.
192                let arc = occupied.get().clone();
193                let mutex_guard = arc.0.lock();
194
195                if *mutex_guard == 0 {
196                    // No more waiters, remove the map entry.
197                    occupied.remove();
198                }
199            }
200        }
201
202        Ok(ret)
203    }
204
205    /// Notify waiters from the wait list
206    pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
207        let mut count_token = 0u32;
208        if let Some(v) = self.inner.map.get(&dst) {
209            let mutex_guard = v.0.lock();
210            for _ in 0..count {
211                if !v.1.notify_one() {
212                    break;
213                }
214                count_token += 1;
215            }
216            drop(mutex_guard);
217        }
218        count_token
219    }
220
221    /// Wake all the waiters, *without* marking them as notified.
222    ///
223    /// Useful on shutdown to resume execution in all waiters.
224    pub fn wake_all_atomic_waiters(&self) {
225        for item in self.inner.map.iter_mut() {
226            let arc = item.value();
227            let _mutex_guard = arc.0.lock();
228            arc.1.notify_all();
229        }
230    }
231
232    /// Disable the use of atomics, leading to all atomic waits failing with
233    /// an error, which leads to a Webassembly trap.
234    ///
235    /// NOTE: will also wake up all current waiters.
236    ///
237    /// Useful for force-closing instances that keep waiting on atomics.
238    pub fn disable_atomics(&self) {
239        self.inner
240            .closed
241            .store(true, std::sync::atomic::Ordering::Release);
242        self.wake_all_atomic_waiters();
243    }
244
245    /// Get a weak handle to this `ThreadConditions` instance.
246    ///
247    /// See [`ThreadConditionsHandle`] for more information.
248    pub fn downgrade(&self) -> ThreadConditionsHandle {
249        ThreadConditionsHandle {
250            inner: Arc::downgrade(&self.inner),
251        }
252    }
253}
254
255/// A weak handle to a `ThreadConditions` instance, which does not prolong its
256/// lifetime.
257///
258/// Internally holds a [`std::sync::Weak`] pointer.
259pub struct ThreadConditionsHandle {
260    inner: std::sync::Weak<NotifyMap>,
261}
262
263impl ThreadConditionsHandle {
264    /// Attempt to upgrade this handle to a strong reference.
265    ///
266    /// Returns `None` if the original `ThreadConditions` instance has been dropped.
267    pub fn upgrade(&self) -> Option<ThreadConditions> {
268        self.inner.upgrade().map(|inner| ThreadConditions { inner })
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn threadconditions_notify_nowaiters() {
278        let mut conditions = ThreadConditions::new();
279        let ret = conditions.do_notify(0, 1);
280        assert_eq!(ret, 0);
281    }
282
283    #[test]
284    fn threadconditions_notify_1waiter() {
285        use std::thread;
286
287        let mut conditions = ThreadConditions::new();
288        let mut threadcond = conditions.clone();
289
290        thread::spawn(move || {
291            let dst = NotifyLocation {
292                address: 0,
293                memory_base: std::ptr::null_mut(),
294            };
295            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
296            assert_eq!(ret, 0);
297        });
298        thread::sleep(Duration::from_millis(10));
299        let ret = conditions.do_notify(0, 1);
300        assert_eq!(ret, 1);
301    }
302
303    #[test]
304    fn threadconditions_notify_waiter_timeout() {
305        use std::thread;
306
307        let mut conditions = ThreadConditions::new();
308        let mut threadcond = conditions.clone();
309
310        thread::spawn(move || {
311            let dst = NotifyLocation {
312                address: 0,
313                memory_base: std::ptr::null_mut(),
314            };
315            let ret = unsafe {
316                threadcond
317                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
318                    .unwrap()
319            };
320            assert_eq!(ret, 2);
321        });
322        thread::sleep(Duration::from_millis(50));
323        let ret = conditions.do_notify(0, 1);
324        assert_eq!(ret, 0);
325    }
326
327    #[test]
328    fn threadconditions_notify_waiter_mismatch() {
329        use std::thread;
330
331        let mut conditions = ThreadConditions::new();
332        let mut threadcond = conditions.clone();
333
334        thread::spawn(move || {
335            let dst = NotifyLocation {
336                address: 8,
337                memory_base: std::ptr::null_mut(),
338            };
339            let ret = unsafe {
340                threadcond
341                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
342                    .unwrap()
343            };
344            assert_eq!(ret, 2);
345        });
346        thread::sleep(Duration::from_millis(1));
347        let ret = conditions.do_notify(0, 1);
348        assert_eq!(ret, 0);
349        thread::sleep(Duration::from_millis(100));
350    }
351
352    #[test]
353    fn threadconditions_notify_2waiters() {
354        use std::thread;
355
356        let mut conditions = ThreadConditions::new();
357        let mut threadcond = conditions.clone();
358        let mut threadcond2 = conditions.clone();
359
360        thread::spawn(move || {
361            let dst = NotifyLocation {
362                address: 0,
363                memory_base: std::ptr::null_mut(),
364            };
365            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
366            assert_eq!(ret, 0);
367        });
368        thread::spawn(move || {
369            let dst = NotifyLocation {
370                address: 0,
371                memory_base: std::ptr::null_mut(),
372            };
373            let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
374            assert_eq!(ret, 0);
375        });
376        thread::sleep(Duration::from_millis(20));
377        let ret = conditions.do_notify(0, 5);
378        assert_eq!(ret, 2);
379    }
380
381    #[test]
382    fn threadconditions_value_mismatch() {
383        let mut conditions = ThreadConditions::new();
384        let mut data: u32 = 42;
385        let dst = NotifyLocation {
386            address: 0,
387            memory_base: (&mut data as *mut u32) as *mut u8,
388        };
389        let ret = unsafe {
390            conditions
391                .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
392                .unwrap()
393        };
394        assert_eq!(ret, 1);
395    }
396}