wasmer_wasix/runtime/package_loader/
builtin_loader.rs

1use std::{
2    collections::HashMap,
3    io::{ErrorKind, Read, Write as _},
4    path::PathBuf,
5    sync::{Arc, RwLock},
6};
7
8use anyhow::{Context, Error, bail};
9use bytes::Bytes;
10use http::{HeaderMap, Method};
11use tempfile::NamedTempFile;
12use url::Url;
13use wasmer_package::{
14    package::WasmerPackageError,
15    utils::{from_bytes, from_disk},
16};
17use webc::DetectError;
18use webc::{Container, ContainerError};
19
20use crate::{
21    bin_factory::BinaryPackage,
22    http::{HttpClient, HttpRequest, USER_AGENT},
23    runtime::{
24        package_loader::PackageLoader,
25        resolver::{DistributionInfo, PackageSummary, Resolution, WebcHash},
26    },
27};
28
29/// The builtin [`PackageLoader`] that is used by the `wasmer` CLI and
30/// respects `$WASMER_DIR`.
31#[derive(Debug)]
32pub struct BuiltinPackageLoader {
33    client: Arc<dyn HttpClient + Send + Sync>,
34    in_memory: Option<InMemoryCache>,
35    cache: Option<FileSystemCache>,
36    /// A mapping from hostnames to tokens
37    tokens: HashMap<String, String>,
38
39    hash_validation: HashIntegrityValidationMode,
40}
41
42/// Defines how to validate package hash integrity.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum HashIntegrityValidationMode {
45    /// Do not validate anything.
46    /// Best for performance.
47    NoValidate,
48    /// Compute the image hash and produce a trace warning on hash mismatches.
49    WarnOnHashMismatch,
50    /// Compute the image hash and fail on a mismatch.
51    FailOnHashMismatch,
52}
53
54impl BuiltinPackageLoader {
55    pub fn new() -> Self {
56        BuiltinPackageLoader {
57            in_memory: Some(InMemoryCache::default()),
58            client: Arc::new(crate::http::default_http_client().unwrap()),
59            cache: None,
60            hash_validation: HashIntegrityValidationMode::NoValidate,
61            tokens: HashMap::new(),
62        }
63    }
64
65    /// Set the validation mode to apply after downloading an image.
66    ///
67    /// See [`HashIntegrityValidationMode`] for details.
68    pub fn with_hash_validation_mode(mut self, mode: HashIntegrityValidationMode) -> Self {
69        self.hash_validation = mode;
70        self
71    }
72
73    pub fn with_cache_dir(self, cache_dir: impl Into<PathBuf>) -> Self {
74        BuiltinPackageLoader {
75            cache: Some(FileSystemCache {
76                cache_dir: cache_dir.into(),
77            }),
78            ..self
79        }
80    }
81
82    /// Disable promotion of loaded containers into the in-memory cache.
83    pub fn without_in_memory_cache(self) -> Self {
84        BuiltinPackageLoader {
85            in_memory: None,
86            ..self
87        }
88    }
89
90    pub fn cache(&self) -> Option<&FileSystemCache> {
91        self.cache.as_ref()
92    }
93
94    pub fn validate_cache(
95        &self,
96        mode: CacheValidationMode,
97    ) -> Result<Vec<ImageHashMismatchError>, anyhow::Error> {
98        let cache = self
99            .cache
100            .as_ref()
101            .context("can not validate cache - no cache configured")?;
102
103        let items = cache.validate_hashes()?;
104        let mut errors = Vec::new();
105        for (path, error) in items {
106            match mode {
107                CacheValidationMode::WarnOnMismatch => {
108                    tracing::warn!(?error, "hash mismatch in cached image file");
109                }
110                CacheValidationMode::PruneOnMismatch => {
111                    tracing::warn!(?error, "deleting cached image file due to hash mismatch");
112                    match std::fs::remove_file(&path) {
113                        Ok(()) => {}
114                        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
115                        Err(fs_err) => {
116                            tracing::error!(
117                                path=%error.source,
118                                ?fs_err,
119                                "could not delete cached image file with hash mismatch"
120                            );
121                        }
122                    }
123                }
124            }
125
126            errors.push(error);
127        }
128
129        Ok(errors)
130    }
131
132    pub fn with_http_client(self, client: impl HttpClient + Send + Sync + 'static) -> Self {
133        self.with_shared_http_client(Arc::new(client))
134    }
135
136    pub fn with_shared_http_client(self, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
137        BuiltinPackageLoader { client, ..self }
138    }
139
140    pub fn with_tokens<I, K, V>(mut self, tokens: I) -> Self
141    where
142        I: IntoIterator<Item = (K, V)>,
143        K: Into<String>,
144        V: Into<String>,
145    {
146        for (hostname, token) in tokens {
147            self = self.with_token(hostname, token);
148        }
149
150        self
151    }
152
153    /// Add an API token that will be used whenever sending requests to a
154    /// particular hostname.
155    ///
156    /// Note that this uses [`Url::authority()`] when looking up tokens, so it
157    /// will match both plain hostnames (e.g. `registry.wasmer.io`) and hosts
158    /// with a port number (e.g. `localhost:8000`).
159    pub fn with_token(mut self, hostname: impl Into<String>, token: impl Into<String>) -> Self {
160        self.tokens.insert(hostname.into(), token.into());
161        self
162    }
163
164    /// Insert a container into the in-memory hash.
165    pub fn insert_cached(&self, hash: WebcHash, container: &Container) {
166        if let Some(in_memory) = &self.in_memory {
167            in_memory.save(container, hash);
168        }
169    }
170
171    /// Remove a container from the in-memory cache.
172    pub fn evict_cached(&self, hash: &WebcHash) -> Option<Container> {
173        self.in_memory
174            .as_ref()
175            .and_then(|in_memory| in_memory.remove(hash))
176    }
177
178    #[tracing::instrument(level = "debug", skip_all, fields(pkg.hash=%hash))]
179    async fn get_cached(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
180        if let Some(cached) = self
181            .in_memory
182            .as_ref()
183            .and_then(|in_memory| in_memory.lookup(hash))
184        {
185            return Ok(Some(cached));
186        }
187
188        if let Some(cache) = self.cache.as_ref()
189            && let Some(cached) = cache.lookup(hash).await?
190        {
191            if let Some(in_memory) = &self.in_memory {
192                tracing::debug!("Copying from the filesystem cache to the in-memory cache");
193                in_memory.save(&cached, *hash);
194            }
195            return Ok(Some(cached));
196        }
197
198        Ok(None)
199    }
200
201    /// Validate image contents with the specified validation mode.
202    async fn validate_hash(
203        image: &bytes::Bytes,
204        mode: HashIntegrityValidationMode,
205        info: &DistributionInfo,
206    ) -> Result<(), anyhow::Error> {
207        let info = info.clone();
208        let image = image.clone();
209        crate::spawn_blocking(move || Self::validate_hash_sync(&image, mode, &info))
210            .await
211            .context("tokio runtime failed")?
212    }
213
214    /// Validate image contents with the specified validation mode.
215    fn validate_hash_sync(
216        image: &[u8],
217        mode: HashIntegrityValidationMode,
218        info: &DistributionInfo,
219    ) -> Result<(), anyhow::Error> {
220        match mode {
221            HashIntegrityValidationMode::NoValidate => {
222                // Nothing to do.
223                Ok(())
224            }
225            HashIntegrityValidationMode::WarnOnHashMismatch => {
226                let actual_hash = WebcHash::sha256(image);
227                if actual_hash != info.webc_sha256 {
228                    tracing::warn!(%info.webc_sha256, %actual_hash, "image hash mismatch - actual image hash does not match the expected hash!");
229                }
230                Ok(())
231            }
232            HashIntegrityValidationMode::FailOnHashMismatch => {
233                let actual_hash = WebcHash::sha256(image);
234                if actual_hash != info.webc_sha256 {
235                    Err(ImageHashMismatchError {
236                        source: info.webc.to_string(),
237                        actual_hash,
238                        expected_hash: info.webc_sha256,
239                    }
240                    .into())
241                } else {
242                    Ok(())
243                }
244            }
245        }
246    }
247
248    #[tracing::instrument(level = "debug", skip_all, fields(%dist.webc, %dist.webc_sha256))]
249    async fn download(&self, dist: &DistributionInfo) -> Result<Bytes, Error> {
250        if dist.webc.scheme() == "file" {
251            match crate::runtime::resolver::utils::file_path_from_url(&dist.webc) {
252                Ok(path) => {
253                    let bytes = crate::spawn_blocking({
254                        let path = path.clone();
255                        move || std::fs::read(path)
256                    })
257                    .await?
258                    .with_context(|| format!("Unable to read \"{}\"", path.display()))?;
259
260                    let bytes = bytes::Bytes::from(bytes);
261
262                    Self::validate_hash(&bytes, self.hash_validation, dist).await?;
263
264                    return Ok(bytes);
265                }
266                Err(e) => {
267                    tracing::debug!(
268                        url=%dist.webc,
269                        error=&*e,
270                        "Unable to convert the file:// URL to a path",
271                    );
272                }
273            }
274        }
275
276        let request = HttpRequest {
277            headers: self.headers(&dist.webc),
278            url: dist.webc.clone(),
279            method: Method::GET,
280            body: None,
281            options: Default::default(),
282        };
283
284        tracing::debug!(%request.url, %request.method, "webc_package_download_start");
285        tracing::trace!(?request.headers);
286
287        let response = self.client.request(request).await?;
288
289        tracing::trace!(
290            %response.status,
291            %response.redirected,
292            ?response.headers,
293            response.len=response.body.as_ref().map(|body| body.len()),
294            "Received a response",
295        );
296
297        let url = &dist.webc;
298        if !response.is_ok() {
299            return Err(
300                crate::runtime::resolver::utils::http_error(&response).context(format!(
301                    "package download failed: GET request to \"{}\" failed with status {}",
302                    url, response.status
303                )),
304            );
305        }
306
307        let body = response.body.context("package download failed")?;
308        let body = Self::decode_response_body(&response.headers, body)
309            .context("package download failed: could not decode response body")?;
310        tracing::debug!(%url, "package_download_succeeded");
311
312        let body = bytes::Bytes::from(body);
313
314        Self::validate_hash(&body, self.hash_validation, dist).await?;
315
316        Ok(body)
317    }
318
319    fn headers(&self, url: &Url) -> HeaderMap {
320        let mut headers = HeaderMap::new();
321        headers.insert("Accept", "application/webc".parse().unwrap());
322        headers.insert("User-Agent", USER_AGENT.parse().unwrap());
323
324        // Accept compressed responses.
325        // NOTE: gzip and zstd decoding is available on native platforms.
326        // In browser platforms, the fetch implementation should automatically
327        // handle decoding of gzip/zstd responses transparently.
328        headers.insert(
329            http::header::ACCEPT_ENCODING,
330            "zstd;q=1.0, gzip;q=0.8".parse().unwrap(),
331        );
332
333        if url.has_authority()
334            && let Some(token) = self.tokens.get(url.authority())
335        {
336            let header = format!("Bearer {token}");
337            match header.parse() {
338                Ok(header) => {
339                    headers.insert(http::header::AUTHORIZATION, header);
340                }
341                Err(e) => {
342                    tracing::warn!(
343                        error = &e as &dyn std::error::Error,
344                        "An error occurred while parsing the authorization header",
345                    );
346                }
347            }
348        }
349
350        headers
351    }
352
353    /// Decode the response body according to the `Content-Encoding` header.
354    ///
355    /// * Supports `gzip` and `zstd` encodings
356    /// * Supports nested encodings (e.g. `gzip, zstd`)
357    /// * Passes through unencoded bodies or "identity" encoding unchanged
358    fn decode_response_body(headers: &HeaderMap, body: Vec<u8>) -> Result<Vec<u8>, anyhow::Error> {
359        let encodings = match headers.get(http::header::CONTENT_ENCODING) {
360            Some(header) => header
361                .to_str()
362                .context("non-utf8 content-encoding header")?
363                .split(',')
364                .map(|encoding| encoding.trim().to_ascii_lowercase())
365                .filter(|encoding| !encoding.is_empty())
366                .collect::<Vec<_>>(),
367            None => Vec::new(),
368        };
369
370        // Check if there is nothing to decode, return early.
371        // "identity" is the default encoding meaning "no encoding" (See RFC 2616 / RFC 7231)
372        if encodings.is_empty() || (encodings.len() == 1 && encodings[0] == "identity") {
373            return Ok(body);
374        }
375
376        let mut reader: Box<dyn Read> = Box::new(std::io::Cursor::new(body));
377        for encoding in encodings.iter().rev() {
378            match encoding.as_str() {
379                "gzip" => {
380                    reader = Box::new(flate2::read::GzDecoder::new(reader));
381                }
382                "zstd" => {
383                    #[cfg(not(target_arch = "wasm32"))]
384                    {
385                        reader = Box::new(
386                            zstd::stream::read::Decoder::new(reader)
387                                .context("failed to initialize zstd decoder")?,
388                        );
389                    }
390                    #[cfg(target_arch = "wasm32")]
391                    {
392                        // NOTE: in browsers this code will not be hit because
393                        // the fetch API automatically handles content decoding.
394                        bail!("zstd content-encoding is not supported on wasm32");
395                    }
396                }
397                "identity" => {}
398                other => bail!("unsupported content-encoding: {other}"),
399            }
400        }
401
402        let mut decoded = Vec::new();
403        reader
404            .read_to_end(&mut decoded)
405            .context("failed to decode response body")?;
406        Ok(decoded)
407    }
408}
409
410impl Default for BuiltinPackageLoader {
411    fn default() -> Self {
412        BuiltinPackageLoader::new()
413    }
414}
415
416#[async_trait::async_trait]
417impl PackageLoader for BuiltinPackageLoader {
418    #[tracing::instrument(
419        level="debug",
420        skip_all,
421        fields(
422            pkg=%summary.pkg.id,
423        ),
424    )]
425    async fn load(&self, summary: &PackageSummary) -> Result<Container, Error> {
426        if let Some(container) = self.get_cached(&summary.dist.webc_sha256).await? {
427            tracing::debug!("Cache hit!");
428            return Ok(container);
429        }
430
431        // looks like we had a cache miss and need to download it manually
432        let bytes = self
433            .download(&summary.dist)
434            .await
435            .with_context(|| format!("Unable to download \"{}\"", summary.dist.webc))?;
436
437        // We want to cache the container we downloaded, but we want to do it
438        // in a smart way to keep memory usage down.
439
440        if let Some(cache) = &self.cache {
441            match cache
442                .save_and_load_as_mmapped(bytes.clone(), &summary.dist)
443                .await
444            {
445                Ok(container) => {
446                    tracing::debug!("Cached to disk");
447                    if let Some(in_memory) = &self.in_memory {
448                        in_memory.save(&container, summary.dist.webc_sha256);
449                    }
450                    // The happy path - we've saved to both caches and loaded the
451                    // container from disk (hopefully using mmap) so we're done.
452                    return Ok(container);
453                }
454                Err(e) => {
455                    tracing::warn!(
456                        error=&*e,
457                        pkg=%summary.pkg.id,
458                        pkg.hash=%summary.dist.webc_sha256,
459                        pkg.url=%summary.dist.webc,
460                        "Unable to save the downloaded package to disk",
461                    );
462                }
463            }
464        }
465
466        // The sad path - looks like we don't have a filesystem cache so we'll
467        // need to keep the whole thing in memory.
468        let container = crate::spawn_blocking(move || from_bytes(bytes)).await??;
469        if let Some(in_memory) = &self.in_memory {
470            // We still want to cache it in memory, of course
471            in_memory.save(&container, summary.dist.webc_sha256);
472        }
473        Ok(container)
474    }
475
476    async fn load_package_tree(
477        &self,
478        root: &Container,
479        resolution: &Resolution,
480        root_is_local_dir: bool,
481    ) -> Result<BinaryPackage, Error> {
482        super::load_package_tree(root, self, resolution, root_is_local_dir).await
483    }
484}
485
486#[derive(Clone, Debug)]
487pub struct ImageHashMismatchError {
488    source: String,
489    expected_hash: WebcHash,
490    actual_hash: WebcHash,
491}
492
493impl std::fmt::Display for ImageHashMismatchError {
494    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        write!(
496            f,
497            "image hash mismatch! expected hash '{}', but the computed hash is '{}' (source '{}')",
498            self.expected_hash, self.actual_hash, self.source,
499        )
500    }
501}
502
503impl std::error::Error for ImageHashMismatchError {}
504
505#[derive(Clone, Copy, Debug, PartialEq, Eq)]
506pub enum CacheValidationMode {
507    /// Just emit a warning for all images where the filename doesn't match
508    /// the expected hash.
509    WarnOnMismatch,
510    /// Remove images from the cache if the filename doesn't match the actual
511    /// hash.
512    PruneOnMismatch,
513}
514
515// FIXME: This implementation will block the async runtime and should use
516// some sort of spawn_blocking() call to run it in the background.
517#[derive(Debug)]
518pub struct FileSystemCache {
519    cache_dir: PathBuf,
520}
521
522impl FileSystemCache {
523    const FILE_SUFFIX: &'static str = ".bin";
524
525    fn temp_dir(&self) -> PathBuf {
526        self.cache_dir.join("__temp__")
527    }
528
529    /// Validate that the cached image file names correspond to their actual
530    /// file content hashes.
531    fn validate_hashes(&self) -> Result<Vec<(PathBuf, ImageHashMismatchError)>, anyhow::Error> {
532        let mut items = Vec::<(PathBuf, ImageHashMismatchError)>::new();
533
534        let iter = match std::fs::read_dir(&self.cache_dir) {
535            Ok(v) => v,
536            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
537                // Cache dir does not exist, so nothing to validate.
538                return Ok(Vec::new());
539            }
540            Err(err) => {
541                return Err(err).with_context(|| {
542                    format!(
543                        "Could not read image cache dir: '{}'",
544                        self.cache_dir.display()
545                    )
546                });
547            }
548        };
549
550        for res in iter {
551            let entry = res?;
552            if !entry.file_type()?.is_file() {
553                continue;
554            }
555
556            // Extract the hash from the filename.
557
558            let hash_opt = entry
559                .file_name()
560                .to_str()
561                .and_then(|x| {
562                    let (raw_hash, _) = x.split_once(Self::FILE_SUFFIX)?;
563                    Some(raw_hash)
564                })
565                .and_then(|x| WebcHash::parse_hex(x).ok());
566            let Some(expected_hash) = hash_opt else {
567                continue;
568            };
569
570            // Compute the actual hash.
571            let path = entry.path();
572            let actual_hash = WebcHash::for_file(&path)?;
573
574            if actual_hash != expected_hash {
575                let err = ImageHashMismatchError {
576                    source: path.to_string_lossy().to_string(),
577                    actual_hash,
578                    expected_hash,
579                };
580                items.push((path, err));
581            }
582        }
583
584        Ok(items)
585    }
586
587    async fn lookup(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
588        let path = self.path(hash);
589
590        let container = crate::spawn_blocking({
591            let path = path.clone();
592            move || from_disk(path)
593        })
594        .await?;
595        match container {
596            Ok(c) => Ok(Some(c)),
597            Err(WasmerPackageError::ContainerError(ContainerError::Open { error, .. }))
598            | Err(WasmerPackageError::ContainerError(ContainerError::Read { error, .. }))
599            | Err(WasmerPackageError::ContainerError(ContainerError::Detect(DetectError::Io(
600                error,
601            )))) if error.kind() == ErrorKind::NotFound => Ok(None),
602            Err(e) => {
603                let msg = format!("Unable to read \"{}\"", path.display());
604                Err(Error::new(e).context(msg))
605            }
606        }
607    }
608
609    async fn save(&self, webc: Bytes, dist: &DistributionInfo) -> Result<PathBuf, Error> {
610        let path = self.path(&dist.webc_sha256);
611        let dist = dist.clone();
612        let temp_dir = self.temp_dir();
613
614        let path2 = path.clone();
615        crate::spawn_blocking(move || {
616            // Keep files in a temporary directory until they are fully written
617            // to prevent temp files being included in [`Self::scan`] or `[Self::retain]`.
618
619            std::fs::create_dir_all(&temp_dir)
620                .with_context(|| format!("Unable to create directory '{}'", temp_dir.display()))?;
621
622            let mut temp = NamedTempFile::new_in(&temp_dir)?;
623            temp.write_all(&webc)?;
624            temp.flush()?;
625            temp.as_file_mut().sync_all()?;
626
627            // Move the temporary file to the final location.
628            temp.persist(&path)?;
629
630            tracing::debug!(
631                pkg.hash=%dist.webc_sha256,
632                pkg.url=%dist.webc,
633                path=%path.display(),
634                num_bytes=webc.len(),
635                "Saved to disk",
636            );
637            Result::<_, Error>::Ok(())
638        })
639        .await??;
640
641        Ok(path2)
642    }
643
644    #[tracing::instrument(level = "debug", skip_all)]
645    async fn save_and_load_as_mmapped(
646        &self,
647        webc: Bytes,
648        dist: &DistributionInfo,
649    ) -> Result<Container, Error> {
650        // First, save it to disk
651        self.save(webc, dist).await?;
652
653        // Now try to load it again. The resulting container should use
654        // a memory-mapped file rather than an in-memory buffer.
655        match self.lookup(&dist.webc_sha256).await? {
656            Some(container) => Ok(container),
657            None => {
658                // Something really weird has occurred and we can't see the
659                // saved file. Just error out and let the fallback code do its
660                // thing.
661                Err(Error::msg("Unable to load the downloaded memory from disk"))
662            }
663        }
664    }
665
666    fn path(&self, hash: &WebcHash) -> PathBuf {
667        self.cache_dir.join(format!(
668            "{}{}",
669            hex::encode(hash.as_bytes()),
670            Self::FILE_SUFFIX
671        ))
672    }
673
674    /// Scan all the cached webc files and invoke the callback for each.
675    pub async fn scan<S, F>(&self, state: S, callback: F) -> Result<S, Error>
676    where
677        S: Send + 'static,
678        F: Fn(&mut S, &std::fs::DirEntry) -> Result<(), Error> + Send + 'static,
679    {
680        let cache_dir = self.cache_dir.clone();
681        tokio::task::spawn_blocking(move || -> Result<S, anyhow::Error> {
682            let mut state = state;
683
684            let iter = match std::fs::read_dir(&cache_dir) {
685                Ok(v) => v,
686                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
687                    // path does not exist, so nothing to scan.
688                    return Ok(state);
689                }
690                Err(err) => {
691                    return Err(err).with_context(|| {
692                        format!("Could not read image cache dir: '{}'", cache_dir.display())
693                    });
694                }
695            };
696
697            for res in iter {
698                let entry = res?;
699                if !entry.file_type()?.is_file() {
700                    continue;
701                }
702
703                callback(&mut state, &entry)?;
704            }
705
706            Ok(state)
707        })
708        .await?
709        .context("tokio runtime failed")
710    }
711
712    /// Remove entries from the cache that do not pass the callback.
713    pub async fn retain<S, F>(&self, state: S, filter: F) -> Result<S, Error>
714    where
715        S: Send + 'static,
716        F: Fn(&mut S, &std::fs::DirEntry) -> Result<bool, anyhow::Error> + Send + 'static,
717    {
718        let cache_dir = self.cache_dir.clone();
719        tokio::task::spawn_blocking(move || {
720            let iter = match std::fs::read_dir(&cache_dir) {
721                Ok(v) => v,
722                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
723                    // path does not exist, so nothing to scan.
724                    return Ok(state);
725                }
726                Err(err) => {
727                    return Err(err).with_context(|| {
728                        format!("Could not read image cache dir: '{}'", cache_dir.display())
729                    });
730                }
731            };
732
733            let mut state = state;
734            for res in iter {
735                let entry = res?;
736                if !entry.file_type()?.is_file() {
737                    continue;
738                }
739
740                if !filter(&mut state, &entry)? {
741                    tracing::debug!(
742                        path=%entry.path().display(),
743                        "Removing cached image file - does not pass the filter",
744                    );
745                    match std::fs::remove_file(entry.path()) {
746                        Ok(()) => {}
747                        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
748                        Err(fs_err) => {
749                            tracing::warn!(
750                                path=%entry.path().display(),
751                                ?fs_err,
752                                "Could not delete cached image file",
753                            );
754                        }
755                    }
756                }
757            }
758
759            Ok(state)
760        })
761        .await?
762        .context("tokio runtime failed")
763    }
764}
765
766#[derive(Debug, Default)]
767struct InMemoryCache(RwLock<HashMap<WebcHash, Container>>);
768
769impl InMemoryCache {
770    fn lookup(&self, hash: &WebcHash) -> Option<Container> {
771        self.0.read().unwrap().get(hash).cloned()
772    }
773
774    fn save(&self, container: &Container, hash: WebcHash) {
775        let mut cache = self.0.write().unwrap();
776        cache.entry(hash).or_insert_with(|| container.clone());
777    }
778
779    fn remove(&self, hash: &WebcHash) -> Option<Container> {
780        self.0.write().unwrap().remove(hash)
781    }
782}
783
784#[cfg(test)]
785mod tests {
786    use std::{collections::VecDeque, io::Write, sync::Mutex};
787
788    use futures::future::BoxFuture;
789    use http::{HeaderMap, HeaderValue, StatusCode};
790    use tempfile::TempDir;
791    use wasmer_config::package::PackageId;
792
793    use crate::{
794        http::{HttpRequest, HttpResponse},
795        runtime::resolver::PackageInfo,
796    };
797
798    use super::*;
799
800    const PYTHON: &[u8] =
801        include_bytes!("../../../../../wasmer-test-files/examples/python-0.1.0.wasmer");
802
803    #[derive(Debug)]
804    pub(crate) struct DummyClient {
805        requests: Mutex<Vec<HttpRequest>>,
806        responses: Mutex<VecDeque<HttpResponse>>,
807    }
808
809    impl DummyClient {
810        pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
811            DummyClient {
812                requests: Mutex::new(Vec::new()),
813                responses: Mutex::new(responses.into_iter().collect()),
814            }
815        }
816    }
817
818    impl HttpClient for DummyClient {
819        fn request(
820            &self,
821            request: HttpRequest,
822        ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
823            let response = self.responses.lock().unwrap().pop_front().unwrap();
824            self.requests.lock().unwrap().push(request);
825            Box::pin(async { Ok(response) })
826        }
827    }
828
829    async fn cache_misses_will_trigger_a_download_internal() {
830        let temp = TempDir::new().unwrap();
831        let client = Arc::new(DummyClient::with_responses([HttpResponse {
832            body: Some(PYTHON.to_vec()),
833            redirected: false,
834            status: StatusCode::OK,
835            headers: HeaderMap::new(),
836        }]));
837        let loader = BuiltinPackageLoader::new()
838            .with_cache_dir(temp.path())
839            .with_shared_http_client(client.clone());
840        let summary = PackageSummary {
841            pkg: PackageInfo {
842                id: PackageId::new_named("python/python", "0.1.0".parse().unwrap()),
843                dependencies: Vec::new(),
844                commands: Vec::new(),
845                entrypoint: Some("asdf".to_string()),
846                filesystem: Vec::new(),
847            },
848            dist: DistributionInfo {
849                webc: "https://wasmer.io/python/python".parse().unwrap(),
850                webc_sha256: [0xaa; 32].into(),
851            },
852        };
853
854        let container = loader.load(&summary).await.unwrap();
855
856        // A HTTP request was sent
857        let requests = client.requests.lock().unwrap();
858        let request = &requests[0];
859        assert_eq!(request.url, summary.dist.webc);
860        assert_eq!(request.method, "GET");
861        #[cfg(not(target_arch = "wasm32"))]
862        {
863            assert_eq!(request.headers.len(), 3);
864            assert_eq!(request.headers["Accept-Encoding"], "zstd;q=1.0, gzip;q=0.8");
865        }
866        #[cfg(target_arch = "wasm32")]
867        {
868            assert_eq!(request.headers.len(), 2);
869            assert!(!request.headers.contains_key(http::header::ACCEPT_ENCODING));
870        }
871        assert_eq!(request.headers["Accept"], "application/webc");
872        assert_eq!(request.headers["User-Agent"], USER_AGENT);
873        // Make sure we got the right package
874        let manifest = container.manifest();
875        assert_eq!(manifest.entrypoint.as_deref(), Some("python"));
876        // it should have been automatically saved to disk
877        let path = loader
878            .cache
879            .as_ref()
880            .unwrap()
881            .path(&summary.dist.webc_sha256);
882        assert!(path.exists());
883        assert_eq!(std::fs::read(&path).unwrap(), PYTHON);
884        // and cached in memory for next time
885        let in_memory = loader.in_memory.as_ref().unwrap().0.read().unwrap();
886        assert!(in_memory.contains_key(&summary.dist.webc_sha256));
887    }
888
889    #[cfg(not(target_arch = "wasm32"))]
890    #[tokio::test(flavor = "multi_thread")]
891    async fn cache_misses_will_trigger_a_download() {
892        cache_misses_will_trigger_a_download_internal().await
893    }
894
895    #[cfg(not(target_arch = "wasm32"))]
896    #[tokio::test]
897    async fn can_disable_in_memory_cache() {
898        let temp = TempDir::new().unwrap();
899        let client = Arc::new(DummyClient::with_responses([HttpResponse {
900            body: Some(PYTHON.to_vec()),
901            redirected: false,
902            status: StatusCode::OK,
903            headers: HeaderMap::new(),
904        }]));
905        let loader = BuiltinPackageLoader::new()
906            .with_cache_dir(temp.path())
907            .without_in_memory_cache()
908            .with_shared_http_client(client);
909        let summary = PackageSummary {
910            pkg: PackageInfo {
911                id: PackageId::new_named("python/python", "0.1.0".parse().unwrap()),
912                dependencies: Vec::new(),
913                commands: Vec::new(),
914                entrypoint: Some("asdf".to_string()),
915                filesystem: Vec::new(),
916            },
917            dist: DistributionInfo {
918                webc: "https://wasmer.io/python/python".parse().unwrap(),
919                webc_sha256: [0xbb; 32].into(),
920            },
921        };
922
923        loader.load(&summary).await.unwrap();
924
925        assert!(loader.in_memory.is_none());
926    }
927
928    #[cfg(target_arch = "wasm32")]
929    #[tokio::test()]
930    async fn cache_misses_will_trigger_a_download() {
931        cache_misses_will_trigger_a_download_internal().await
932    }
933
934    #[tokio::test]
935    async fn evict_cached_removes_in_memory_container() {
936        let loader = BuiltinPackageLoader::new();
937        let container = from_bytes(PYTHON).unwrap();
938        let hash: WebcHash = [0xaa; 32].into();
939        loader.insert_cached(hash, &container);
940        let evicted = loader.evict_cached(&hash);
941        assert!(evicted.is_some());
942        {
943            let in_memory = loader.in_memory.as_ref().unwrap().0.read().unwrap();
944            assert!(!in_memory.contains_key(&hash));
945        }
946        assert!(loader.evict_cached(&hash).is_none());
947    }
948
949    /// Small helper to construct headers with a given content-encoding.
950    fn headers_with_encoding(content_encoding: Option<&str>) -> HeaderMap {
951        let mut headers = HeaderMap::new();
952        if let Some(value) = content_encoding {
953            headers.insert(http::header::CONTENT_ENCODING, value.parse().unwrap());
954        }
955        headers
956    }
957
958    /// Small helper to construct headers with a raw content-encoding value.
959    fn headers_with_raw_encoding(value: &[u8]) -> HeaderMap {
960        let mut headers = HeaderMap::new();
961        headers.insert(
962            http::header::CONTENT_ENCODING,
963            HeaderValue::from_bytes(value).unwrap(),
964        );
965        headers
966    }
967
968    /// Confirm decode_response_body passes through unencoded bodies unchanged.
969    #[test]
970    fn decode_response_body_passthrough() {
971        let body = b"plain-bytes".to_vec();
972
973        let decoded =
974            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(None), body.clone())
975                .unwrap();
976        assert_eq!(decoded, body);
977
978        let decoded = BuiltinPackageLoader::decode_response_body(
979            &headers_with_encoding(Some("identity")),
980            body.clone(),
981        )
982        .unwrap();
983        assert_eq!(decoded, body);
984    }
985
986    /// Confirm decode_response_body treats empty/whitespace encoding lists as no encoding.
987    #[test]
988    fn decode_response_body_empty_encoding_list() {
989        let body = b"plain-bytes".to_vec();
990        let decoded = BuiltinPackageLoader::decode_response_body(
991            &headers_with_encoding(Some(" , , ")),
992            body.clone(),
993        )
994        .unwrap();
995        assert_eq!(decoded, body);
996    }
997
998    /// Confirm decode_response_body errors on non-utf8 content-encoding headers.
999    #[test]
1000    fn decode_response_body_non_utf8_encoding_header() {
1001        let body = b"bytes".to_vec();
1002        let err =
1003            BuiltinPackageLoader::decode_response_body(&headers_with_raw_encoding(&[0xff]), body)
1004                .unwrap_err();
1005        let msg = err.to_string();
1006        assert!(msg.contains("non-utf8 content-encoding"));
1007    }
1008
1009    /// Confirm decode_response_body decodes gzip-encoded bodies.
1010    #[test]
1011    fn decode_response_body_gzip() {
1012        let body = b"gzip-bytes".to_vec();
1013        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
1014        encoder.write_all(&body).unwrap();
1015        let encoded = encoder.finish().unwrap();
1016
1017        let decoded = BuiltinPackageLoader::decode_response_body(
1018            &headers_with_encoding(Some("gzip")),
1019            encoded,
1020        )
1021        .unwrap();
1022        assert_eq!(decoded, body);
1023    }
1024
1025    /// Confirm decode_response_body ignores identity when combined with other encodings.
1026    #[test]
1027    fn decode_response_body_identity_and_gzip() {
1028        let body = b"gzip-bytes".to_vec();
1029        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
1030        encoder.write_all(&body).unwrap();
1031        let encoded = encoder.finish().unwrap();
1032
1033        let decoded = BuiltinPackageLoader::decode_response_body(
1034            &headers_with_encoding(Some("identity, gzip")),
1035            encoded,
1036        )
1037        .unwrap();
1038        assert_eq!(decoded, body);
1039    }
1040
1041    /// Confirm decode_response_body errors on invalid gzip payloads.
1042    #[test]
1043    fn decode_response_body_gzip_invalid_payload() {
1044        let body = b"not-gzip".to_vec();
1045        let err =
1046            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("gzip")), body)
1047                .unwrap_err();
1048        let msg = err.to_string();
1049        assert!(msg.contains("failed to decode response body"));
1050    }
1051
1052    /// Confirm decode_response_body decodes zstd-encoded bodies.
1053    #[cfg(not(target_arch = "wasm32"))]
1054    #[test]
1055    fn decode_response_body_zstd() {
1056        let body = b"zstd-bytes".to_vec();
1057        let encoded = zstd::stream::encode_all(std::io::Cursor::new(&body), 0).unwrap();
1058
1059        let decoded = BuiltinPackageLoader::decode_response_body(
1060            &headers_with_encoding(Some("zstd")),
1061            encoded,
1062        )
1063        .unwrap();
1064        assert_eq!(decoded, body);
1065    }
1066
1067    /// Confirm decode_response_body errors on invalid zstd payloads.
1068    #[cfg(not(target_arch = "wasm32"))]
1069    #[test]
1070    fn decode_response_body_zstd_invalid_payload() {
1071        let body = b"not-zstd".to_vec();
1072        let err =
1073            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("zstd")), body)
1074                .unwrap_err();
1075        let msg = err.to_string();
1076        assert!(msg.contains("failed to decode response body"));
1077    }
1078
1079    /// Confirm decode_response_body decodes layered gzip+zstd-encoded bodies.
1080    #[cfg(not(target_arch = "wasm32"))]
1081    #[test]
1082    fn decode_response_body_zstd_and_gzip() {
1083        let body = b"layered-bytes".to_vec();
1084        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
1085        encoder.write_all(&body).unwrap();
1086        let gzipped = encoder.finish().unwrap();
1087        let encoded = zstd::stream::encode_all(std::io::Cursor::new(gzipped), 0).unwrap();
1088
1089        let decoded = BuiltinPackageLoader::decode_response_body(
1090            &headers_with_encoding(Some("gzip, zstd")),
1091            encoded,
1092        )
1093        .unwrap();
1094        assert_eq!(decoded, body);
1095    }
1096
1097    /// Confirm decode_response_body errors on unknown encodings.
1098    #[test]
1099    fn decode_response_body_unknown_encoding() {
1100        let body = b"weird".to_vec();
1101        let err =
1102            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("br")), body)
1103                .unwrap_err();
1104        let msg = err.to_string();
1105        assert!(msg.contains("unsupported content-encoding"));
1106    }
1107
1108    // NOTE: must be a tokio test because the BuiltinPackageLoader::new()
1109    // constructor requires a runtime...
1110    #[tokio::test]
1111    async fn test_builtin_package_downloader_cache_validation() {
1112        let dir = tempfile::tempdir().unwrap();
1113        let path = dir.path();
1114
1115        let contents = "fail";
1116        let correct_hash = WebcHash::sha256(contents);
1117        let used_hash =
1118            WebcHash::parse_hex("0000a28ea38a000f3a3328cb7fabe330638d3258affe1a869e3f92986222d997")
1119                .unwrap();
1120        let filename = format!("{}{}", used_hash, FileSystemCache::FILE_SUFFIX);
1121        let file_path = path.join(filename);
1122        std::fs::write(&file_path, contents).unwrap();
1123
1124        let dl = BuiltinPackageLoader::new().with_cache_dir(path);
1125
1126        let errors = dl
1127            .validate_cache(CacheValidationMode::PruneOnMismatch)
1128            .unwrap();
1129        assert_eq!(errors.len(), 1);
1130        assert_eq!(errors[0].actual_hash, correct_hash);
1131        assert_eq!(errors[0].expected_hash, used_hash);
1132
1133        assert_eq!(file_path.exists(), false);
1134    }
1135
1136    #[tokio::test]
1137    async fn test_file_cache_scan_retain() {
1138        let dir = tempfile::tempdir().unwrap();
1139        let path = dir.path();
1140
1141        let cache = FileSystemCache {
1142            cache_dir: path.to_path_buf(),
1143        };
1144
1145        {
1146            let state = cache
1147                .scan(0u64, |state: &mut u64, _entry| {
1148                    *state += 1;
1149                    Ok(())
1150                })
1151                .await
1152                .unwrap();
1153
1154            assert_eq!(state, 0);
1155        }
1156
1157        let path1 = cache
1158            .save(
1159                Bytes::from_static(b"test1"),
1160                &DistributionInfo {
1161                    webc: Url::parse("file:///test1.webc").unwrap(),
1162                    webc_sha256: WebcHash::sha256(b"test1"),
1163                },
1164            )
1165            .await
1166            .unwrap();
1167        let path2 = cache
1168            .save(
1169                Bytes::from_static(b"test2"),
1170                &DistributionInfo {
1171                    webc: Url::parse("file:///test2.webc").unwrap(),
1172                    webc_sha256: WebcHash::sha256(b"test2"),
1173                },
1174            )
1175            .await
1176            .unwrap();
1177
1178        {
1179            let path1 = path1.clone();
1180            let path2 = path2.clone();
1181            let state = cache
1182                .scan(0u64, move |state: &mut u64, entry| {
1183                    *state += 1;
1184                    assert!(entry.path() == path1 || entry.path() == path2);
1185                    Ok(())
1186                })
1187                .await
1188                .unwrap();
1189
1190            assert_eq!(state, 2);
1191        }
1192
1193        {
1194            let path1 = path1.clone();
1195            let state = cache
1196                .retain(0u64, move |state: &mut u64, entry| {
1197                    *state += 1;
1198                    Ok(entry.path() == path1)
1199                })
1200                .await
1201                .unwrap();
1202            assert_eq!(state, 2);
1203        }
1204
1205        assert!(path1.exists());
1206        assert!(!path2.exists(), "Path 2 should have been deleted");
1207
1208        {
1209            let path1 = path1.clone();
1210            let state = cache
1211                .scan(0u64, move |state: &mut u64, entry| {
1212                    *state += 1;
1213                    assert!(entry.path() == path1);
1214                    Ok(())
1215                })
1216                .await
1217                .unwrap();
1218            assert_eq!(state, 1);
1219        }
1220    }
1221}