wasmer_wasix/runtime/module_cache/
thread_local.rs

1use std::{cell::RefCell, collections::HashMap};
2
3use wasmer::{Engine, Module};
4
5use crate::runtime::module_cache::{CacheError, ModuleCache};
6use wasmer_types::ModuleHash;
7
8std::thread_local! {
9    static CACHED_MODULES: RefCell<HashMap<(ModuleHash, String), Module>>
10        = RefCell::new(HashMap::new());
11}
12
13/// A cache that will cache modules in a thread-local variable.
14#[derive(Debug, Clone, Default)]
15#[non_exhaustive]
16pub struct ThreadLocalCache {}
17
18impl ThreadLocalCache {
19    fn lookup(&self, key: ModuleHash, deterministic_id: &str) -> Option<Module> {
20        let key = (key, deterministic_id.to_string());
21        CACHED_MODULES.with(|m| m.borrow().get(&key).cloned())
22    }
23
24    fn insert(&self, key: ModuleHash, module: &Module, deterministic_id: &str) {
25        let key = (key, deterministic_id.to_string());
26        CACHED_MODULES.with(|m| m.borrow_mut().insert(key, module.clone()));
27    }
28}
29
30#[async_trait::async_trait]
31impl ModuleCache for ThreadLocalCache {
32    #[tracing::instrument(level = "debug", skip_all, fields(%key))]
33    async fn load(&self, key: ModuleHash, engine: &Engine) -> Result<Module, CacheError> {
34        match self.lookup(key, &engine.deterministic_id()) {
35            Some(m) => {
36                tracing::debug!("Cache hit!");
37                Ok(m)
38            }
39            None => Err(CacheError::NotFound),
40        }
41    }
42
43    async fn contains(&self, key: ModuleHash, engine: &Engine) -> Result<bool, CacheError> {
44        let exists = self.lookup(key, &engine.deterministic_id()).is_some();
45        Ok(exists)
46    }
47
48    #[tracing::instrument(level = "debug", skip_all, fields(%key))]
49    async fn save(
50        &self,
51        key: ModuleHash,
52        engine: &Engine,
53        module: &Module,
54    ) -> Result<(), CacheError> {
55        self.insert(key, module, &engine.deterministic_id());
56        Ok(())
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    const ADD_WAT: &[u8] = br#"(
65        module
66            (func
67                (export "add")
68                (param $x i64)
69                (param $y i64)
70                (result i64)
71                (i64.add (local.get $x) (local.get $y)))
72        )"#;
73
74    #[tokio::test(flavor = "current_thread")]
75    async fn round_trip_via_cache() {
76        let engine = Engine::default();
77        let module = Module::new(&engine, ADD_WAT).unwrap();
78        let cache = ThreadLocalCache::default();
79        let key = ModuleHash::xxhash_from_bytes([0; 8]);
80
81        cache.save(key, &engine, &module).await.unwrap();
82        let round_tripped = cache.load(key, &engine).await.unwrap();
83
84        let exports: Vec<_> = round_tripped
85            .exports()
86            .map(|export| export.name().to_string())
87            .collect();
88        assert_eq!(exports, ["add"]);
89    }
90}