1use std::{
2 io::Write,
3 path::{Path, PathBuf},
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use anyhow::{Context, Error};
9use http::Method;
10use sha2::{Digest, Sha256};
11use tempfile::NamedTempFile;
12use url::Url;
13use wasmer_config::package::{PackageHash, PackageId, PackageSource};
14use wasmer_package::utils::from_disk;
15
16use crate::{
17 http::{HttpClient, HttpRequest},
18 runtime::resolver::{
19 DistributionInfo, PackageInfo, PackageSummary, QueryError, Source, WebcHash,
20 },
21};
22
23#[derive(Debug, Clone)]
38pub struct WebSource {
39 cache_dir: PathBuf,
40 client: Arc<dyn HttpClient + Send + Sync>,
41 retry_period: Duration,
42}
43
44impl WebSource {
45 pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5 * 60);
46
47 pub fn new(cache_dir: impl Into<PathBuf>, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
48 WebSource {
49 cache_dir: cache_dir.into(),
50 client,
51 retry_period: WebSource::DEFAULT_RETRY_PERIOD,
52 }
53 }
54
55 pub fn with_retry_period(self, retry_period: Duration) -> Self {
58 WebSource {
59 retry_period,
60 ..self
61 }
62 }
63
64 #[tracing::instrument(level = "debug", skip_all, fields(%url))]
66 async fn get_locally_cached_file(&self, url: &Url) -> Result<PathBuf, Error> {
67 let cache_key = sha256(url.as_str().as_bytes());
71
72 let cache_info = CacheInfo::for_url(&cache_key, &self.cache_dir);
74
75 let state = match classify_cache_using_mtime(cache_info, self.retry_period) {
77 Ok(path) => {
78 tracing::debug!(path=%path.display(), "Cache hit!");
79 return Ok(path);
80 }
81 Err(s) => s,
82 };
83
84 if let CacheState::PossiblyDirty { etag, path } = &state {
86 match self.get_etag(url).await {
87 Ok(new_etag) if new_etag == *etag => {
88 return Ok(path.clone());
89 }
90 Ok(different_etag) => {
91 tracing::debug!(
92 original_etag=%etag,
93 new_etag=%different_etag,
94 path=%path.display(),
95 "File has been updated. Redownloading.",
96 );
97 }
98 Err(e) => {
99 tracing::debug!(
100 error=&*e,
101 path=%path.display(),
102 original_etag=%etag,
103 "Unable to check if the etag is out of date",
104 )
105 }
106 }
107 }
108
109 let (bytes, etag) = match self.fetch(url).await {
111 Ok((bytes, etag)) => (bytes, etag),
112 Err(e) => {
113 tracing::warn!(error = &*e, "Download failed");
114 match state.take_path() {
115 Some(path) => {
116 tracing::debug!(
117 path=%path.display(),
118 "Using a possibly stale cached file",
119 );
120 return Ok(path);
121 }
122 None => {
123 return Err(e);
124 }
125 }
126 }
127 };
128
129 let path = self.cache_dir.join(&cache_key);
130 self.atomically_save_file(&path, &bytes)
131 .await
132 .with_context(|| {
133 format!(
134 "Unable to save the downloaded file to \"{}\"",
135 path.display()
136 )
137 })?;
138
139 if let Some(etag) = etag
140 && let Err(e) = self
141 .atomically_save_file(path.with_extension("etag"), etag.as_bytes())
142 .await
143 {
144 tracing::warn!(
145 error=&*e,
146 %etag,
147 %url,
148 path=%path.display(),
149 "Unable to save the etag file",
150 )
151 }
152
153 Ok(path)
154 }
155
156 async fn atomically_save_file(&self, path: impl AsRef<Path>, data: &[u8]) -> Result<(), Error> {
157 let path = path.as_ref();
160
161 if let Some(parent) = path.parent() {
162 std::fs::create_dir_all(parent)
163 .with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
164 }
165
166 let mut temp = NamedTempFile::new_in(&self.cache_dir)?;
167 temp.write_all(data)?;
168 temp.as_file().sync_all()?;
169 temp.persist(path)?;
170
171 Ok(())
172 }
173
174 async fn get_etag(&self, url: &Url) -> Result<String, Error> {
175 let request = HttpRequest {
176 url: url.clone(),
177 method: Method::HEAD,
178 headers: super::utils::webc_headers(),
179 body: None,
180 options: Default::default(),
181 };
182
183 let response = self.client.request(request).await?;
184
185 if !response.is_ok() {
186 return Err(super::utils::http_error(&response)
187 .context(format!("The HEAD request to \"{url}\" failed")));
188 }
189
190 let etag = response
191 .headers
192 .get("ETag")
193 .context("The HEAD request didn't contain an ETag header`")?
194 .to_str()
195 .context("The ETag wasn't valid UTF-8")?;
196
197 Ok(etag.to_string())
198 }
199
200 async fn fetch(&self, url: &Url) -> Result<(Vec<u8>, Option<String>), Error> {
201 let request = HttpRequest {
202 url: url.clone(),
203 method: Method::GET,
204 headers: super::utils::webc_headers(),
205 body: None,
206 options: Default::default(),
207 };
208 let response = self.client.request(request).await?;
209
210 if !response.is_ok() {
211 return Err(super::utils::http_error(&response)
212 .context(format!("The GET request to \"{url}\" failed")));
213 }
214
215 let body = response.body.context("Response didn't contain a body")?;
216
217 let etag = response
218 .headers
219 .get("ETag")
220 .and_then(|etag| etag.to_str().ok())
221 .map(|etag| etag.to_string());
222
223 Ok((body, etag))
224 }
225
226 async fn load_url(&self, url: &Url) -> Result<Vec<PackageSummary>, anyhow::Error> {
227 let local_path = self
228 .get_locally_cached_file(url)
229 .await
230 .context("Unable to get the locally cached file")?;
231
232 let webc_sha256 = crate::block_in_place(|| WebcHash::for_file(&local_path))
233 .with_context(|| format!("Unable to hash \"{}\"", local_path.display()))?;
234
235 let container = crate::block_in_place(|| from_disk(&local_path))
238 .with_context(|| format!("Unable to load \"{}\"", local_path.display()))?;
239
240 let id = PackageInfo::package_id_from_manifest(container.manifest())?
241 .unwrap_or_else(|| PackageId::Hash(PackageHash::from_sha256_bytes(webc_sha256.0)));
242
243 let pkg = PackageInfo::from_manifest(id, container.manifest(), container.version())
244 .context("Unable to determine the package's metadata")?;
245
246 let dist = DistributionInfo {
247 webc: url.clone(),
248 webc_sha256,
249 };
250
251 Ok(vec![PackageSummary { pkg, dist }])
252 }
253}
254
255#[async_trait::async_trait]
256impl Source for WebSource {
257 #[tracing::instrument(level = "debug", skip_all, fields(%package))]
258 async fn query(&self, package: &PackageSource) -> Result<Vec<PackageSummary>, QueryError> {
259 let url = match package {
260 PackageSource::Url(url) => url,
261 _ => {
262 return Err(QueryError::Unsupported {
263 query: package.clone(),
264 });
265 }
266 };
267
268 self.load_url(url)
269 .await
270 .map_err(|error| QueryError::new_other(error, package))
271 }
272}
273
274fn sha256(bytes: &[u8]) -> String {
275 let mut hasher = Sha256::default();
276 hasher.update(bytes);
277 let hash = hasher.finalize();
278 hex::encode_upper(hash)
279}
280
281#[derive(Debug, Clone, PartialEq)]
282enum CacheInfo {
283 Miss,
285 Hit {
287 path: PathBuf,
288 etag: Option<String>,
289 last_modified: Option<SystemTime>,
290 },
291}
292
293impl CacheInfo {
294 fn for_url(key: &str, checkout_dir: &Path) -> CacheInfo {
295 let path = checkout_dir.join(key);
296
297 if !path.exists() {
298 return CacheInfo::Miss;
299 }
300
301 let etag = std::fs::read_to_string(path.with_extension("etag")).ok();
302 let last_modified = path.metadata().and_then(|m| m.modified()).ok();
303
304 CacheInfo::Hit {
305 etag,
306 last_modified,
307 path,
308 }
309 }
310}
311
312fn classify_cache_using_mtime(
313 info: CacheInfo,
314 invalidation_threshold: Duration,
315) -> Result<PathBuf, CacheState> {
316 let (path, last_modified, etag) = match info {
317 CacheInfo::Hit {
318 path,
319 last_modified: Some(last_modified),
320 etag,
321 ..
322 } => (path, last_modified, etag),
323 CacheInfo::Hit {
324 path,
325 last_modified: None,
326 etag: Some(etag),
327 ..
328 } => return Err(CacheState::PossiblyDirty { etag, path }),
329 CacheInfo::Hit {
330 etag: None,
331 last_modified: None,
332 path,
333 ..
334 } => {
335 return Err(CacheState::UnableToVerify { path });
336 }
337 CacheInfo::Miss => return Err(CacheState::Miss),
338 };
339
340 if let Ok(time_since_last_modified) = last_modified.elapsed()
341 && time_since_last_modified <= invalidation_threshold
342 {
343 return Ok(path);
344 }
345
346 match etag {
347 Some(etag) => Err(CacheState::PossiblyDirty { etag, path }),
348 None => Err(CacheState::UnableToVerify { path }),
349 }
350}
351
352#[derive(Debug)]
354enum CacheState {
355 Miss,
357 PossiblyDirty { etag: String, path: PathBuf },
360 UnableToVerify { path: PathBuf },
364}
365
366impl CacheState {
367 fn take_path(self) -> Option<PathBuf> {
368 match self {
369 CacheState::PossiblyDirty { path, .. } | CacheState::UnableToVerify { path } => {
370 Some(path)
371 }
372 _ => None,
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use std::{collections::VecDeque, sync::Mutex};
380
381 use futures::future::BoxFuture;
382 use http::{HeaderMap, StatusCode, header::IntoHeaderName};
383 use tempfile::TempDir;
384
385 use crate::http::HttpResponse;
386
387 use super::*;
388
389 const PYTHON: &[u8] = include_bytes!(concat!(
390 env!("CARGO_MANIFEST_DIR"),
391 "/../../wasmer-test-files/examples/python-0.1.0.wasmer"
392 ));
393 const COREUTILS: &[u8] = include_bytes!(concat!(
394 env!("CARGO_MANIFEST_DIR"),
395 "/../../wasmer-test-files/integration/webc/coreutils-1.0.16-e27dbb4f-2ef2-4b44-b46a-ddd86497c6d7.webc"
396 ));
397 const DUMMY_URL: &str = "http://my-registry.io/some/package";
398 const DUMMY_URL_HASH: &str = "4D7481F44E1D971A8C60D3C7BD505E2727602CF9369ED623920E029C2BA2351D";
399
400 #[derive(Debug)]
401 pub(crate) struct DummyClient {
402 requests: Mutex<Vec<HttpRequest>>,
403 responses: Mutex<VecDeque<HttpResponse>>,
404 }
405
406 impl DummyClient {
407 pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
408 DummyClient {
409 requests: Mutex::new(Vec::new()),
410 responses: Mutex::new(responses.into_iter().collect()),
411 }
412 }
413 }
414
415 impl HttpClient for DummyClient {
416 fn request(
417 &self,
418 request: HttpRequest,
419 ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
420 let response = self.responses.lock().unwrap().pop_front().unwrap();
421 self.requests.lock().unwrap().push(request);
422 Box::pin(async { Ok(response) })
423 }
424 }
425
426 struct ResponseBuilder(HttpResponse);
427
428 impl ResponseBuilder {
429 pub fn new() -> Self {
430 ResponseBuilder(HttpResponse {
431 body: None,
432 redirected: false,
433 status: StatusCode::OK,
434 headers: HeaderMap::new(),
435 })
436 }
437
438 pub fn with_status(mut self, code: StatusCode) -> Self {
439 self.0.status = code;
440 self
441 }
442
443 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
444 self.0.body = Some(body.into());
445 self
446 }
447
448 pub fn with_etag(self, value: &str) -> Self {
449 self.with_header("ETag", value)
450 }
451
452 pub fn with_header(mut self, name: impl IntoHeaderName, value: &str) -> Self {
453 self.0.headers.insert(name, value.parse().unwrap());
454 self
455 }
456
457 pub fn build(self) -> HttpResponse {
458 self.0
459 }
460 }
461
462 async fn empty_cache_does_a_full_download_internal() {
463 let dummy_etag = "This is an etag";
464 let temp = TempDir::new().unwrap();
465 let client = DummyClient::with_responses([ResponseBuilder::new()
466 .with_body(PYTHON)
467 .with_etag(dummy_etag)
468 .build()]);
469 let source = WebSource::new(temp.path(), Arc::new(client));
470 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
471
472 let summaries = source.query(&spec).await.unwrap();
473
474 assert_eq!(summaries.len(), 1);
476 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
477 let path = temp.path().join(DUMMY_URL_HASH);
479 assert!(path.exists());
480 let etag_path = path.with_extension("etag");
481 assert!(etag_path.exists());
482 assert_eq!(std::fs::read_to_string(etag_path).unwrap(), dummy_etag);
484 assert_eq!(std::fs::read(path).unwrap(), PYTHON);
485 }
486 #[cfg(not(target_arch = "wasm32"))]
487 #[tokio::test(flavor = "multi_thread")]
488 async fn empty_cache_does_a_full_download() {
489 empty_cache_does_a_full_download_internal().await
490 }
491 #[cfg(target_arch = "wasm32")]
492 #[tokio::test()]
493 async fn empty_cache_does_a_full_download() {
494 empty_cache_does_a_full_download_internal().await
495 }
496
497 async fn cache_hit_internal() {
498 let temp = TempDir::new().unwrap();
499 let client = Arc::new(DummyClient::with_responses([]));
500 let source = WebSource::new(temp.path(), client.clone());
501 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
502 std::fs::write(temp.path().join(DUMMY_URL_HASH), PYTHON).unwrap();
504
505 let summaries = source.query(&spec).await.unwrap();
506
507 assert_eq!(summaries.len(), 1);
509 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
510 assert_eq!(client.requests.lock().unwrap().len(), 0);
512 }
513 #[cfg(not(target_arch = "wasm32"))]
514 #[tokio::test(flavor = "multi_thread")]
515 async fn cache_hit() {
516 cache_hit_internal().await
517 }
518 #[cfg(target_arch = "wasm32")]
519 #[tokio::test()]
520 async fn cache_hit() {
521 cache_hit_internal().await
522 }
523
524 async fn fall_back_to_stale_cache_if_request_fails_internal() {
525 let temp = TempDir::new().unwrap();
526 let client = Arc::new(DummyClient::with_responses([ResponseBuilder::new()
527 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
528 .build()]));
529 let python_path = temp.path().join(DUMMY_URL_HASH);
531 std::fs::write(&python_path, PYTHON).unwrap();
532 let source = WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::ZERO);
533 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
534
535 let summaries = source.query(&spec).await.unwrap();
536
537 assert_eq!(summaries.len(), 1);
539 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
540 assert_eq!(client.requests.lock().unwrap().len(), 1);
542 assert!(!python_path.with_extension("etag").exists());
544 }
545 #[cfg(not(target_arch = "wasm32"))]
546 #[tokio::test(flavor = "multi_thread")]
547 async fn fall_back_to_stale_cache_if_request_fails() {
548 fall_back_to_stale_cache_if_request_fails_internal().await
549 }
550 #[cfg(target_arch = "wasm32")]
551 #[tokio::test()]
552 async fn fall_back_to_stale_cache_if_request_fails() {
553 fall_back_to_stale_cache_if_request_fails_internal().await
554 }
555
556 async fn download_again_if_etag_is_different_internal() {
557 let temp = TempDir::new().unwrap();
558 let client = Arc::new(DummyClient::with_responses([
559 ResponseBuilder::new().with_etag("coreutils").build(),
560 ResponseBuilder::new()
561 .with_body(COREUTILS)
562 .with_etag("coreutils")
563 .build(),
564 ]));
565 let path = temp.path().join(DUMMY_URL_HASH);
567 std::fs::write(&path, PYTHON).unwrap();
568 std::fs::write(path.with_extension("etag"), "python").unwrap();
569 let source =
571 WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::new(0, 0));
572 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
573
574 let summaries = source.query(&spec).await.unwrap();
575
576 assert_eq!(summaries.len(), 1);
578 assert_eq!(
579 summaries[0].pkg.id.as_named().unwrap().full_name,
580 "sharrattj/coreutils"
581 );
582 let requests = client.requests.lock().unwrap();
584 assert_eq!(requests.len(), 2);
585 assert_eq!(requests[0].method, "HEAD");
586 assert_eq!(requests[1].method, "GET");
587 assert_eq!(
589 std::fs::read_to_string(path.with_extension("etag")).unwrap(),
590 "coreutils"
591 );
592 }
593 #[cfg(not(target_arch = "wasm32"))]
594 #[tokio::test(flavor = "multi_thread")]
595 async fn download_again_if_etag_is_different() {
596 download_again_if_etag_is_different_internal().await
597 }
598 #[cfg(target_arch = "wasm32")]
599 #[tokio::test()]
600 async fn download_again_if_etag_is_different() {
601 download_again_if_etag_is_different_internal().await
602 }
603}