wasmer_vm/
threadconditions.rs1use std::{
2 sync::atomic::AtomicPtr,
3 sync::{
4 Arc,
5 atomic::{AtomicBool, Ordering},
6 },
7 time::Duration,
8};
9
10use dashmap::DashMap;
11use fnv::FnvBuildHasher;
12use parking_lot::{Condvar, Mutex};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum WaiterError {
20 Unimplemented,
22 TooManyWaiters,
24 AtomicsDisabled,
26}
27
28impl std::fmt::Display for WaiterError {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "WaiterError")
31 }
32}
33
34pub enum ExpectedValue {
36 None,
38
39 U32(u32),
41
42 U64(u64),
44}
45
46#[derive(Clone, Copy, Debug)]
48pub struct NotifyLocation {
49 pub address: u32,
51 pub memory_base: *mut u8,
53}
54
55#[derive(Debug, Default)]
56struct NotifyMap {
57 closed: AtomicBool,
59
60 map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
67}
68
69#[derive(Debug)]
71pub struct ThreadConditions {
72 inner: Arc<NotifyMap>, }
74
75impl Clone for ThreadConditions {
76 fn clone(&self) -> Self {
77 Self {
78 inner: Arc::clone(&self.inner),
79 }
80 }
81}
82
83impl ThreadConditions {
84 pub fn new() -> Self {
86 Self {
87 inner: Arc::new(NotifyMap::default()),
88 }
89 }
90
91 pub unsafe fn do_wait(
106 &mut self,
107 dst: NotifyLocation,
108 expected: ExpectedValue,
109 timeout: Option<Duration>,
110 ) -> Result<u32, WaiterError> {
111 if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
112 return Err(WaiterError::AtomicsDisabled);
113 }
114
115 if self.inner.map.len() as u64 >= 1u64 << 32 {
116 return Err(WaiterError::TooManyWaiters);
117 }
118
119 let entry = self.inner.map.entry(dst.address);
122 let ref_mut = entry.or_default();
123 let arc = ref_mut.clone();
124
125 let mut mutex_guard = arc.0.lock();
128
129 drop(ref_mut);
131
132 let should_sleep = match expected {
143 ExpectedValue::None => true,
144 ExpectedValue::U32(expected_val) => unsafe {
145 let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
146 let atomic_src = AtomicPtr::new(src);
147 let read_val = *atomic_src.load(Ordering::Acquire);
148 read_val == expected_val
149 },
150 ExpectedValue::U64(expected_val) => unsafe {
151 let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
152 let atomic_src = AtomicPtr::new(src);
153 let read_val = *atomic_src.load(Ordering::Acquire);
154 read_val == expected_val
155 },
156 };
157
158 let ret = if should_sleep {
159 *mutex_guard += 1;
160
161 let ret = if let Some(timeout) = timeout {
162 let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
163 if timeout.timed_out() {
164 2 } else {
166 0 }
168 } else {
169 arc.1.wait(&mut mutex_guard);
170 0
171 };
172
173 *mutex_guard -= 1;
174
175 ret
176 } else {
177 1 };
179
180 {
181 drop(mutex_guard);
186
187 let entry = self.inner.map.entry(dst.address);
189 if let dashmap::Entry::Occupied(occupied) = entry {
190 let arc = occupied.get().clone();
192 let mutex_guard = arc.0.lock();
193
194 if *mutex_guard == 0 {
195 occupied.remove();
197 }
198 }
199 }
200
201 Ok(ret)
202 }
203
204 pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
206 let mut count_token = 0u32;
207 if let Some(v) = self.inner.map.get(&dst) {
208 let mutex_guard = v.0.lock();
209 for _ in 0..count {
210 if !v.1.notify_one() {
211 break;
212 }
213 count_token += 1;
214 }
215 drop(mutex_guard);
216 }
217 count_token
218 }
219
220 pub fn wake_all_atomic_waiters(&self) {
224 for item in self.inner.map.iter_mut() {
225 let arc = item.value();
226 let _mutex_guard = arc.0.lock();
227 arc.1.notify_all();
228 }
229 }
230
231 pub fn disable_atomics(&self) {
236 self.inner
237 .closed
238 .store(true, std::sync::atomic::Ordering::Release);
239 self.wake_all_atomic_waiters();
240 }
241
242 pub fn downgrade(&self) -> ThreadConditionsHandle {
246 ThreadConditionsHandle {
247 inner: Arc::downgrade(&self.inner),
248 }
249 }
250}
251
252pub struct ThreadConditionsHandle {
257 inner: std::sync::Weak<NotifyMap>,
258}
259
260impl ThreadConditionsHandle {
261 pub fn upgrade(&self) -> Option<ThreadConditions> {
265 self.inner.upgrade().map(|inner| ThreadConditions { inner })
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn threadconditions_notify_nowaiters() {
275 let mut conditions = ThreadConditions::new();
276 let ret = conditions.do_notify(0, 1);
277 assert_eq!(ret, 0);
278 }
279
280 #[test]
281 fn threadconditions_notify_1waiter() {
282 use std::thread;
283
284 let mut conditions = ThreadConditions::new();
285 let mut threadcond = conditions.clone();
286
287 thread::spawn(move || {
288 let dst = NotifyLocation {
289 address: 0,
290 memory_base: std::ptr::null_mut(),
291 };
292 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
293 assert_eq!(ret, 0);
294 });
295 thread::sleep(Duration::from_millis(10));
296 let ret = conditions.do_notify(0, 1);
297 assert_eq!(ret, 1);
298 }
299
300 #[test]
301 fn threadconditions_notify_waiter_timeout() {
302 use std::thread;
303
304 let mut conditions = ThreadConditions::new();
305 let mut threadcond = conditions.clone();
306
307 thread::spawn(move || {
308 let dst = NotifyLocation {
309 address: 0,
310 memory_base: std::ptr::null_mut(),
311 };
312 let ret = unsafe {
313 threadcond
314 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
315 .unwrap()
316 };
317 assert_eq!(ret, 2);
318 });
319 thread::sleep(Duration::from_millis(50));
320 let ret = conditions.do_notify(0, 1);
321 assert_eq!(ret, 0);
322 }
323
324 #[test]
325 fn threadconditions_notify_waiter_mismatch() {
326 use std::thread;
327
328 let mut conditions = ThreadConditions::new();
329 let mut threadcond = conditions.clone();
330
331 thread::spawn(move || {
332 let dst = NotifyLocation {
333 address: 8,
334 memory_base: std::ptr::null_mut(),
335 };
336 let ret = unsafe {
337 threadcond
338 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
339 .unwrap()
340 };
341 assert_eq!(ret, 2);
342 });
343 thread::sleep(Duration::from_millis(1));
344 let ret = conditions.do_notify(0, 1);
345 assert_eq!(ret, 0);
346 thread::sleep(Duration::from_millis(100));
347 }
348
349 #[test]
350 fn threadconditions_notify_2waiters() {
351 use std::thread;
352
353 let mut conditions = ThreadConditions::new();
354 let mut threadcond = conditions.clone();
355 let mut threadcond2 = conditions.clone();
356
357 thread::spawn(move || {
358 let dst = NotifyLocation {
359 address: 0,
360 memory_base: std::ptr::null_mut(),
361 };
362 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
363 assert_eq!(ret, 0);
364 });
365 thread::spawn(move || {
366 let dst = NotifyLocation {
367 address: 0,
368 memory_base: std::ptr::null_mut(),
369 };
370 let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
371 assert_eq!(ret, 0);
372 });
373 thread::sleep(Duration::from_millis(20));
374 let ret = conditions.do_notify(0, 5);
375 assert_eq!(ret, 2);
376 }
377
378 #[test]
379 fn threadconditions_value_mismatch() {
380 let mut conditions = ThreadConditions::new();
381 let mut data: u32 = 42;
382 let dst = NotifyLocation {
383 address: 0,
384 memory_base: (&mut data as *mut u32) as *mut u8,
385 };
386 let ret = unsafe {
387 conditions
388 .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
389 .unwrap()
390 };
391 assert_eq!(ret, 1);
392 }
393}