wasmer_wasix/runtime/package_loader/
builtin_loader.rs

1use std::{
2    collections::HashMap,
3    io::{ErrorKind, Read, Write as _},
4    path::PathBuf,
5    sync::{Arc, RwLock},
6};
7
8use anyhow::{Context, Error, bail};
9use bytes::Bytes;
10use http::{HeaderMap, Method};
11use tempfile::NamedTempFile;
12use url::Url;
13use wasmer_package::{
14    package::WasmerPackageError,
15    utils::{from_bytes, from_disk},
16};
17use webc::DetectError;
18use webc::{Container, ContainerError};
19
20use crate::{
21    bin_factory::BinaryPackage,
22    http::{HttpClient, HttpRequest, USER_AGENT},
23    runtime::{
24        package_loader::PackageLoader,
25        resolver::{DistributionInfo, PackageSummary, Resolution, WebcHash},
26    },
27};
28
29/// The builtin [`PackageLoader`] that is used by the `wasmer` CLI and
30/// respects `$WASMER_DIR`.
31#[derive(Debug)]
32pub struct BuiltinPackageLoader {
33    client: Arc<dyn HttpClient + Send + Sync>,
34    in_memory: InMemoryCache,
35    cache: Option<FileSystemCache>,
36    /// A mapping from hostnames to tokens
37    tokens: HashMap<String, String>,
38
39    hash_validation: HashIntegrityValidationMode,
40}
41
42/// Defines how to validate package hash integrity.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum HashIntegrityValidationMode {
45    /// Do not validate anything.
46    /// Best for performance.
47    NoValidate,
48    /// Compute the image hash and produce a trace warning on hash mismatches.
49    WarnOnHashMismatch,
50    /// Compute the image hash and fail on a mismatch.
51    FailOnHashMismatch,
52}
53
54impl BuiltinPackageLoader {
55    pub fn new() -> Self {
56        BuiltinPackageLoader {
57            in_memory: InMemoryCache::default(),
58            client: Arc::new(crate::http::default_http_client().unwrap()),
59            cache: None,
60            hash_validation: HashIntegrityValidationMode::NoValidate,
61            tokens: HashMap::new(),
62        }
63    }
64
65    /// Set the validation mode to apply after downloading an image.
66    ///
67    /// See [`HashIntegrityValidationMode`] for details.
68    pub fn with_hash_validation_mode(mut self, mode: HashIntegrityValidationMode) -> Self {
69        self.hash_validation = mode;
70        self
71    }
72
73    pub fn with_cache_dir(self, cache_dir: impl Into<PathBuf>) -> Self {
74        BuiltinPackageLoader {
75            cache: Some(FileSystemCache {
76                cache_dir: cache_dir.into(),
77            }),
78            ..self
79        }
80    }
81
82    pub fn cache(&self) -> Option<&FileSystemCache> {
83        self.cache.as_ref()
84    }
85
86    pub fn validate_cache(
87        &self,
88        mode: CacheValidationMode,
89    ) -> Result<Vec<ImageHashMismatchError>, anyhow::Error> {
90        let cache = self
91            .cache
92            .as_ref()
93            .context("can not validate cache - no cache configured")?;
94
95        let items = cache.validate_hashes()?;
96        let mut errors = Vec::new();
97        for (path, error) in items {
98            match mode {
99                CacheValidationMode::WarnOnMismatch => {
100                    tracing::warn!(?error, "hash mismatch in cached image file");
101                }
102                CacheValidationMode::PruneOnMismatch => {
103                    tracing::warn!(?error, "deleting cached image file due to hash mismatch");
104                    match std::fs::remove_file(&path) {
105                        Ok(()) => {}
106                        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
107                        Err(fs_err) => {
108                            tracing::error!(
109                                path=%error.source,
110                                ?fs_err,
111                                "could not delete cached image file with hash mismatch"
112                            );
113                        }
114                    }
115                }
116            }
117
118            errors.push(error);
119        }
120
121        Ok(errors)
122    }
123
124    pub fn with_http_client(self, client: impl HttpClient + Send + Sync + 'static) -> Self {
125        self.with_shared_http_client(Arc::new(client))
126    }
127
128    pub fn with_shared_http_client(self, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
129        BuiltinPackageLoader { client, ..self }
130    }
131
132    pub fn with_tokens<I, K, V>(mut self, tokens: I) -> Self
133    where
134        I: IntoIterator<Item = (K, V)>,
135        K: Into<String>,
136        V: Into<String>,
137    {
138        for (hostname, token) in tokens {
139            self = self.with_token(hostname, token);
140        }
141
142        self
143    }
144
145    /// Add an API token that will be used whenever sending requests to a
146    /// particular hostname.
147    ///
148    /// Note that this uses [`Url::authority()`] when looking up tokens, so it
149    /// will match both plain hostnames (e.g. `registry.wasmer.io`) and hosts
150    /// with a port number (e.g. `localhost:8000`).
151    pub fn with_token(mut self, hostname: impl Into<String>, token: impl Into<String>) -> Self {
152        self.tokens.insert(hostname.into(), token.into());
153        self
154    }
155
156    /// Insert a container into the in-memory hash.
157    pub fn insert_cached(&self, hash: WebcHash, container: &Container) {
158        self.in_memory.save(container, hash);
159    }
160
161    #[tracing::instrument(level = "debug", skip_all, fields(pkg.hash=%hash))]
162    async fn get_cached(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
163        if let Some(cached) = self.in_memory.lookup(hash) {
164            return Ok(Some(cached));
165        }
166
167        if let Some(cache) = self.cache.as_ref()
168            && let Some(cached) = cache.lookup(hash).await?
169        {
170            // Note: We want to propagate it to the in-memory cache, too
171            tracing::debug!("Copying from the filesystem cache to the in-memory cache");
172            self.in_memory.save(&cached, *hash);
173            return Ok(Some(cached));
174        }
175
176        Ok(None)
177    }
178
179    /// Validate image contents with the specified validation mode.
180    async fn validate_hash(
181        image: &bytes::Bytes,
182        mode: HashIntegrityValidationMode,
183        info: &DistributionInfo,
184    ) -> Result<(), anyhow::Error> {
185        let info = info.clone();
186        let image = image.clone();
187        crate::spawn_blocking(move || Self::validate_hash_sync(&image, mode, &info))
188            .await
189            .context("tokio runtime failed")?
190    }
191
192    /// Validate image contents with the specified validation mode.
193    fn validate_hash_sync(
194        image: &[u8],
195        mode: HashIntegrityValidationMode,
196        info: &DistributionInfo,
197    ) -> Result<(), anyhow::Error> {
198        match mode {
199            HashIntegrityValidationMode::NoValidate => {
200                // Nothing to do.
201                Ok(())
202            }
203            HashIntegrityValidationMode::WarnOnHashMismatch => {
204                let actual_hash = WebcHash::sha256(image);
205                if actual_hash != info.webc_sha256 {
206                    tracing::warn!(%info.webc_sha256, %actual_hash, "image hash mismatch - actual image hash does not match the expected hash!");
207                }
208                Ok(())
209            }
210            HashIntegrityValidationMode::FailOnHashMismatch => {
211                let actual_hash = WebcHash::sha256(image);
212                if actual_hash != info.webc_sha256 {
213                    Err(ImageHashMismatchError {
214                        source: info.webc.to_string(),
215                        actual_hash,
216                        expected_hash: info.webc_sha256,
217                    }
218                    .into())
219                } else {
220                    Ok(())
221                }
222            }
223        }
224    }
225
226    #[tracing::instrument(level = "debug", skip_all, fields(%dist.webc, %dist.webc_sha256))]
227    async fn download(&self, dist: &DistributionInfo) -> Result<Bytes, Error> {
228        if dist.webc.scheme() == "file" {
229            match crate::runtime::resolver::utils::file_path_from_url(&dist.webc) {
230                Ok(path) => {
231                    let bytes = crate::spawn_blocking({
232                        let path = path.clone();
233                        move || std::fs::read(path)
234                    })
235                    .await?
236                    .with_context(|| format!("Unable to read \"{}\"", path.display()))?;
237
238                    let bytes = bytes::Bytes::from(bytes);
239
240                    Self::validate_hash(&bytes, self.hash_validation, dist).await?;
241
242                    return Ok(bytes);
243                }
244                Err(e) => {
245                    tracing::debug!(
246                        url=%dist.webc,
247                        error=&*e,
248                        "Unable to convert the file:// URL to a path",
249                    );
250                }
251            }
252        }
253
254        let request = HttpRequest {
255            headers: self.headers(&dist.webc),
256            url: dist.webc.clone(),
257            method: Method::GET,
258            body: None,
259            options: Default::default(),
260        };
261
262        tracing::debug!(%request.url, %request.method, "webc_package_download_start");
263        tracing::trace!(?request.headers);
264
265        let response = self.client.request(request).await?;
266
267        tracing::trace!(
268            %response.status,
269            %response.redirected,
270            ?response.headers,
271            response.len=response.body.as_ref().map(|body| body.len()),
272            "Received a response",
273        );
274
275        let url = &dist.webc;
276        if !response.is_ok() {
277            return Err(
278                crate::runtime::resolver::utils::http_error(&response).context(format!(
279                    "package download failed: GET request to \"{}\" failed with status {}",
280                    url, response.status
281                )),
282            );
283        }
284
285        let body = response.body.context("package download failed")?;
286        let body = Self::decode_response_body(&response.headers, body)
287            .context("package download failed: could not decode response body")?;
288        tracing::debug!(%url, "package_download_succeeded");
289
290        let body = bytes::Bytes::from(body);
291
292        Self::validate_hash(&body, self.hash_validation, dist).await?;
293
294        Ok(body)
295    }
296
297    fn headers(&self, url: &Url) -> HeaderMap {
298        let mut headers = HeaderMap::new();
299        headers.insert("Accept", "application/webc".parse().unwrap());
300        headers.insert("User-Agent", USER_AGENT.parse().unwrap());
301
302        // Accept compressed responses.
303        // NOTE: gzip and zstd decoding is available on native platforms.
304        // In browser platforms, the fetch implementation should automatically
305        // handle decoding of gzip/zstd responses transparently.
306        headers.insert(
307            http::header::ACCEPT_ENCODING,
308            "zstd;q=1.0, gzip;q=0.8".parse().unwrap(),
309        );
310
311        if url.has_authority()
312            && let Some(token) = self.tokens.get(url.authority())
313        {
314            let header = format!("Bearer {token}");
315            match header.parse() {
316                Ok(header) => {
317                    headers.insert(http::header::AUTHORIZATION, header);
318                }
319                Err(e) => {
320                    tracing::warn!(
321                        error = &e as &dyn std::error::Error,
322                        "An error occurred while parsing the authorization header",
323                    );
324                }
325            }
326        }
327
328        headers
329    }
330
331    /// Decode the response body according to the `Content-Encoding` header.
332    ///
333    /// * Supports `gzip` and `zstd` encodings
334    /// * Supports nested encodings (e.g. `gzip, zstd`)
335    /// * Passes through unencoded bodies or "identity" encoding unchanged
336    fn decode_response_body(headers: &HeaderMap, body: Vec<u8>) -> Result<Vec<u8>, anyhow::Error> {
337        let encodings = match headers.get(http::header::CONTENT_ENCODING) {
338            Some(header) => header
339                .to_str()
340                .context("non-utf8 content-encoding header")?
341                .split(',')
342                .map(|encoding| encoding.trim().to_ascii_lowercase())
343                .filter(|encoding| !encoding.is_empty())
344                .collect::<Vec<_>>(),
345            None => Vec::new(),
346        };
347
348        // Check if there is nothing to decode, return early.
349        // "identity" is the default encoding meaning "no encoding" (See RFC 2616 / RFC 7231)
350        if encodings.is_empty() || (encodings.len() == 1 && encodings[0] == "identity") {
351            return Ok(body);
352        }
353
354        let mut reader: Box<dyn Read> = Box::new(std::io::Cursor::new(body));
355        for encoding in encodings.iter().rev() {
356            match encoding.as_str() {
357                "gzip" => {
358                    reader = Box::new(flate2::read::GzDecoder::new(reader));
359                }
360                "zstd" => {
361                    #[cfg(not(target_arch = "wasm32"))]
362                    {
363                        reader = Box::new(
364                            zstd::stream::read::Decoder::new(reader)
365                                .context("failed to initialize zstd decoder")?,
366                        );
367                    }
368                    #[cfg(target_arch = "wasm32")]
369                    {
370                        // NOTE: in browsers this code will not be hit because
371                        // the fetch API automatically handles content decoding.
372                        bail!("zstd content-encoding is not supported on wasm32");
373                    }
374                }
375                "identity" => {}
376                other => bail!("unsupported content-encoding: {other}"),
377            }
378        }
379
380        let mut decoded = Vec::new();
381        reader
382            .read_to_end(&mut decoded)
383            .context("failed to decode response body")?;
384        Ok(decoded)
385    }
386}
387
388impl Default for BuiltinPackageLoader {
389    fn default() -> Self {
390        BuiltinPackageLoader::new()
391    }
392}
393
394#[async_trait::async_trait]
395impl PackageLoader for BuiltinPackageLoader {
396    #[tracing::instrument(
397        level="debug",
398        skip_all,
399        fields(
400            pkg=%summary.pkg.id,
401        ),
402    )]
403    async fn load(&self, summary: &PackageSummary) -> Result<Container, Error> {
404        if let Some(container) = self.get_cached(&summary.dist.webc_sha256).await? {
405            tracing::debug!("Cache hit!");
406            return Ok(container);
407        }
408
409        // looks like we had a cache miss and need to download it manually
410        let bytes = self
411            .download(&summary.dist)
412            .await
413            .with_context(|| format!("Unable to download \"{}\"", summary.dist.webc))?;
414
415        // We want to cache the container we downloaded, but we want to do it
416        // in a smart way to keep memory usage down.
417
418        if let Some(cache) = &self.cache {
419            match cache
420                .save_and_load_as_mmapped(bytes.clone(), &summary.dist)
421                .await
422            {
423                Ok(container) => {
424                    tracing::debug!("Cached to disk");
425                    self.in_memory.save(&container, summary.dist.webc_sha256);
426                    // The happy path - we've saved to both caches and loaded the
427                    // container from disk (hopefully using mmap) so we're done.
428                    return Ok(container);
429                }
430                Err(e) => {
431                    tracing::warn!(
432                        error=&*e,
433                        pkg=%summary.pkg.id,
434                        pkg.hash=%summary.dist.webc_sha256,
435                        pkg.url=%summary.dist.webc,
436                        "Unable to save the downloaded package to disk",
437                    );
438                }
439            }
440        }
441
442        // The sad path - looks like we don't have a filesystem cache so we'll
443        // need to keep the whole thing in memory.
444        let container = crate::spawn_blocking(move || from_bytes(bytes)).await??;
445        // We still want to cache it in memory, of course
446        self.in_memory.save(&container, summary.dist.webc_sha256);
447        Ok(container)
448    }
449
450    async fn load_package_tree(
451        &self,
452        root: &Container,
453        resolution: &Resolution,
454        root_is_local_dir: bool,
455    ) -> Result<BinaryPackage, Error> {
456        super::load_package_tree(root, self, resolution, root_is_local_dir).await
457    }
458}
459
460#[derive(Clone, Debug)]
461pub struct ImageHashMismatchError {
462    source: String,
463    expected_hash: WebcHash,
464    actual_hash: WebcHash,
465}
466
467impl std::fmt::Display for ImageHashMismatchError {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        write!(
470            f,
471            "image hash mismatch! expected hash '{}', but the computed hash is '{}' (source '{}')",
472            self.expected_hash, self.actual_hash, self.source,
473        )
474    }
475}
476
477impl std::error::Error for ImageHashMismatchError {}
478
479#[derive(Clone, Copy, Debug, PartialEq, Eq)]
480pub enum CacheValidationMode {
481    /// Just emit a warning for all images where the filename doesn't match
482    /// the expected hash.
483    WarnOnMismatch,
484    /// Remove images from the cache if the filename doesn't match the actual
485    /// hash.
486    PruneOnMismatch,
487}
488
489// FIXME: This implementation will block the async runtime and should use
490// some sort of spawn_blocking() call to run it in the background.
491#[derive(Debug)]
492pub struct FileSystemCache {
493    cache_dir: PathBuf,
494}
495
496impl FileSystemCache {
497    const FILE_SUFFIX: &'static str = ".bin";
498
499    fn temp_dir(&self) -> PathBuf {
500        self.cache_dir.join("__temp__")
501    }
502
503    /// Validate that the cached image file names correspond to their actual
504    /// file content hashes.
505    fn validate_hashes(&self) -> Result<Vec<(PathBuf, ImageHashMismatchError)>, anyhow::Error> {
506        let mut items = Vec::<(PathBuf, ImageHashMismatchError)>::new();
507
508        let iter = match std::fs::read_dir(&self.cache_dir) {
509            Ok(v) => v,
510            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
511                // Cache dir does not exist, so nothing to validate.
512                return Ok(Vec::new());
513            }
514            Err(err) => {
515                return Err(err).with_context(|| {
516                    format!(
517                        "Could not read image cache dir: '{}'",
518                        self.cache_dir.display()
519                    )
520                });
521            }
522        };
523
524        for res in iter {
525            let entry = res?;
526            if !entry.file_type()?.is_file() {
527                continue;
528            }
529
530            // Extract the hash from the filename.
531
532            let hash_opt = entry
533                .file_name()
534                .to_str()
535                .and_then(|x| {
536                    let (raw_hash, _) = x.split_once(Self::FILE_SUFFIX)?;
537                    Some(raw_hash)
538                })
539                .and_then(|x| WebcHash::parse_hex(x).ok());
540            let Some(expected_hash) = hash_opt else {
541                continue;
542            };
543
544            // Compute the actual hash.
545            let path = entry.path();
546            let actual_hash = WebcHash::for_file(&path)?;
547
548            if actual_hash != expected_hash {
549                let err = ImageHashMismatchError {
550                    source: path.to_string_lossy().to_string(),
551                    actual_hash,
552                    expected_hash,
553                };
554                items.push((path, err));
555            }
556        }
557
558        Ok(items)
559    }
560
561    async fn lookup(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
562        let path = self.path(hash);
563
564        let container = crate::spawn_blocking({
565            let path = path.clone();
566            move || from_disk(path)
567        })
568        .await?;
569        match container {
570            Ok(c) => Ok(Some(c)),
571            Err(WasmerPackageError::ContainerError(ContainerError::Open { error, .. }))
572            | Err(WasmerPackageError::ContainerError(ContainerError::Read { error, .. }))
573            | Err(WasmerPackageError::ContainerError(ContainerError::Detect(DetectError::Io(
574                error,
575            )))) if error.kind() == ErrorKind::NotFound => Ok(None),
576            Err(e) => {
577                let msg = format!("Unable to read \"{}\"", path.display());
578                Err(Error::new(e).context(msg))
579            }
580        }
581    }
582
583    async fn save(&self, webc: Bytes, dist: &DistributionInfo) -> Result<PathBuf, Error> {
584        let path = self.path(&dist.webc_sha256);
585        let dist = dist.clone();
586        let temp_dir = self.temp_dir();
587
588        let path2 = path.clone();
589        crate::spawn_blocking(move || {
590            // Keep files in a temporary directory until they are fully written
591            // to prevent temp files being included in [`Self::scan`] or `[Self::retain]`.
592
593            std::fs::create_dir_all(&temp_dir)
594                .with_context(|| format!("Unable to create directory '{}'", temp_dir.display()))?;
595
596            let mut temp = NamedTempFile::new_in(&temp_dir)?;
597            temp.write_all(&webc)?;
598            temp.flush()?;
599            temp.as_file_mut().sync_all()?;
600
601            // Move the temporary file to the final location.
602            temp.persist(&path)?;
603
604            tracing::debug!(
605                pkg.hash=%dist.webc_sha256,
606                pkg.url=%dist.webc,
607                path=%path.display(),
608                num_bytes=webc.len(),
609                "Saved to disk",
610            );
611            Result::<_, Error>::Ok(())
612        })
613        .await??;
614
615        Ok(path2)
616    }
617
618    #[tracing::instrument(level = "debug", skip_all)]
619    async fn save_and_load_as_mmapped(
620        &self,
621        webc: Bytes,
622        dist: &DistributionInfo,
623    ) -> Result<Container, Error> {
624        // First, save it to disk
625        self.save(webc, dist).await?;
626
627        // Now try to load it again. The resulting container should use
628        // a memory-mapped file rather than an in-memory buffer.
629        match self.lookup(&dist.webc_sha256).await? {
630            Some(container) => Ok(container),
631            None => {
632                // Something really weird has occurred and we can't see the
633                // saved file. Just error out and let the fallback code do its
634                // thing.
635                Err(Error::msg("Unable to load the downloaded memory from disk"))
636            }
637        }
638    }
639
640    fn path(&self, hash: &WebcHash) -> PathBuf {
641        self.cache_dir.join(format!(
642            "{}{}",
643            hex::encode(hash.as_bytes()),
644            Self::FILE_SUFFIX
645        ))
646    }
647
648    /// Scan all the cached webc files and invoke the callback for each.
649    pub async fn scan<S, F>(&self, state: S, callback: F) -> Result<S, Error>
650    where
651        S: Send + 'static,
652        F: Fn(&mut S, &std::fs::DirEntry) -> Result<(), Error> + Send + 'static,
653    {
654        let cache_dir = self.cache_dir.clone();
655        tokio::task::spawn_blocking(move || -> Result<S, anyhow::Error> {
656            let mut state = state;
657
658            let iter = match std::fs::read_dir(&cache_dir) {
659                Ok(v) => v,
660                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
661                    // path does not exist, so nothing to scan.
662                    return Ok(state);
663                }
664                Err(err) => {
665                    return Err(err).with_context(|| {
666                        format!("Could not read image cache dir: '{}'", cache_dir.display())
667                    });
668                }
669            };
670
671            for res in iter {
672                let entry = res?;
673                if !entry.file_type()?.is_file() {
674                    continue;
675                }
676
677                callback(&mut state, &entry)?;
678            }
679
680            Ok(state)
681        })
682        .await?
683        .context("tokio runtime failed")
684    }
685
686    /// Remove entries from the cache that do not pass the callback.
687    pub async fn retain<S, F>(&self, state: S, filter: F) -> Result<S, Error>
688    where
689        S: Send + 'static,
690        F: Fn(&mut S, &std::fs::DirEntry) -> Result<bool, anyhow::Error> + Send + 'static,
691    {
692        let cache_dir = self.cache_dir.clone();
693        tokio::task::spawn_blocking(move || {
694            let iter = match std::fs::read_dir(&cache_dir) {
695                Ok(v) => v,
696                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
697                    // path does not exist, so nothing to scan.
698                    return Ok(state);
699                }
700                Err(err) => {
701                    return Err(err).with_context(|| {
702                        format!("Could not read image cache dir: '{}'", cache_dir.display())
703                    });
704                }
705            };
706
707            let mut state = state;
708            for res in iter {
709                let entry = res?;
710                if !entry.file_type()?.is_file() {
711                    continue;
712                }
713
714                if !filter(&mut state, &entry)? {
715                    tracing::debug!(
716                        path=%entry.path().display(),
717                        "Removing cached image file - does not pass the filter",
718                    );
719                    match std::fs::remove_file(entry.path()) {
720                        Ok(()) => {}
721                        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
722                        Err(fs_err) => {
723                            tracing::warn!(
724                                path=%entry.path().display(),
725                                ?fs_err,
726                                "Could not delete cached image file",
727                            );
728                        }
729                    }
730                }
731            }
732
733            Ok(state)
734        })
735        .await?
736        .context("tokio runtime failed")
737    }
738}
739
740#[derive(Debug, Default)]
741struct InMemoryCache(RwLock<HashMap<WebcHash, Container>>);
742
743impl InMemoryCache {
744    fn lookup(&self, hash: &WebcHash) -> Option<Container> {
745        self.0.read().unwrap().get(hash).cloned()
746    }
747
748    fn save(&self, container: &Container, hash: WebcHash) {
749        let mut cache = self.0.write().unwrap();
750        cache.entry(hash).or_insert_with(|| container.clone());
751    }
752}
753
754#[cfg(test)]
755mod tests {
756    use std::{collections::VecDeque, io::Write, sync::Mutex};
757
758    use futures::future::BoxFuture;
759    use http::{HeaderMap, HeaderValue, StatusCode};
760    use tempfile::TempDir;
761    use wasmer_config::package::PackageId;
762
763    use crate::{
764        http::{HttpRequest, HttpResponse},
765        runtime::resolver::PackageInfo,
766    };
767
768    use super::*;
769
770    const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
771
772    #[derive(Debug)]
773    pub(crate) struct DummyClient {
774        requests: Mutex<Vec<HttpRequest>>,
775        responses: Mutex<VecDeque<HttpResponse>>,
776    }
777
778    impl DummyClient {
779        pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
780            DummyClient {
781                requests: Mutex::new(Vec::new()),
782                responses: Mutex::new(responses.into_iter().collect()),
783            }
784        }
785    }
786
787    impl HttpClient for DummyClient {
788        fn request(
789            &self,
790            request: HttpRequest,
791        ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
792            let response = self.responses.lock().unwrap().pop_front().unwrap();
793            self.requests.lock().unwrap().push(request);
794            Box::pin(async { Ok(response) })
795        }
796    }
797
798    async fn cache_misses_will_trigger_a_download_internal() {
799        let temp = TempDir::new().unwrap();
800        let client = Arc::new(DummyClient::with_responses([HttpResponse {
801            body: Some(PYTHON.to_vec()),
802            redirected: false,
803            status: StatusCode::OK,
804            headers: HeaderMap::new(),
805        }]));
806        let loader = BuiltinPackageLoader::new()
807            .with_cache_dir(temp.path())
808            .with_shared_http_client(client.clone());
809        let summary = PackageSummary {
810            pkg: PackageInfo {
811                id: PackageId::new_named("python/python", "0.1.0".parse().unwrap()),
812                dependencies: Vec::new(),
813                commands: Vec::new(),
814                entrypoint: Some("asdf".to_string()),
815                filesystem: Vec::new(),
816            },
817            dist: DistributionInfo {
818                webc: "https://wasmer.io/python/python".parse().unwrap(),
819                webc_sha256: [0xaa; 32].into(),
820            },
821        };
822
823        let container = loader.load(&summary).await.unwrap();
824
825        // A HTTP request was sent
826        let requests = client.requests.lock().unwrap();
827        let request = &requests[0];
828        assert_eq!(request.url, summary.dist.webc);
829        assert_eq!(request.method, "GET");
830        #[cfg(not(target_arch = "wasm32"))]
831        {
832            assert_eq!(request.headers.len(), 3);
833            assert_eq!(request.headers["Accept-Encoding"], "zstd;q=1.0, gzip;q=0.8");
834        }
835        #[cfg(target_arch = "wasm32")]
836        {
837            assert_eq!(request.headers.len(), 2);
838            assert!(!request.headers.contains_key(http::header::ACCEPT_ENCODING));
839        }
840        assert_eq!(request.headers["Accept"], "application/webc");
841        assert_eq!(request.headers["User-Agent"], USER_AGENT);
842        // Make sure we got the right package
843        let manifest = container.manifest();
844        assert_eq!(manifest.entrypoint.as_deref(), Some("python"));
845        // it should have been automatically saved to disk
846        let path = loader
847            .cache
848            .as_ref()
849            .unwrap()
850            .path(&summary.dist.webc_sha256);
851        assert!(path.exists());
852        assert_eq!(std::fs::read(&path).unwrap(), PYTHON);
853        // and cached in memory for next time
854        let in_memory = loader.in_memory.0.read().unwrap();
855        assert!(in_memory.contains_key(&summary.dist.webc_sha256));
856    }
857
858    #[cfg(not(target_arch = "wasm32"))]
859    #[tokio::test(flavor = "multi_thread")]
860    async fn cache_misses_will_trigger_a_download() {
861        cache_misses_will_trigger_a_download_internal().await
862    }
863
864    #[cfg(target_arch = "wasm32")]
865    #[tokio::test()]
866    async fn cache_misses_will_trigger_a_download() {
867        cache_misses_will_trigger_a_download_internal().await
868    }
869
870    /// Small helper to construct headers with a given content-encoding.
871    fn headers_with_encoding(content_encoding: Option<&str>) -> HeaderMap {
872        let mut headers = HeaderMap::new();
873        if let Some(value) = content_encoding {
874            headers.insert(http::header::CONTENT_ENCODING, value.parse().unwrap());
875        }
876        headers
877    }
878
879    /// Small helper to construct headers with a raw content-encoding value.
880    fn headers_with_raw_encoding(value: &[u8]) -> HeaderMap {
881        let mut headers = HeaderMap::new();
882        headers.insert(
883            http::header::CONTENT_ENCODING,
884            HeaderValue::from_bytes(value).unwrap(),
885        );
886        headers
887    }
888
889    /// Confirm decode_response_body passes through unencoded bodies unchanged.
890    #[test]
891    fn decode_response_body_passthrough() {
892        let body = b"plain-bytes".to_vec();
893
894        let decoded =
895            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(None), body.clone())
896                .unwrap();
897        assert_eq!(decoded, body);
898
899        let decoded = BuiltinPackageLoader::decode_response_body(
900            &headers_with_encoding(Some("identity")),
901            body.clone(),
902        )
903        .unwrap();
904        assert_eq!(decoded, body);
905    }
906
907    /// Confirm decode_response_body treats empty/whitespace encoding lists as no encoding.
908    #[test]
909    fn decode_response_body_empty_encoding_list() {
910        let body = b"plain-bytes".to_vec();
911        let decoded = BuiltinPackageLoader::decode_response_body(
912            &headers_with_encoding(Some(" , , ")),
913            body.clone(),
914        )
915        .unwrap();
916        assert_eq!(decoded, body);
917    }
918
919    /// Confirm decode_response_body errors on non-utf8 content-encoding headers.
920    #[test]
921    fn decode_response_body_non_utf8_encoding_header() {
922        let body = b"bytes".to_vec();
923        let err =
924            BuiltinPackageLoader::decode_response_body(&headers_with_raw_encoding(&[0xff]), body)
925                .unwrap_err();
926        let msg = err.to_string();
927        assert!(msg.contains("non-utf8 content-encoding"));
928    }
929
930    /// Confirm decode_response_body decodes gzip-encoded bodies.
931    #[test]
932    fn decode_response_body_gzip() {
933        let body = b"gzip-bytes".to_vec();
934        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
935        encoder.write_all(&body).unwrap();
936        let encoded = encoder.finish().unwrap();
937
938        let decoded = BuiltinPackageLoader::decode_response_body(
939            &headers_with_encoding(Some("gzip")),
940            encoded,
941        )
942        .unwrap();
943        assert_eq!(decoded, body);
944    }
945
946    /// Confirm decode_response_body ignores identity when combined with other encodings.
947    #[test]
948    fn decode_response_body_identity_and_gzip() {
949        let body = b"gzip-bytes".to_vec();
950        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
951        encoder.write_all(&body).unwrap();
952        let encoded = encoder.finish().unwrap();
953
954        let decoded = BuiltinPackageLoader::decode_response_body(
955            &headers_with_encoding(Some("identity, gzip")),
956            encoded,
957        )
958        .unwrap();
959        assert_eq!(decoded, body);
960    }
961
962    /// Confirm decode_response_body errors on invalid gzip payloads.
963    #[test]
964    fn decode_response_body_gzip_invalid_payload() {
965        let body = b"not-gzip".to_vec();
966        let err =
967            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("gzip")), body)
968                .unwrap_err();
969        let msg = err.to_string();
970        assert!(msg.contains("failed to decode response body"));
971    }
972
973    /// Confirm decode_response_body decodes zstd-encoded bodies.
974    #[cfg(not(target_arch = "wasm32"))]
975    #[test]
976    fn decode_response_body_zstd() {
977        let body = b"zstd-bytes".to_vec();
978        let encoded = zstd::stream::encode_all(std::io::Cursor::new(&body), 0).unwrap();
979
980        let decoded = BuiltinPackageLoader::decode_response_body(
981            &headers_with_encoding(Some("zstd")),
982            encoded,
983        )
984        .unwrap();
985        assert_eq!(decoded, body);
986    }
987
988    /// Confirm decode_response_body errors on invalid zstd payloads.
989    #[cfg(not(target_arch = "wasm32"))]
990    #[test]
991    fn decode_response_body_zstd_invalid_payload() {
992        let body = b"not-zstd".to_vec();
993        let err =
994            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("zstd")), body)
995                .unwrap_err();
996        let msg = err.to_string();
997        assert!(msg.contains("failed to decode response body"));
998    }
999
1000    /// Confirm decode_response_body decodes layered gzip+zstd-encoded bodies.
1001    #[cfg(not(target_arch = "wasm32"))]
1002    #[test]
1003    fn decode_response_body_zstd_and_gzip() {
1004        let body = b"layered-bytes".to_vec();
1005        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
1006        encoder.write_all(&body).unwrap();
1007        let gzipped = encoder.finish().unwrap();
1008        let encoded = zstd::stream::encode_all(std::io::Cursor::new(gzipped), 0).unwrap();
1009
1010        let decoded = BuiltinPackageLoader::decode_response_body(
1011            &headers_with_encoding(Some("gzip, zstd")),
1012            encoded,
1013        )
1014        .unwrap();
1015        assert_eq!(decoded, body);
1016    }
1017
1018    /// Confirm decode_response_body errors on unknown encodings.
1019    #[test]
1020    fn decode_response_body_unknown_encoding() {
1021        let body = b"weird".to_vec();
1022        let err =
1023            BuiltinPackageLoader::decode_response_body(&headers_with_encoding(Some("br")), body)
1024                .unwrap_err();
1025        let msg = err.to_string();
1026        assert!(msg.contains("unsupported content-encoding"));
1027    }
1028
1029    // NOTE: must be a tokio test because the BuiltinPackageLoader::new()
1030    // constructor requires a runtime...
1031    #[tokio::test]
1032    async fn test_builtin_package_downloader_cache_validation() {
1033        let dir = tempfile::tempdir().unwrap();
1034        let path = dir.path();
1035
1036        let contents = "fail";
1037        let correct_hash = WebcHash::sha256(contents);
1038        let used_hash =
1039            WebcHash::parse_hex("0000a28ea38a000f3a3328cb7fabe330638d3258affe1a869e3f92986222d997")
1040                .unwrap();
1041        let filename = format!("{}{}", used_hash, FileSystemCache::FILE_SUFFIX);
1042        let file_path = path.join(filename);
1043        std::fs::write(&file_path, contents).unwrap();
1044
1045        let dl = BuiltinPackageLoader::new().with_cache_dir(path);
1046
1047        let errors = dl
1048            .validate_cache(CacheValidationMode::PruneOnMismatch)
1049            .unwrap();
1050        assert_eq!(errors.len(), 1);
1051        assert_eq!(errors[0].actual_hash, correct_hash);
1052        assert_eq!(errors[0].expected_hash, used_hash);
1053
1054        assert_eq!(file_path.exists(), false);
1055    }
1056
1057    #[tokio::test]
1058    async fn test_file_cache_scan_retain() {
1059        let dir = tempfile::tempdir().unwrap();
1060        let path = dir.path();
1061
1062        let cache = FileSystemCache {
1063            cache_dir: path.to_path_buf(),
1064        };
1065
1066        {
1067            let state = cache
1068                .scan(0u64, |state: &mut u64, _entry| {
1069                    *state += 1;
1070                    Ok(())
1071                })
1072                .await
1073                .unwrap();
1074
1075            assert_eq!(state, 0);
1076        }
1077
1078        let path1 = cache
1079            .save(
1080                Bytes::from_static(b"test1"),
1081                &DistributionInfo {
1082                    webc: Url::parse("file:///test1.webc").unwrap(),
1083                    webc_sha256: WebcHash::sha256(b"test1"),
1084                },
1085            )
1086            .await
1087            .unwrap();
1088        let path2 = cache
1089            .save(
1090                Bytes::from_static(b"test2"),
1091                &DistributionInfo {
1092                    webc: Url::parse("file:///test2.webc").unwrap(),
1093                    webc_sha256: WebcHash::sha256(b"test2"),
1094                },
1095            )
1096            .await
1097            .unwrap();
1098
1099        {
1100            let path1 = path1.clone();
1101            let path2 = path2.clone();
1102            let state = cache
1103                .scan(0u64, move |state: &mut u64, entry| {
1104                    *state += 1;
1105                    assert!(entry.path() == path1 || entry.path() == path2);
1106                    Ok(())
1107                })
1108                .await
1109                .unwrap();
1110
1111            assert_eq!(state, 2);
1112        }
1113
1114        {
1115            let path1 = path1.clone();
1116            let state = cache
1117                .retain(0u64, move |state: &mut u64, entry| {
1118                    *state += 1;
1119                    Ok(entry.path() == path1)
1120                })
1121                .await
1122                .unwrap();
1123            assert_eq!(state, 2);
1124        }
1125
1126        assert!(path1.exists());
1127        assert!(!path2.exists(), "Path 2 should have been deleted");
1128
1129        {
1130            let path1 = path1.clone();
1131            let state = cache
1132                .scan(0u64, move |state: &mut u64, entry| {
1133                    *state += 1;
1134                    assert!(entry.path() == path1);
1135                    Ok(())
1136                })
1137                .await
1138                .unwrap();
1139            assert_eq!(state, 1);
1140        }
1141    }
1142}