wasmer_vm/
threadconditions.rs1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
5 },
6 time::Duration,
7};
8
9use dashmap::DashMap;
10use fnv::FnvBuildHasher;
11use parking_lot::{Condvar, Mutex};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
17#[non_exhaustive]
18pub enum WaiterError {
19 Unimplemented,
21 TooManyWaiters,
23 AtomicsDisabled,
25}
26
27impl std::fmt::Display for WaiterError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "WaiterError")
30 }
31}
32
33pub enum ExpectedValue {
35 None,
37
38 U32(u32),
40
41 U64(u64),
43}
44
45#[derive(Clone, Copy, Debug)]
47pub struct NotifyLocation {
48 pub address: u32,
50 pub memory_base: *mut u8,
52}
53
54#[derive(Debug, Default)]
55struct NotifyMap {
56 closed: AtomicBool,
58
59 map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
66}
67
68#[derive(Debug)]
70pub struct ThreadConditions {
71 inner: Arc<NotifyMap>, }
73
74impl Clone for ThreadConditions {
75 fn clone(&self) -> Self {
76 Self {
77 inner: Arc::clone(&self.inner),
78 }
79 }
80}
81
82impl ThreadConditions {
83 pub fn new() -> Self {
85 Self {
86 inner: Arc::new(NotifyMap::default()),
87 }
88 }
89
90 pub unsafe fn do_wait(
105 &mut self,
106 dst: NotifyLocation,
107 expected: ExpectedValue,
108 timeout: Option<Duration>,
109 ) -> Result<u32, WaiterError> {
110 if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
111 return Err(WaiterError::AtomicsDisabled);
112 }
113
114 if self.inner.map.len() as u64 >= 1u64 << 32 {
115 return Err(WaiterError::TooManyWaiters);
116 }
117
118 let entry = self.inner.map.entry(dst.address);
121 let ref_mut = entry.or_default();
122 let arc = ref_mut.clone();
123
124 let mut mutex_guard = arc.0.lock();
127
128 drop(ref_mut);
130
131 let should_sleep = match expected {
142 ExpectedValue::None => true,
143 ExpectedValue::U32(expected_val) => unsafe {
144 let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
145 let read_val = AtomicU32::from_ptr(src).load(Ordering::Acquire);
146 read_val == expected_val
147 },
148 ExpectedValue::U64(expected_val) => unsafe {
149 let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
150 let read_val = AtomicU64::from_ptr(src).load(Ordering::Acquire);
151 read_val == expected_val
152 },
153 };
154
155 let ret = if should_sleep {
156 *mutex_guard += 1;
157
158 let ret = if let Some(timeout) = timeout {
159 let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
160 if timeout.timed_out() {
161 2 } else {
163 0 }
165 } else {
166 arc.1.wait(&mut mutex_guard);
167 0
168 };
169
170 *mutex_guard -= 1;
171
172 if self.inner.closed.load(Ordering::Acquire) {
173 return Err(WaiterError::AtomicsDisabled);
174 }
175
176 ret
177 } else {
178 1 };
180
181 {
182 drop(mutex_guard);
187
188 let entry = self.inner.map.entry(dst.address);
190 if let dashmap::Entry::Occupied(occupied) = entry {
191 let arc = occupied.get().clone();
193 let mutex_guard = arc.0.lock();
194
195 if *mutex_guard == 0 {
196 occupied.remove();
198 }
199 }
200 }
201
202 Ok(ret)
203 }
204
205 pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
207 let mut count_token = 0u32;
208 if let Some(v) = self.inner.map.get(&dst) {
209 let mutex_guard = v.0.lock();
210 for _ in 0..count {
211 if !v.1.notify_one() {
212 break;
213 }
214 count_token += 1;
215 }
216 drop(mutex_guard);
217 }
218 count_token
219 }
220
221 pub fn wake_all_atomic_waiters(&self) {
225 for item in self.inner.map.iter_mut() {
226 let arc = item.value();
227 let _mutex_guard = arc.0.lock();
228 arc.1.notify_all();
229 }
230 }
231
232 pub fn disable_atomics(&self) {
239 self.inner
240 .closed
241 .store(true, std::sync::atomic::Ordering::Release);
242 self.wake_all_atomic_waiters();
243 }
244
245 pub fn downgrade(&self) -> ThreadConditionsHandle {
249 ThreadConditionsHandle {
250 inner: Arc::downgrade(&self.inner),
251 }
252 }
253}
254
255pub struct ThreadConditionsHandle {
260 inner: std::sync::Weak<NotifyMap>,
261}
262
263impl ThreadConditionsHandle {
264 pub fn upgrade(&self) -> Option<ThreadConditions> {
268 self.inner.upgrade().map(|inner| ThreadConditions { inner })
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn threadconditions_notify_nowaiters() {
278 let mut conditions = ThreadConditions::new();
279 let ret = conditions.do_notify(0, 1);
280 assert_eq!(ret, 0);
281 }
282
283 #[test]
284 fn threadconditions_notify_1waiter() {
285 use std::thread;
286
287 let mut conditions = ThreadConditions::new();
288 let mut threadcond = conditions.clone();
289
290 thread::spawn(move || {
291 let dst = NotifyLocation {
292 address: 0,
293 memory_base: std::ptr::null_mut(),
294 };
295 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
296 assert_eq!(ret, 0);
297 });
298 thread::sleep(Duration::from_millis(10));
299 let ret = conditions.do_notify(0, 1);
300 assert_eq!(ret, 1);
301 }
302
303 #[test]
304 fn threadconditions_notify_waiter_timeout() {
305 use std::thread;
306
307 let mut conditions = ThreadConditions::new();
308 let mut threadcond = conditions.clone();
309
310 thread::spawn(move || {
311 let dst = NotifyLocation {
312 address: 0,
313 memory_base: std::ptr::null_mut(),
314 };
315 let ret = unsafe {
316 threadcond
317 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
318 .unwrap()
319 };
320 assert_eq!(ret, 2);
321 });
322 thread::sleep(Duration::from_millis(50));
323 let ret = conditions.do_notify(0, 1);
324 assert_eq!(ret, 0);
325 }
326
327 #[test]
328 fn threadconditions_notify_waiter_mismatch() {
329 use std::thread;
330
331 let mut conditions = ThreadConditions::new();
332 let mut threadcond = conditions.clone();
333
334 thread::spawn(move || {
335 let dst = NotifyLocation {
336 address: 8,
337 memory_base: std::ptr::null_mut(),
338 };
339 let ret = unsafe {
340 threadcond
341 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
342 .unwrap()
343 };
344 assert_eq!(ret, 2);
345 });
346 thread::sleep(Duration::from_millis(1));
347 let ret = conditions.do_notify(0, 1);
348 assert_eq!(ret, 0);
349 thread::sleep(Duration::from_millis(100));
350 }
351
352 #[test]
353 fn threadconditions_notify_2waiters() {
354 use std::thread;
355
356 let mut conditions = ThreadConditions::new();
357 let mut threadcond = conditions.clone();
358 let mut threadcond2 = conditions.clone();
359
360 thread::spawn(move || {
361 let dst = NotifyLocation {
362 address: 0,
363 memory_base: std::ptr::null_mut(),
364 };
365 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
366 assert_eq!(ret, 0);
367 });
368 thread::spawn(move || {
369 let dst = NotifyLocation {
370 address: 0,
371 memory_base: std::ptr::null_mut(),
372 };
373 let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
374 assert_eq!(ret, 0);
375 });
376 thread::sleep(Duration::from_millis(20));
377 let ret = conditions.do_notify(0, 5);
378 assert_eq!(ret, 2);
379 }
380
381 #[test]
382 fn threadconditions_value_mismatch() {
383 let mut conditions = ThreadConditions::new();
384 let mut data: u32 = 42;
385 let dst = NotifyLocation {
386 address: 0,
387 memory_base: (&mut data as *mut u32) as *mut u8,
388 };
389 let ret = unsafe {
390 conditions
391 .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
392 .unwrap()
393 };
394 assert_eq!(ret, 1);
395 }
396}