wasmer_wasix/runtime/resolver/
web_source.rs

1use std::{
2    fmt::Write as _,
3    io::Write,
4    path::{Path, PathBuf},
5    sync::Arc,
6    time::{Duration, SystemTime},
7};
8
9use anyhow::{Context, Error};
10use http::Method;
11use sha2::{Digest, Sha256};
12use tempfile::NamedTempFile;
13use url::Url;
14use wasmer_config::package::{PackageHash, PackageId, PackageSource};
15use wasmer_package::utils::from_disk;
16
17use crate::{
18    http::{HttpClient, HttpRequest},
19    runtime::resolver::{
20        DistributionInfo, PackageInfo, PackageSummary, QueryError, Source, WebcHash,
21    },
22};
23
24/// A [`Source`] which can query arbitrary packages on the internet.
25///
26/// # Implementation Notes
27///
28/// Unlike other [`Source`] implementations, this will need to download
29/// a package if it is a [`PackageSource::Url`]. Optionally, these downloaded
30/// packages can be cached in a local directory.
31///
32/// After a certain period ([`WebSource::with_retry_period()`]), the
33/// [`WebSource`] will re-check the uploaded source to make sure the cached
34/// package is still valid. This checking is done using the [ETag][ETag] header,
35/// if available.
36///
37/// [ETag]: https://en.wikipedia.org/wiki/HTTP_ETag
38#[derive(Debug, Clone)]
39pub struct WebSource {
40    cache_dir: PathBuf,
41    client: Arc<dyn HttpClient + Send + Sync>,
42    retry_period: Duration,
43}
44
45impl WebSource {
46    pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5 * 60);
47
48    pub fn new(cache_dir: impl Into<PathBuf>, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
49        WebSource {
50            cache_dir: cache_dir.into(),
51            client,
52            retry_period: WebSource::DEFAULT_RETRY_PERIOD,
53        }
54    }
55
56    /// Set the period after which an item should be marked as "possibly dirty"
57    /// in the cache.
58    pub fn with_retry_period(self, retry_period: Duration) -> Self {
59        WebSource {
60            retry_period,
61            ..self
62        }
63    }
64
65    /// Download a package and cache it locally.
66    #[tracing::instrument(level = "debug", skip_all, fields(%url))]
67    async fn get_locally_cached_file(&self, url: &Url) -> Result<PathBuf, Error> {
68        // This function is a bit tricky because we go to great lengths to avoid
69        // unnecessary downloads.
70
71        let cache_key = sha256(url.as_str().as_bytes());
72
73        // First, we figure out some basic information about the item
74        let cache_info = CacheInfo::for_url(&cache_key, &self.cache_dir);
75
76        // Next we check if we definitely got a cache hit
77        let state = match classify_cache_using_mtime(cache_info, self.retry_period) {
78            Ok(path) => {
79                tracing::debug!(path=%path.display(), "Cache hit!");
80                return Ok(path);
81            }
82            Err(s) => s,
83        };
84
85        // Let's check if the ETag is still valid
86        if let CacheState::PossiblyDirty { etag, path } = &state {
87            match self.get_etag(url).await {
88                Ok(new_etag) if new_etag == *etag => {
89                    return Ok(path.clone());
90                }
91                Ok(different_etag) => {
92                    tracing::debug!(
93                        original_etag=%etag,
94                        new_etag=%different_etag,
95                        path=%path.display(),
96                        "File has been updated. Redownloading.",
97                    );
98                }
99                Err(e) => {
100                    tracing::debug!(
101                        error=&*e,
102                        path=%path.display(),
103                        original_etag=%etag,
104                        "Unable to check if the etag is out of date",
105                    )
106                }
107            }
108        }
109
110        // Oh well, looks like we'll need to download it again
111        let (bytes, etag) = match self.fetch(url).await {
112            Ok((bytes, etag)) => (bytes, etag),
113            Err(e) => {
114                tracing::warn!(error = &*e, "Download failed");
115                match state.take_path() {
116                    Some(path) => {
117                        tracing::debug!(
118                            path=%path.display(),
119                            "Using a possibly stale cached file",
120                        );
121                        return Ok(path);
122                    }
123                    None => {
124                        return Err(e);
125                    }
126                }
127            }
128        };
129
130        let path = self.cache_dir.join(&cache_key);
131        self.atomically_save_file(&path, &bytes)
132            .await
133            .with_context(|| {
134                format!(
135                    "Unable to save the downloaded file to \"{}\"",
136                    path.display()
137                )
138            })?;
139
140        if let Some(etag) = etag
141            && let Err(e) = self
142                .atomically_save_file(path.with_extension("etag"), etag.as_bytes())
143                .await
144        {
145            tracing::warn!(
146                error=&*e,
147                %etag,
148                %url,
149                path=%path.display(),
150                "Unable to save the etag file",
151            )
152        }
153
154        Ok(path)
155    }
156
157    async fn atomically_save_file(&self, path: impl AsRef<Path>, data: &[u8]) -> Result<(), Error> {
158        // FIXME: This will all block the main thread
159
160        let path = path.as_ref();
161
162        if let Some(parent) = path.parent() {
163            std::fs::create_dir_all(parent)
164                .with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
165        }
166
167        let mut temp = NamedTempFile::new_in(&self.cache_dir)?;
168        temp.write_all(data)?;
169        temp.as_file().sync_all()?;
170        temp.persist(path)?;
171
172        Ok(())
173    }
174
175    async fn get_etag(&self, url: &Url) -> Result<String, Error> {
176        let request = HttpRequest {
177            url: url.clone(),
178            method: Method::HEAD,
179            headers: super::utils::webc_headers(),
180            body: None,
181            options: Default::default(),
182        };
183
184        let response = self.client.request(request).await?;
185
186        if !response.is_ok() {
187            return Err(super::utils::http_error(&response)
188                .context(format!("The HEAD request to \"{url}\" failed")));
189        }
190
191        let etag = response
192            .headers
193            .get("ETag")
194            .context("The HEAD request didn't contain an ETag header`")?
195            .to_str()
196            .context("The ETag wasn't valid UTF-8")?;
197
198        Ok(etag.to_string())
199    }
200
201    async fn fetch(&self, url: &Url) -> Result<(Vec<u8>, Option<String>), Error> {
202        let request = HttpRequest {
203            url: url.clone(),
204            method: Method::GET,
205            headers: super::utils::webc_headers(),
206            body: None,
207            options: Default::default(),
208        };
209        let response = self.client.request(request).await?;
210
211        if !response.is_ok() {
212            return Err(super::utils::http_error(&response)
213                .context(format!("The GET request to \"{url}\" failed")));
214        }
215
216        let body = response.body.context("Response didn't contain a body")?;
217
218        let etag = response
219            .headers
220            .get("ETag")
221            .and_then(|etag| etag.to_str().ok())
222            .map(|etag| etag.to_string());
223
224        Ok((body, etag))
225    }
226
227    async fn load_url(&self, url: &Url) -> Result<Vec<PackageSummary>, anyhow::Error> {
228        let local_path = self
229            .get_locally_cached_file(url)
230            .await
231            .context("Unable to get the locally cached file")?;
232
233        let webc_sha256 = crate::block_in_place(|| WebcHash::for_file(&local_path))
234            .with_context(|| format!("Unable to hash \"{}\"", local_path.display()))?;
235
236        // Note: We want to use Container::from_disk() rather than the bytes
237        // our HTTP client gave us because then we can use memory-mapped files
238        let container = crate::block_in_place(|| from_disk(&local_path))
239            .with_context(|| format!("Unable to load \"{}\"", local_path.display()))?;
240
241        let id = PackageInfo::package_id_from_manifest(container.manifest())?
242            .unwrap_or_else(|| PackageId::Hash(PackageHash::from_sha256_bytes(webc_sha256.0)));
243
244        let pkg = PackageInfo::from_manifest(id, container.manifest(), container.version())
245            .context("Unable to determine the package's metadata")?;
246
247        let dist = DistributionInfo {
248            webc: url.clone(),
249            webc_sha256,
250        };
251
252        Ok(vec![PackageSummary { pkg, dist }])
253    }
254}
255
256#[async_trait::async_trait]
257impl Source for WebSource {
258    #[tracing::instrument(level = "debug", skip_all, fields(%package))]
259    async fn query(&self, package: &PackageSource) -> Result<Vec<PackageSummary>, QueryError> {
260        let url = match package {
261            PackageSource::Url(url) => url,
262            _ => {
263                return Err(QueryError::Unsupported {
264                    query: package.clone(),
265                });
266            }
267        };
268
269        self.load_url(url)
270            .await
271            .map_err(|error| QueryError::new_other(error, package))
272    }
273}
274
275fn sha256(bytes: &[u8]) -> String {
276    let mut hasher = Sha256::default();
277    hasher.update(bytes);
278    let hash = hasher.finalize();
279    let mut buffer = String::with_capacity(hash.len() * 2);
280    for byte in hash {
281        write!(buffer, "{byte:02X}").expect("Unreachable");
282    }
283
284    buffer
285}
286
287#[derive(Debug, Clone, PartialEq)]
288enum CacheInfo {
289    /// An item isn't in the cache, but could be cached later on.
290    Miss,
291    /// An item in the cache.
292    Hit {
293        path: PathBuf,
294        etag: Option<String>,
295        last_modified: Option<SystemTime>,
296    },
297}
298
299impl CacheInfo {
300    fn for_url(key: &str, checkout_dir: &Path) -> CacheInfo {
301        let path = checkout_dir.join(key);
302
303        if !path.exists() {
304            return CacheInfo::Miss;
305        }
306
307        let etag = std::fs::read_to_string(path.with_extension("etag")).ok();
308        let last_modified = path.metadata().and_then(|m| m.modified()).ok();
309
310        CacheInfo::Hit {
311            etag,
312            last_modified,
313            path,
314        }
315    }
316}
317
318fn classify_cache_using_mtime(
319    info: CacheInfo,
320    invalidation_threshold: Duration,
321) -> Result<PathBuf, CacheState> {
322    let (path, last_modified, etag) = match info {
323        CacheInfo::Hit {
324            path,
325            last_modified: Some(last_modified),
326            etag,
327            ..
328        } => (path, last_modified, etag),
329        CacheInfo::Hit {
330            path,
331            last_modified: None,
332            etag: Some(etag),
333            ..
334        } => return Err(CacheState::PossiblyDirty { etag, path }),
335        CacheInfo::Hit {
336            etag: None,
337            last_modified: None,
338            path,
339            ..
340        } => {
341            return Err(CacheState::UnableToVerify { path });
342        }
343        CacheInfo::Miss => return Err(CacheState::Miss),
344    };
345
346    if let Ok(time_since_last_modified) = last_modified.elapsed()
347        && time_since_last_modified <= invalidation_threshold
348    {
349        return Ok(path);
350    }
351
352    match etag {
353        Some(etag) => Err(CacheState::PossiblyDirty { etag, path }),
354        None => Err(CacheState::UnableToVerify { path }),
355    }
356}
357
358/// Classification of how valid an item is based on filesystem metadata.
359#[derive(Debug)]
360enum CacheState {
361    /// The item isn't in the cache.
362    Miss,
363    /// The cached item might be invalid, but it has an ETag we can use for
364    /// further validation.
365    PossiblyDirty { etag: String, path: PathBuf },
366    /// The cached item exists on disk, but we weren't able to tell whether it is still
367    /// valid, and there aren't any other ways to validate it further. You can
368    /// probably reuse this if you are having internet issues.
369    UnableToVerify { path: PathBuf },
370}
371
372impl CacheState {
373    fn take_path(self) -> Option<PathBuf> {
374        match self {
375            CacheState::PossiblyDirty { path, .. } | CacheState::UnableToVerify { path } => {
376                Some(path)
377            }
378            _ => None,
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use std::{collections::VecDeque, sync::Mutex};
386
387    use futures::future::BoxFuture;
388    use http::{HeaderMap, StatusCode, header::IntoHeaderName};
389    use tempfile::TempDir;
390
391    use crate::http::HttpResponse;
392
393    use super::*;
394
395    const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
396    const COREUTILS: &[u8] = include_bytes!(
397        "../../../../../tests/integration/cli/tests/webc/coreutils-1.0.16-e27dbb4f-2ef2-4b44-b46a-ddd86497c6d7.webc"
398    );
399    const DUMMY_URL: &str = "http://my-registry.io/some/package";
400    const DUMMY_URL_HASH: &str = "4D7481F44E1D971A8C60D3C7BD505E2727602CF9369ED623920E029C2BA2351D";
401
402    #[derive(Debug)]
403    pub(crate) struct DummyClient {
404        requests: Mutex<Vec<HttpRequest>>,
405        responses: Mutex<VecDeque<HttpResponse>>,
406    }
407
408    impl DummyClient {
409        pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
410            DummyClient {
411                requests: Mutex::new(Vec::new()),
412                responses: Mutex::new(responses.into_iter().collect()),
413            }
414        }
415    }
416
417    impl HttpClient for DummyClient {
418        fn request(
419            &self,
420            request: HttpRequest,
421        ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
422            let response = self.responses.lock().unwrap().pop_front().unwrap();
423            self.requests.lock().unwrap().push(request);
424            Box::pin(async { Ok(response) })
425        }
426    }
427
428    struct ResponseBuilder(HttpResponse);
429
430    impl ResponseBuilder {
431        pub fn new() -> Self {
432            ResponseBuilder(HttpResponse {
433                body: None,
434                redirected: false,
435                status: StatusCode::OK,
436                headers: HeaderMap::new(),
437            })
438        }
439
440        pub fn with_status(mut self, code: StatusCode) -> Self {
441            self.0.status = code;
442            self
443        }
444
445        pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
446            self.0.body = Some(body.into());
447            self
448        }
449
450        pub fn with_etag(self, value: &str) -> Self {
451            self.with_header("ETag", value)
452        }
453
454        pub fn with_header(mut self, name: impl IntoHeaderName, value: &str) -> Self {
455            self.0.headers.insert(name, value.parse().unwrap());
456            self
457        }
458
459        pub fn build(self) -> HttpResponse {
460            self.0
461        }
462    }
463
464    async fn empty_cache_does_a_full_download_internal() {
465        let dummy_etag = "This is an etag";
466        let temp = TempDir::new().unwrap();
467        let client = DummyClient::with_responses([ResponseBuilder::new()
468            .with_body(PYTHON)
469            .with_etag(dummy_etag)
470            .build()]);
471        let source = WebSource::new(temp.path(), Arc::new(client));
472        let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
473
474        let summaries = source.query(&spec).await.unwrap();
475
476        // We got the right response, as expected
477        assert_eq!(summaries.len(), 1);
478        assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
479        // But we should have also cached the file and etag
480        let path = temp.path().join(DUMMY_URL_HASH);
481        assert!(path.exists());
482        let etag_path = path.with_extension("etag");
483        assert!(etag_path.exists());
484        // And they should contain the correct content
485        assert_eq!(std::fs::read_to_string(etag_path).unwrap(), dummy_etag);
486        assert_eq!(std::fs::read(path).unwrap(), PYTHON);
487    }
488    #[cfg(not(target_arch = "wasm32"))]
489    #[tokio::test(flavor = "multi_thread")]
490    async fn empty_cache_does_a_full_download() {
491        empty_cache_does_a_full_download_internal().await
492    }
493    #[cfg(target_arch = "wasm32")]
494    #[tokio::test()]
495    async fn empty_cache_does_a_full_download() {
496        empty_cache_does_a_full_download_internal().await
497    }
498
499    async fn cache_hit_internal() {
500        let temp = TempDir::new().unwrap();
501        let client = Arc::new(DummyClient::with_responses([]));
502        let source = WebSource::new(temp.path(), client.clone());
503        let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
504        // Prime the cache
505        std::fs::write(temp.path().join(DUMMY_URL_HASH), PYTHON).unwrap();
506
507        let summaries = source.query(&spec).await.unwrap();
508
509        // We got the right response, as expected
510        assert_eq!(summaries.len(), 1);
511        assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
512        // And no requests were sent
513        assert_eq!(client.requests.lock().unwrap().len(), 0);
514    }
515    #[cfg(not(target_arch = "wasm32"))]
516    #[tokio::test(flavor = "multi_thread")]
517    async fn cache_hit() {
518        cache_hit_internal().await
519    }
520    #[cfg(target_arch = "wasm32")]
521    #[tokio::test()]
522    async fn cache_hit() {
523        cache_hit_internal().await
524    }
525
526    async fn fall_back_to_stale_cache_if_request_fails_internal() {
527        let temp = TempDir::new().unwrap();
528        let client = Arc::new(DummyClient::with_responses([ResponseBuilder::new()
529            .with_status(StatusCode::INTERNAL_SERVER_ERROR)
530            .build()]));
531        // Add something to the cache
532        let python_path = temp.path().join(DUMMY_URL_HASH);
533        std::fs::write(&python_path, PYTHON).unwrap();
534        let source = WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::ZERO);
535        let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
536
537        let summaries = source.query(&spec).await.unwrap();
538
539        // We got the right response, as expected
540        assert_eq!(summaries.len(), 1);
541        assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
542        // And one request was sent
543        assert_eq!(client.requests.lock().unwrap().len(), 1);
544        // The etag file wasn't written
545        assert!(!python_path.with_extension("etag").exists());
546    }
547    #[cfg(not(target_arch = "wasm32"))]
548    #[tokio::test(flavor = "multi_thread")]
549    async fn fall_back_to_stale_cache_if_request_fails() {
550        fall_back_to_stale_cache_if_request_fails_internal().await
551    }
552    #[cfg(target_arch = "wasm32")]
553    #[tokio::test()]
554    async fn fall_back_to_stale_cache_if_request_fails() {
555        fall_back_to_stale_cache_if_request_fails_internal().await
556    }
557
558    async fn download_again_if_etag_is_different_internal() {
559        let temp = TempDir::new().unwrap();
560        let client = Arc::new(DummyClient::with_responses([
561            ResponseBuilder::new().with_etag("coreutils").build(),
562            ResponseBuilder::new()
563                .with_body(COREUTILS)
564                .with_etag("coreutils")
565                .build(),
566        ]));
567        // Add Python to the cache
568        let path = temp.path().join(DUMMY_URL_HASH);
569        std::fs::write(&path, PYTHON).unwrap();
570        std::fs::write(path.with_extension("etag"), "python").unwrap();
571        // but create a source that will always want to re-check the etags
572        let source =
573            WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::new(0, 0));
574        let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
575
576        let summaries = source.query(&spec).await.unwrap();
577
578        // Instead of Python (the originally cached item), we should get coreutils
579        assert_eq!(summaries.len(), 1);
580        assert_eq!(
581            summaries[0].pkg.id.as_named().unwrap().full_name,
582            "sharrattj/coreutils"
583        );
584        // both a HEAD and GET request were sent
585        let requests = client.requests.lock().unwrap();
586        assert_eq!(requests.len(), 2);
587        assert_eq!(requests[0].method, "HEAD");
588        assert_eq!(requests[1].method, "GET");
589        // The etag file was also updated
590        assert_eq!(
591            std::fs::read_to_string(path.with_extension("etag")).unwrap(),
592            "coreutils"
593        );
594    }
595    #[cfg(not(target_arch = "wasm32"))]
596    #[tokio::test(flavor = "multi_thread")]
597    async fn download_again_if_etag_is_different() {
598        download_again_if_etag_is_different_internal().await
599    }
600    #[cfg(target_arch = "wasm32")]
601    #[tokio::test()]
602    async fn download_again_if_etag_is_different() {
603        download_again_if_etag_is_different_internal().await
604    }
605}