1use std::{
2 io::Write,
3 path::{Path, PathBuf},
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use anyhow::{Context, Error};
9use http::Method;
10use sha2::{Digest, Sha256};
11use tempfile::NamedTempFile;
12use url::Url;
13use wasmer_config::package::{PackageHash, PackageId, PackageSource};
14use wasmer_package::utils::from_disk;
15
16use crate::{
17 http::{HttpClient, HttpRequest},
18 runtime::resolver::{
19 DistributionInfo, PackageInfo, PackageSummary, QueryError, Source, WebcHash,
20 },
21};
22
23#[derive(Debug, Clone)]
38pub struct WebSource {
39 cache_dir: PathBuf,
40 client: Arc<dyn HttpClient + Send + Sync>,
41 retry_period: Duration,
42}
43
44impl WebSource {
45 pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5 * 60);
46
47 pub fn new(cache_dir: impl Into<PathBuf>, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
48 WebSource {
49 cache_dir: cache_dir.into(),
50 client,
51 retry_period: WebSource::DEFAULT_RETRY_PERIOD,
52 }
53 }
54
55 pub fn with_retry_period(self, retry_period: Duration) -> Self {
58 WebSource {
59 retry_period,
60 ..self
61 }
62 }
63
64 #[tracing::instrument(level = "debug", skip_all, fields(%url))]
66 async fn get_locally_cached_file(&self, url: &Url) -> Result<PathBuf, Error> {
67 let cache_key = sha256(url.as_str().as_bytes());
71
72 let cache_info = CacheInfo::for_url(&cache_key, &self.cache_dir);
74
75 let state = match classify_cache_using_mtime(cache_info, self.retry_period) {
77 Ok(path) => {
78 tracing::debug!(path=%path.display(), "Cache hit!");
79 return Ok(path);
80 }
81 Err(s) => s,
82 };
83
84 if let CacheState::PossiblyDirty { etag, path } = &state {
86 match self.get_etag(url).await {
87 Ok(new_etag) if new_etag == *etag => {
88 return Ok(path.clone());
89 }
90 Ok(different_etag) => {
91 tracing::debug!(
92 original_etag=%etag,
93 new_etag=%different_etag,
94 path=%path.display(),
95 "File has been updated. Redownloading.",
96 );
97 }
98 Err(e) => {
99 tracing::debug!(
100 error=&*e,
101 path=%path.display(),
102 original_etag=%etag,
103 "Unable to check if the etag is out of date",
104 )
105 }
106 }
107 }
108
109 let (bytes, etag) = match self.fetch(url).await {
111 Ok((bytes, etag)) => (bytes, etag),
112 Err(e) => {
113 tracing::warn!(error = &*e, "Download failed");
114 match state.take_path() {
115 Some(path) => {
116 tracing::debug!(
117 path=%path.display(),
118 "Using a possibly stale cached file",
119 );
120 return Ok(path);
121 }
122 None => {
123 return Err(e);
124 }
125 }
126 }
127 };
128
129 let path = self.cache_dir.join(&cache_key);
130 self.atomically_save_file(&path, &bytes)
131 .await
132 .with_context(|| {
133 format!(
134 "Unable to save the downloaded file to \"{}\"",
135 path.display()
136 )
137 })?;
138
139 if let Some(etag) = etag
140 && let Err(e) = self
141 .atomically_save_file(path.with_extension("etag"), etag.as_bytes())
142 .await
143 {
144 tracing::warn!(
145 error=&*e,
146 %etag,
147 %url,
148 path=%path.display(),
149 "Unable to save the etag file",
150 )
151 }
152
153 Ok(path)
154 }
155
156 async fn atomically_save_file(&self, path: impl AsRef<Path>, data: &[u8]) -> Result<(), Error> {
157 let path = path.as_ref();
160
161 if let Some(parent) = path.parent() {
162 std::fs::create_dir_all(parent)
163 .with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
164 }
165
166 let mut temp = NamedTempFile::new_in(&self.cache_dir)?;
167 temp.write_all(data)?;
168 temp.as_file().sync_all()?;
169 temp.persist(path)?;
170
171 Ok(())
172 }
173
174 async fn get_etag(&self, url: &Url) -> Result<String, Error> {
175 let request = HttpRequest {
176 url: url.clone(),
177 method: Method::HEAD,
178 headers: super::utils::webc_headers(),
179 body: None,
180 options: Default::default(),
181 };
182
183 let response = self.client.request(request).await?;
184
185 if !response.is_ok() {
186 return Err(super::utils::http_error(&response)
187 .context(format!("The HEAD request to \"{url}\" failed")));
188 }
189
190 let etag = response
191 .headers
192 .get("ETag")
193 .context("The HEAD request didn't contain an ETag header`")?
194 .to_str()
195 .context("The ETag wasn't valid UTF-8")?;
196
197 Ok(etag.to_string())
198 }
199
200 async fn fetch(&self, url: &Url) -> Result<(Vec<u8>, Option<String>), Error> {
201 let request = HttpRequest {
202 url: url.clone(),
203 method: Method::GET,
204 headers: super::utils::webc_headers(),
205 body: None,
206 options: Default::default(),
207 };
208 let response = self.client.request(request).await?;
209
210 if !response.is_ok() {
211 return Err(super::utils::http_error(&response)
212 .context(format!("The GET request to \"{url}\" failed")));
213 }
214
215 let body = response.body.context("Response didn't contain a body")?;
216
217 let etag = response
218 .headers
219 .get("ETag")
220 .and_then(|etag| etag.to_str().ok())
221 .map(|etag| etag.to_string());
222
223 Ok((body, etag))
224 }
225
226 async fn load_url(&self, url: &Url) -> Result<Vec<PackageSummary>, anyhow::Error> {
227 let local_path = self
228 .get_locally_cached_file(url)
229 .await
230 .context("Unable to get the locally cached file")?;
231
232 let webc_sha256 = crate::block_in_place(|| WebcHash::for_file(&local_path))
233 .with_context(|| format!("Unable to hash \"{}\"", local_path.display()))?;
234
235 let container = crate::block_in_place(|| from_disk(&local_path))
238 .with_context(|| format!("Unable to load \"{}\"", local_path.display()))?;
239
240 let id = PackageInfo::package_id_from_manifest(container.manifest())?
241 .unwrap_or_else(|| PackageId::Hash(PackageHash::from_sha256_bytes(webc_sha256.0)));
242
243 let pkg = PackageInfo::from_manifest(id, container.manifest(), container.version())
244 .context("Unable to determine the package's metadata")?;
245
246 let dist = DistributionInfo {
247 webc: url.clone(),
248 webc_sha256,
249 };
250
251 Ok(vec![PackageSummary { pkg, dist }])
252 }
253}
254
255#[async_trait::async_trait]
256impl Source for WebSource {
257 #[tracing::instrument(level = "debug", skip_all, fields(%package))]
258 async fn query(&self, package: &PackageSource) -> Result<Vec<PackageSummary>, QueryError> {
259 let url = match package {
260 PackageSource::Url(url) => url,
261 _ => {
262 return Err(QueryError::Unsupported {
263 query: package.clone(),
264 });
265 }
266 };
267
268 self.load_url(url)
269 .await
270 .map_err(|error| QueryError::new_other(error, package))
271 }
272}
273
274fn sha256(bytes: &[u8]) -> String {
275 let mut hasher = Sha256::default();
276 hasher.update(bytes);
277 let hash = hasher.finalize();
278 hex::encode_upper(hash)
279}
280
281#[derive(Debug, Clone, PartialEq)]
282enum CacheInfo {
283 Miss,
285 Hit {
287 path: PathBuf,
288 etag: Option<String>,
289 last_modified: Option<SystemTime>,
290 },
291}
292
293impl CacheInfo {
294 fn for_url(key: &str, checkout_dir: &Path) -> CacheInfo {
295 let path = checkout_dir.join(key);
296
297 if !path.exists() {
298 return CacheInfo::Miss;
299 }
300
301 let etag = std::fs::read_to_string(path.with_extension("etag")).ok();
302 let last_modified = path.metadata().and_then(|m| m.modified()).ok();
303
304 CacheInfo::Hit {
305 etag,
306 last_modified,
307 path,
308 }
309 }
310}
311
312fn classify_cache_using_mtime(
313 info: CacheInfo,
314 invalidation_threshold: Duration,
315) -> Result<PathBuf, CacheState> {
316 let (path, last_modified, etag) = match info {
317 CacheInfo::Hit {
318 path,
319 last_modified: Some(last_modified),
320 etag,
321 ..
322 } => (path, last_modified, etag),
323 CacheInfo::Hit {
324 path,
325 last_modified: None,
326 etag: Some(etag),
327 ..
328 } => return Err(CacheState::PossiblyDirty { etag, path }),
329 CacheInfo::Hit {
330 etag: None,
331 last_modified: None,
332 path,
333 ..
334 } => {
335 return Err(CacheState::UnableToVerify { path });
336 }
337 CacheInfo::Miss => return Err(CacheState::Miss),
338 };
339
340 if let Ok(time_since_last_modified) = last_modified.elapsed()
341 && time_since_last_modified <= invalidation_threshold
342 {
343 return Ok(path);
344 }
345
346 match etag {
347 Some(etag) => Err(CacheState::PossiblyDirty { etag, path }),
348 None => Err(CacheState::UnableToVerify { path }),
349 }
350}
351
352#[derive(Debug)]
354enum CacheState {
355 Miss,
357 PossiblyDirty { etag: String, path: PathBuf },
360 UnableToVerify { path: PathBuf },
364}
365
366impl CacheState {
367 fn take_path(self) -> Option<PathBuf> {
368 match self {
369 CacheState::PossiblyDirty { path, .. } | CacheState::UnableToVerify { path } => {
370 Some(path)
371 }
372 _ => None,
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use std::{collections::VecDeque, sync::Mutex};
380
381 use futures::future::BoxFuture;
382 use http::{HeaderMap, StatusCode, header::IntoHeaderName};
383 use tempfile::TempDir;
384
385 use crate::http::HttpResponse;
386
387 use super::*;
388
389 const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
390 const COREUTILS: &[u8] = include_bytes!(
391 "../../../../../tests/integration/cli/tests/webc/coreutils-1.0.16-e27dbb4f-2ef2-4b44-b46a-ddd86497c6d7.webc"
392 );
393 const DUMMY_URL: &str = "http://my-registry.io/some/package";
394 const DUMMY_URL_HASH: &str = "4D7481F44E1D971A8C60D3C7BD505E2727602CF9369ED623920E029C2BA2351D";
395
396 #[derive(Debug)]
397 pub(crate) struct DummyClient {
398 requests: Mutex<Vec<HttpRequest>>,
399 responses: Mutex<VecDeque<HttpResponse>>,
400 }
401
402 impl DummyClient {
403 pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
404 DummyClient {
405 requests: Mutex::new(Vec::new()),
406 responses: Mutex::new(responses.into_iter().collect()),
407 }
408 }
409 }
410
411 impl HttpClient for DummyClient {
412 fn request(
413 &self,
414 request: HttpRequest,
415 ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
416 let response = self.responses.lock().unwrap().pop_front().unwrap();
417 self.requests.lock().unwrap().push(request);
418 Box::pin(async { Ok(response) })
419 }
420 }
421
422 struct ResponseBuilder(HttpResponse);
423
424 impl ResponseBuilder {
425 pub fn new() -> Self {
426 ResponseBuilder(HttpResponse {
427 body: None,
428 redirected: false,
429 status: StatusCode::OK,
430 headers: HeaderMap::new(),
431 })
432 }
433
434 pub fn with_status(mut self, code: StatusCode) -> Self {
435 self.0.status = code;
436 self
437 }
438
439 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
440 self.0.body = Some(body.into());
441 self
442 }
443
444 pub fn with_etag(self, value: &str) -> Self {
445 self.with_header("ETag", value)
446 }
447
448 pub fn with_header(mut self, name: impl IntoHeaderName, value: &str) -> Self {
449 self.0.headers.insert(name, value.parse().unwrap());
450 self
451 }
452
453 pub fn build(self) -> HttpResponse {
454 self.0
455 }
456 }
457
458 async fn empty_cache_does_a_full_download_internal() {
459 let dummy_etag = "This is an etag";
460 let temp = TempDir::new().unwrap();
461 let client = DummyClient::with_responses([ResponseBuilder::new()
462 .with_body(PYTHON)
463 .with_etag(dummy_etag)
464 .build()]);
465 let source = WebSource::new(temp.path(), Arc::new(client));
466 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
467
468 let summaries = source.query(&spec).await.unwrap();
469
470 assert_eq!(summaries.len(), 1);
472 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
473 let path = temp.path().join(DUMMY_URL_HASH);
475 assert!(path.exists());
476 let etag_path = path.with_extension("etag");
477 assert!(etag_path.exists());
478 assert_eq!(std::fs::read_to_string(etag_path).unwrap(), dummy_etag);
480 assert_eq!(std::fs::read(path).unwrap(), PYTHON);
481 }
482 #[cfg(not(target_arch = "wasm32"))]
483 #[tokio::test(flavor = "multi_thread")]
484 async fn empty_cache_does_a_full_download() {
485 empty_cache_does_a_full_download_internal().await
486 }
487 #[cfg(target_arch = "wasm32")]
488 #[tokio::test()]
489 async fn empty_cache_does_a_full_download() {
490 empty_cache_does_a_full_download_internal().await
491 }
492
493 async fn cache_hit_internal() {
494 let temp = TempDir::new().unwrap();
495 let client = Arc::new(DummyClient::with_responses([]));
496 let source = WebSource::new(temp.path(), client.clone());
497 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
498 std::fs::write(temp.path().join(DUMMY_URL_HASH), PYTHON).unwrap();
500
501 let summaries = source.query(&spec).await.unwrap();
502
503 assert_eq!(summaries.len(), 1);
505 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
506 assert_eq!(client.requests.lock().unwrap().len(), 0);
508 }
509 #[cfg(not(target_arch = "wasm32"))]
510 #[tokio::test(flavor = "multi_thread")]
511 async fn cache_hit() {
512 cache_hit_internal().await
513 }
514 #[cfg(target_arch = "wasm32")]
515 #[tokio::test()]
516 async fn cache_hit() {
517 cache_hit_internal().await
518 }
519
520 async fn fall_back_to_stale_cache_if_request_fails_internal() {
521 let temp = TempDir::new().unwrap();
522 let client = Arc::new(DummyClient::with_responses([ResponseBuilder::new()
523 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
524 .build()]));
525 let python_path = temp.path().join(DUMMY_URL_HASH);
527 std::fs::write(&python_path, PYTHON).unwrap();
528 let source = WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::ZERO);
529 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
530
531 let summaries = source.query(&spec).await.unwrap();
532
533 assert_eq!(summaries.len(), 1);
535 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
536 assert_eq!(client.requests.lock().unwrap().len(), 1);
538 assert!(!python_path.with_extension("etag").exists());
540 }
541 #[cfg(not(target_arch = "wasm32"))]
542 #[tokio::test(flavor = "multi_thread")]
543 async fn fall_back_to_stale_cache_if_request_fails() {
544 fall_back_to_stale_cache_if_request_fails_internal().await
545 }
546 #[cfg(target_arch = "wasm32")]
547 #[tokio::test()]
548 async fn fall_back_to_stale_cache_if_request_fails() {
549 fall_back_to_stale_cache_if_request_fails_internal().await
550 }
551
552 async fn download_again_if_etag_is_different_internal() {
553 let temp = TempDir::new().unwrap();
554 let client = Arc::new(DummyClient::with_responses([
555 ResponseBuilder::new().with_etag("coreutils").build(),
556 ResponseBuilder::new()
557 .with_body(COREUTILS)
558 .with_etag("coreutils")
559 .build(),
560 ]));
561 let path = temp.path().join(DUMMY_URL_HASH);
563 std::fs::write(&path, PYTHON).unwrap();
564 std::fs::write(path.with_extension("etag"), "python").unwrap();
565 let source =
567 WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::new(0, 0));
568 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
569
570 let summaries = source.query(&spec).await.unwrap();
571
572 assert_eq!(summaries.len(), 1);
574 assert_eq!(
575 summaries[0].pkg.id.as_named().unwrap().full_name,
576 "sharrattj/coreutils"
577 );
578 let requests = client.requests.lock().unwrap();
580 assert_eq!(requests.len(), 2);
581 assert_eq!(requests[0].method, "HEAD");
582 assert_eq!(requests[1].method, "GET");
583 assert_eq!(
585 std::fs::read_to_string(path.with_extension("etag")).unwrap(),
586 "coreutils"
587 );
588 }
589 #[cfg(not(target_arch = "wasm32"))]
590 #[tokio::test(flavor = "multi_thread")]
591 async fn download_again_if_etag_is_different() {
592 download_again_if_etag_is_different_internal().await
593 }
594 #[cfg(target_arch = "wasm32")]
595 #[tokio::test()]
596 async fn download_again_if_etag_is_different() {
597 download_again_if_etag_is_different_internal().await
598 }
599}