wasmer_vm/
threadconditions.rs1use std::{
2 sync::atomic::AtomicPtr,
3 sync::{
4 Arc,
5 atomic::{AtomicBool, Ordering},
6 },
7 time::Duration,
8};
9
10use dashmap::DashMap;
11use fnv::FnvBuildHasher;
12use parking_lot::{Condvar, Mutex};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum WaiterError {
20 Unimplemented,
22 TooManyWaiters,
24 AtomicsDisabled,
26}
27
28impl std::fmt::Display for WaiterError {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "WaiterError")
31 }
32}
33
34pub enum ExpectedValue {
36 None,
38
39 U32(u32),
41
42 U64(u64),
44}
45
46#[derive(Clone, Copy, Debug)]
48pub struct NotifyLocation {
49 pub address: u32,
51 pub memory_base: *mut u8,
53}
54
55#[derive(Debug, Default)]
56struct NotifyMap {
57 closed: AtomicBool,
59
60 map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
67}
68
69#[derive(Debug)]
71pub struct ThreadConditions {
72 inner: Arc<NotifyMap>, }
74
75impl Clone for ThreadConditions {
76 fn clone(&self) -> Self {
77 Self {
78 inner: Arc::clone(&self.inner),
79 }
80 }
81}
82
83impl ThreadConditions {
84 pub fn new() -> Self {
86 Self {
87 inner: Arc::new(NotifyMap::default()),
88 }
89 }
90
91 pub unsafe fn do_wait(
106 &mut self,
107 dst: NotifyLocation,
108 expected: ExpectedValue,
109 timeout: Option<Duration>,
110 ) -> Result<u32, WaiterError> {
111 if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
112 return Err(WaiterError::AtomicsDisabled);
113 }
114
115 if self.inner.map.len() as u64 >= 1u64 << 32 {
116 return Err(WaiterError::TooManyWaiters);
117 }
118
119 let entry = self.inner.map.entry(dst.address);
122 let ref_mut = entry.or_default();
123 let arc = ref_mut.clone();
124
125 let mut mutex_guard = arc.0.lock();
128
129 drop(ref_mut);
131
132 let should_sleep = match expected {
143 ExpectedValue::None => true,
144 ExpectedValue::U32(expected_val) => unsafe {
145 let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
146 let atomic_src = AtomicPtr::new(src);
147 let read_val = *atomic_src.load(Ordering::Acquire);
148 read_val == expected_val
149 },
150 ExpectedValue::U64(expected_val) => unsafe {
151 let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
152 let atomic_src = AtomicPtr::new(src);
153 let read_val = *atomic_src.load(Ordering::Acquire);
154 read_val == expected_val
155 },
156 };
157
158 let ret = if should_sleep {
159 *mutex_guard += 1;
160
161 let ret = if let Some(timeout) = timeout {
162 let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
163 if timeout.timed_out() {
164 2 } else {
166 0 }
168 } else {
169 arc.1.wait(&mut mutex_guard);
170 0
171 };
172
173 *mutex_guard -= 1;
174
175 if self.inner.closed.load(Ordering::Acquire) {
176 return Err(WaiterError::AtomicsDisabled);
177 }
178
179 ret
180 } else {
181 1 };
183
184 {
185 drop(mutex_guard);
190
191 let entry = self.inner.map.entry(dst.address);
193 if let dashmap::Entry::Occupied(occupied) = entry {
194 let arc = occupied.get().clone();
196 let mutex_guard = arc.0.lock();
197
198 if *mutex_guard == 0 {
199 occupied.remove();
201 }
202 }
203 }
204
205 Ok(ret)
206 }
207
208 pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
210 let mut count_token = 0u32;
211 if let Some(v) = self.inner.map.get(&dst) {
212 let mutex_guard = v.0.lock();
213 for _ in 0..count {
214 if !v.1.notify_one() {
215 break;
216 }
217 count_token += 1;
218 }
219 drop(mutex_guard);
220 }
221 count_token
222 }
223
224 pub fn wake_all_atomic_waiters(&self) {
228 for item in self.inner.map.iter_mut() {
229 let arc = item.value();
230 let _mutex_guard = arc.0.lock();
231 arc.1.notify_all();
232 }
233 }
234
235 pub fn disable_atomics(&self) {
242 self.inner
243 .closed
244 .store(true, std::sync::atomic::Ordering::Release);
245 self.wake_all_atomic_waiters();
246 }
247
248 pub fn downgrade(&self) -> ThreadConditionsHandle {
252 ThreadConditionsHandle {
253 inner: Arc::downgrade(&self.inner),
254 }
255 }
256}
257
258pub struct ThreadConditionsHandle {
263 inner: std::sync::Weak<NotifyMap>,
264}
265
266impl ThreadConditionsHandle {
267 pub fn upgrade(&self) -> Option<ThreadConditions> {
271 self.inner.upgrade().map(|inner| ThreadConditions { inner })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn threadconditions_notify_nowaiters() {
281 let mut conditions = ThreadConditions::new();
282 let ret = conditions.do_notify(0, 1);
283 assert_eq!(ret, 0);
284 }
285
286 #[test]
287 fn threadconditions_notify_1waiter() {
288 use std::thread;
289
290 let mut conditions = ThreadConditions::new();
291 let mut threadcond = conditions.clone();
292
293 thread::spawn(move || {
294 let dst = NotifyLocation {
295 address: 0,
296 memory_base: std::ptr::null_mut(),
297 };
298 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
299 assert_eq!(ret, 0);
300 });
301 thread::sleep(Duration::from_millis(10));
302 let ret = conditions.do_notify(0, 1);
303 assert_eq!(ret, 1);
304 }
305
306 #[test]
307 fn threadconditions_notify_waiter_timeout() {
308 use std::thread;
309
310 let mut conditions = ThreadConditions::new();
311 let mut threadcond = conditions.clone();
312
313 thread::spawn(move || {
314 let dst = NotifyLocation {
315 address: 0,
316 memory_base: std::ptr::null_mut(),
317 };
318 let ret = unsafe {
319 threadcond
320 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
321 .unwrap()
322 };
323 assert_eq!(ret, 2);
324 });
325 thread::sleep(Duration::from_millis(50));
326 let ret = conditions.do_notify(0, 1);
327 assert_eq!(ret, 0);
328 }
329
330 #[test]
331 fn threadconditions_notify_waiter_mismatch() {
332 use std::thread;
333
334 let mut conditions = ThreadConditions::new();
335 let mut threadcond = conditions.clone();
336
337 thread::spawn(move || {
338 let dst = NotifyLocation {
339 address: 8,
340 memory_base: std::ptr::null_mut(),
341 };
342 let ret = unsafe {
343 threadcond
344 .do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
345 .unwrap()
346 };
347 assert_eq!(ret, 2);
348 });
349 thread::sleep(Duration::from_millis(1));
350 let ret = conditions.do_notify(0, 1);
351 assert_eq!(ret, 0);
352 thread::sleep(Duration::from_millis(100));
353 }
354
355 #[test]
356 fn threadconditions_notify_2waiters() {
357 use std::thread;
358
359 let mut conditions = ThreadConditions::new();
360 let mut threadcond = conditions.clone();
361 let mut threadcond2 = conditions.clone();
362
363 thread::spawn(move || {
364 let dst = NotifyLocation {
365 address: 0,
366 memory_base: std::ptr::null_mut(),
367 };
368 let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
369 assert_eq!(ret, 0);
370 });
371 thread::spawn(move || {
372 let dst = NotifyLocation {
373 address: 0,
374 memory_base: std::ptr::null_mut(),
375 };
376 let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
377 assert_eq!(ret, 0);
378 });
379 thread::sleep(Duration::from_millis(20));
380 let ret = conditions.do_notify(0, 5);
381 assert_eq!(ret, 2);
382 }
383
384 #[test]
385 fn threadconditions_value_mismatch() {
386 let mut conditions = ThreadConditions::new();
387 let mut data: u32 = 42;
388 let dst = NotifyLocation {
389 address: 0,
390 memory_base: (&mut data as *mut u32) as *mut u8,
391 };
392 let ret = unsafe {
393 conditions
394 .do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
395 .unwrap()
396 };
397 assert_eq!(ret, 1);
398 }
399}