wasmer_wasix/runtime/module_cache/
thread_local.rs1use std::{cell::RefCell, collections::HashMap};
2
3use wasmer::{Engine, Module};
4
5use crate::runtime::module_cache::{CacheError, ModuleCache};
6use wasmer_types::ModuleHash;
7
8std::thread_local! {
9 static CACHED_MODULES: RefCell<HashMap<(ModuleHash, String), Module>>
10 = RefCell::new(HashMap::new());
11}
12
13#[derive(Debug, Clone, Default)]
15#[non_exhaustive]
16pub struct ThreadLocalCache {}
17
18impl ThreadLocalCache {
19 fn lookup(&self, key: ModuleHash, deterministic_id: &str) -> Option<Module> {
20 let key = (key, deterministic_id.to_string());
21 CACHED_MODULES.with(|m| m.borrow().get(&key).cloned())
22 }
23
24 fn insert(&self, key: ModuleHash, module: &Module, deterministic_id: &str) {
25 let key = (key, deterministic_id.to_string());
26 CACHED_MODULES.with(|m| m.borrow_mut().insert(key, module.clone()));
27 }
28}
29
30#[async_trait::async_trait]
31impl ModuleCache for ThreadLocalCache {
32 #[tracing::instrument(level = "debug", skip_all, fields(%key))]
33 async fn load(&self, key: ModuleHash, engine: &Engine) -> Result<Module, CacheError> {
34 match self.lookup(key, &engine.deterministic_id()) {
35 Some(m) => {
36 tracing::debug!("Cache hit!");
37 Ok(m)
38 }
39 None => Err(CacheError::NotFound),
40 }
41 }
42
43 async fn contains(&self, key: ModuleHash, engine: &Engine) -> Result<bool, CacheError> {
44 let exists = self.lookup(key, &engine.deterministic_id()).is_some();
45 Ok(exists)
46 }
47
48 #[tracing::instrument(level = "debug", skip_all, fields(%key))]
49 async fn save(
50 &self,
51 key: ModuleHash,
52 engine: &Engine,
53 module: &Module,
54 ) -> Result<(), CacheError> {
55 self.insert(key, module, &engine.deterministic_id());
56 Ok(())
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 const ADD_WAT: &[u8] = br#"(
65 module
66 (func
67 (export "add")
68 (param $x i64)
69 (param $y i64)
70 (result i64)
71 (i64.add (local.get $x) (local.get $y)))
72 )"#;
73
74 #[tokio::test(flavor = "current_thread")]
75 async fn round_trip_via_cache() {
76 let engine = Engine::default();
77 let module = Module::new(&engine, ADD_WAT).unwrap();
78 let cache = ThreadLocalCache::default();
79 let key = ModuleHash::xxhash_from_bytes([0; 8]);
80
81 cache.save(key, &engine, &module).await.unwrap();
82 let round_tripped = cache.load(key, &engine).await.unwrap();
83
84 let exports: Vec<_> = round_tripped
85 .exports()
86 .map(|export| export.name().to_string())
87 .collect();
88 assert_eq!(exports, ["add"]);
89 }
90}