wasmer_vm/
threadconditions.rs

1use std::{
2    sync::atomic::AtomicPtr,
3    sync::{
4        Arc,
5        atomic::{AtomicBool, Ordering},
6    },
7    time::Duration,
8};
9
10use dashmap::DashMap;
11use fnv::FnvBuildHasher;
12use parking_lot::{Condvar, Mutex};
13use thiserror::Error;
14
15/// Error that can occur during wait/notify calls.
16// Non-exhaustive to allow for future variants without breaking changes!
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum WaiterError {
20    /// Wait/Notify is not implemented for this memory
21    Unimplemented,
22    /// To many waiter for an address
23    TooManyWaiters,
24    /// Atomic operations are disabled.
25    AtomicsDisabled,
26}
27
28impl std::fmt::Display for WaiterError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "WaiterError")
31    }
32}
33
34/// Expected value for atomic waits
35pub enum ExpectedValue {
36    /// No expected value; this is used for native waits only.
37    None,
38
39    /// 32-bit expected value
40    U32(u32),
41
42    /// 64-bit expected value
43    U64(u64),
44}
45
46/// A location in memory for a Waiter
47#[derive(Clone, Copy, Debug)]
48pub struct NotifyLocation {
49    /// The address of the Waiter location
50    pub address: u32,
51    /// The base of the memory this address is relative to
52    pub memory_base: *mut u8,
53}
54
55#[derive(Debug, Default)]
56struct NotifyMap {
57    /// If set to true, all waits will fail with an error.
58    closed: AtomicBool,
59
60    // For each wait address, we store a mutex and a condvar. The condvar is
61    // used to handle sleeping and waking, while the mutex stores the
62    // (manually-updated) number of waiters on that address. This lets us
63    // know when there are no more waiters so we can clean up the map entry.
64    // note that using a Weak here would be insufficient since it can't
65    // clean up the map entries for us, only the mutexes/condvars.
66    map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
67}
68
69/// HashMap of Waiters for the Thread/Notify opcodes
70#[derive(Debug)]
71pub struct ThreadConditions {
72    inner: Arc<NotifyMap>, // The Hasmap with the Notify for the Notify/wait opcodes
73}
74
75impl Clone for ThreadConditions {
76    fn clone(&self) -> Self {
77        Self {
78            inner: Arc::clone(&self.inner),
79        }
80    }
81}
82
83impl ThreadConditions {
84    /// Create a new ThreadConditions
85    pub fn new() -> Self {
86        Self {
87            inner: Arc::new(NotifyMap::default()),
88        }
89    }
90
91    // To implement Wait / Notify, a HasMap, behind a mutex, will be used
92    // to track the address of waiter. The key of the hashmap is based on the memory.
93    // The actual waiting is implemented with a Condvar + Mutex pair. A Weak is stored
94    // in the hashmap to at least delete the condvar and mutex when there are no
95    // waiters for a given address. Map keys are currently not cleaned up.
96
97    /// Add current thread to the waiter hash
98    ///
99    /// # Safety
100    /// If `expected` is [`ExpectedValue::None`], no safety requirements.
101    /// The notify location must have a valid base address that belongs to a memory,
102    /// and the address must be a valid offset within that memory. The offset also
103    /// must be properly aligned for the expected value type; either 4-byte aligned for
104    /// [`ExpectedValue::U32`] or 8-byte aligned for [`ExpectedValue::U64`].
105    pub unsafe fn do_wait(
106        &mut self,
107        dst: NotifyLocation,
108        expected: ExpectedValue,
109        timeout: Option<Duration>,
110    ) -> Result<u32, WaiterError> {
111        if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
112            return Err(WaiterError::AtomicsDisabled);
113        }
114
115        if self.inner.map.len() as u64 >= 1u64 << 32 {
116            return Err(WaiterError::TooManyWaiters);
117        }
118
119        // Step 1: lock the map key, so we know no one else can get/create a
120        // different Arc than the one we're getting/creating
121        let entry = self.inner.map.entry(dst.address);
122        let ref_mut = entry.or_default();
123        let arc = ref_mut.clone();
124
125        // Step 2: lock the mutex while still holding the map lock, so nobody
126        // can delete the map key or make a new Arc
127        let mut mutex_guard = arc.0.lock();
128
129        // Step 3: unlock the map key, we don't need it anymore.
130        drop(ref_mut);
131
132        // Once we lock the mutex, we can check the expected value. A notifying
133        // thread will have written an updated value to the address *before*
134        // doing the notify call, and the call has to acquire the same lock we're
135        // holding. This means we can't miss an update to the expected value that
136        // would prevent us from sleeping.
137        // This logic mirrors how the linux kernel's futex syscall works, so see
138        // the documentation on that if I made zero sense here.
139
140        // Safety: the function's safety contract ensures that the memory location is valid
141        // and can be dereferenced.
142        let should_sleep = match expected {
143            ExpectedValue::None => true,
144            ExpectedValue::U32(expected_val) => unsafe {
145                let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
146                let atomic_src = AtomicPtr::new(src);
147                let read_val = *atomic_src.load(Ordering::Acquire);
148                read_val == expected_val
149            },
150            ExpectedValue::U64(expected_val) => unsafe {
151                let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
152                let atomic_src = AtomicPtr::new(src);
153                let read_val = *atomic_src.load(Ordering::Acquire);
154                read_val == expected_val
155            },
156        };
157
158        let ret = if should_sleep {
159            *mutex_guard += 1;
160
161            let ret = if let Some(timeout) = timeout {
162                let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
163                if timeout.timed_out() {
164                    2 // timeout
165                } else {
166                    0 // notified
167                }
168            } else {
169                arc.1.wait(&mut mutex_guard);
170                0
171            };
172
173            *mutex_guard -= 1;
174
175            if self.inner.closed.load(Ordering::Acquire) {
176                return Err(WaiterError::AtomicsDisabled);
177            }
178
179            ret
180        } else {
181            1 // value mismatch
182        };
183
184        {
185            // Note we use two sets of locks; one for the map itself, and one per
186            // wait address. Locking order must stay consistent at all times: map
187            // first, then mutex. So we have to drop the mutex guard here and then
188            // reacquire it after locking the map key to avoid deadlocks.
189            drop(mutex_guard);
190
191            // Same as above, first lock the map key...
192            let entry = self.inner.map.entry(dst.address);
193            if let dashmap::Entry::Occupied(occupied) = entry {
194                // ... then lock the mutex.
195                let arc = occupied.get().clone();
196                let mutex_guard = arc.0.lock();
197
198                if *mutex_guard == 0 {
199                    // No more waiters, remove the map entry.
200                    occupied.remove();
201                }
202            }
203        }
204
205        Ok(ret)
206    }
207
208    /// Notify waiters from the wait list
209    pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
210        let mut count_token = 0u32;
211        if let Some(v) = self.inner.map.get(&dst) {
212            let mutex_guard = v.0.lock();
213            for _ in 0..count {
214                if !v.1.notify_one() {
215                    break;
216                }
217                count_token += 1;
218            }
219            drop(mutex_guard);
220        }
221        count_token
222    }
223
224    /// Wake all the waiters, *without* marking them as notified.
225    ///
226    /// Useful on shutdown to resume execution in all waiters.
227    pub fn wake_all_atomic_waiters(&self) {
228        for item in self.inner.map.iter_mut() {
229            let arc = item.value();
230            let _mutex_guard = arc.0.lock();
231            arc.1.notify_all();
232        }
233    }
234
235    /// Disable the use of atomics, leading to all atomic waits failing with
236    /// an error, which leads to a Webassembly trap.
237    ///
238    /// NOTE: will also wake up all current waiters.
239    ///
240    /// Useful for force-closing instances that keep waiting on atomics.
241    pub fn disable_atomics(&self) {
242        self.inner
243            .closed
244            .store(true, std::sync::atomic::Ordering::Release);
245        self.wake_all_atomic_waiters();
246    }
247
248    /// Get a weak handle to this `ThreadConditions` instance.
249    ///
250    /// See [`ThreadConditionsHandle`] for more information.
251    pub fn downgrade(&self) -> ThreadConditionsHandle {
252        ThreadConditionsHandle {
253            inner: Arc::downgrade(&self.inner),
254        }
255    }
256}
257
258/// A weak handle to a `ThreadConditions` instance, which does not prolong its
259/// lifetime.
260///
261/// Internally holds a [`std::sync::Weak`] pointer.
262pub struct ThreadConditionsHandle {
263    inner: std::sync::Weak<NotifyMap>,
264}
265
266impl ThreadConditionsHandle {
267    /// Attempt to upgrade this handle to a strong reference.
268    ///
269    /// Returns `None` if the original `ThreadConditions` instance has been dropped.
270    pub fn upgrade(&self) -> Option<ThreadConditions> {
271        self.inner.upgrade().map(|inner| ThreadConditions { inner })
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn threadconditions_notify_nowaiters() {
281        let mut conditions = ThreadConditions::new();
282        let ret = conditions.do_notify(0, 1);
283        assert_eq!(ret, 0);
284    }
285
286    #[test]
287    fn threadconditions_notify_1waiter() {
288        use std::thread;
289
290        let mut conditions = ThreadConditions::new();
291        let mut threadcond = conditions.clone();
292
293        thread::spawn(move || {
294            let dst = NotifyLocation {
295                address: 0,
296                memory_base: std::ptr::null_mut(),
297            };
298            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
299            assert_eq!(ret, 0);
300        });
301        thread::sleep(Duration::from_millis(10));
302        let ret = conditions.do_notify(0, 1);
303        assert_eq!(ret, 1);
304    }
305
306    #[test]
307    fn threadconditions_notify_waiter_timeout() {
308        use std::thread;
309
310        let mut conditions = ThreadConditions::new();
311        let mut threadcond = conditions.clone();
312
313        thread::spawn(move || {
314            let dst = NotifyLocation {
315                address: 0,
316                memory_base: std::ptr::null_mut(),
317            };
318            let ret = unsafe {
319                threadcond
320                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
321                    .unwrap()
322            };
323            assert_eq!(ret, 2);
324        });
325        thread::sleep(Duration::from_millis(50));
326        let ret = conditions.do_notify(0, 1);
327        assert_eq!(ret, 0);
328    }
329
330    #[test]
331    fn threadconditions_notify_waiter_mismatch() {
332        use std::thread;
333
334        let mut conditions = ThreadConditions::new();
335        let mut threadcond = conditions.clone();
336
337        thread::spawn(move || {
338            let dst = NotifyLocation {
339                address: 8,
340                memory_base: std::ptr::null_mut(),
341            };
342            let ret = unsafe {
343                threadcond
344                    .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
345                    .unwrap()
346            };
347            assert_eq!(ret, 2);
348        });
349        thread::sleep(Duration::from_millis(1));
350        let ret = conditions.do_notify(0, 1);
351        assert_eq!(ret, 0);
352        thread::sleep(Duration::from_millis(100));
353    }
354
355    #[test]
356    fn threadconditions_notify_2waiters() {
357        use std::thread;
358
359        let mut conditions = ThreadConditions::new();
360        let mut threadcond = conditions.clone();
361        let mut threadcond2 = conditions.clone();
362
363        thread::spawn(move || {
364            let dst = NotifyLocation {
365                address: 0,
366                memory_base: std::ptr::null_mut(),
367            };
368            let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
369            assert_eq!(ret, 0);
370        });
371        thread::spawn(move || {
372            let dst = NotifyLocation {
373                address: 0,
374                memory_base: std::ptr::null_mut(),
375            };
376            let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
377            assert_eq!(ret, 0);
378        });
379        thread::sleep(Duration::from_millis(20));
380        let ret = conditions.do_notify(0, 5);
381        assert_eq!(ret, 2);
382    }
383
384    #[test]
385    fn threadconditions_value_mismatch() {
386        let mut conditions = ThreadConditions::new();
387        let mut data: u32 = 42;
388        let dst = NotifyLocation {
389            address: 0,
390            memory_base: (&mut data as *mut u32) as *mut u8,
391        };
392        let ret = unsafe {
393            conditions
394                .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
395                .unwrap()
396        };
397        assert_eq!(ret, 1);
398    }
399}