1use std::{
2 fmt::Write as _,
3 io::Write,
4 path::{Path, PathBuf},
5 sync::Arc,
6 time::{Duration, SystemTime},
7};
8
9use anyhow::{Context, Error};
10use http::Method;
11use sha2::{Digest, Sha256};
12use tempfile::NamedTempFile;
13use url::Url;
14use wasmer_config::package::{PackageHash, PackageId, PackageSource};
15use wasmer_package::utils::from_disk;
16
17use crate::{
18 http::{HttpClient, HttpRequest},
19 runtime::resolver::{
20 DistributionInfo, PackageInfo, PackageSummary, QueryError, Source, WebcHash,
21 },
22};
23
24#[derive(Debug, Clone)]
39pub struct WebSource {
40 cache_dir: PathBuf,
41 client: Arc<dyn HttpClient + Send + Sync>,
42 retry_period: Duration,
43}
44
45impl WebSource {
46 pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5 * 60);
47
48 pub fn new(cache_dir: impl Into<PathBuf>, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
49 WebSource {
50 cache_dir: cache_dir.into(),
51 client,
52 retry_period: WebSource::DEFAULT_RETRY_PERIOD,
53 }
54 }
55
56 pub fn with_retry_period(self, retry_period: Duration) -> Self {
59 WebSource {
60 retry_period,
61 ..self
62 }
63 }
64
65 #[tracing::instrument(level = "debug", skip_all, fields(%url))]
67 async fn get_locally_cached_file(&self, url: &Url) -> Result<PathBuf, Error> {
68 let cache_key = sha256(url.as_str().as_bytes());
72
73 let cache_info = CacheInfo::for_url(&cache_key, &self.cache_dir);
75
76 let state = match classify_cache_using_mtime(cache_info, self.retry_period) {
78 Ok(path) => {
79 tracing::debug!(path=%path.display(), "Cache hit!");
80 return Ok(path);
81 }
82 Err(s) => s,
83 };
84
85 if let CacheState::PossiblyDirty { etag, path } = &state {
87 match self.get_etag(url).await {
88 Ok(new_etag) if new_etag == *etag => {
89 return Ok(path.clone());
90 }
91 Ok(different_etag) => {
92 tracing::debug!(
93 original_etag=%etag,
94 new_etag=%different_etag,
95 path=%path.display(),
96 "File has been updated. Redownloading.",
97 );
98 }
99 Err(e) => {
100 tracing::debug!(
101 error=&*e,
102 path=%path.display(),
103 original_etag=%etag,
104 "Unable to check if the etag is out of date",
105 )
106 }
107 }
108 }
109
110 let (bytes, etag) = match self.fetch(url).await {
112 Ok((bytes, etag)) => (bytes, etag),
113 Err(e) => {
114 tracing::warn!(error = &*e, "Download failed");
115 match state.take_path() {
116 Some(path) => {
117 tracing::debug!(
118 path=%path.display(),
119 "Using a possibly stale cached file",
120 );
121 return Ok(path);
122 }
123 None => {
124 return Err(e);
125 }
126 }
127 }
128 };
129
130 let path = self.cache_dir.join(&cache_key);
131 self.atomically_save_file(&path, &bytes)
132 .await
133 .with_context(|| {
134 format!(
135 "Unable to save the downloaded file to \"{}\"",
136 path.display()
137 )
138 })?;
139
140 if let Some(etag) = etag
141 && let Err(e) = self
142 .atomically_save_file(path.with_extension("etag"), etag.as_bytes())
143 .await
144 {
145 tracing::warn!(
146 error=&*e,
147 %etag,
148 %url,
149 path=%path.display(),
150 "Unable to save the etag file",
151 )
152 }
153
154 Ok(path)
155 }
156
157 async fn atomically_save_file(&self, path: impl AsRef<Path>, data: &[u8]) -> Result<(), Error> {
158 let path = path.as_ref();
161
162 if let Some(parent) = path.parent() {
163 std::fs::create_dir_all(parent)
164 .with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
165 }
166
167 let mut temp = NamedTempFile::new_in(&self.cache_dir)?;
168 temp.write_all(data)?;
169 temp.as_file().sync_all()?;
170 temp.persist(path)?;
171
172 Ok(())
173 }
174
175 async fn get_etag(&self, url: &Url) -> Result<String, Error> {
176 let request = HttpRequest {
177 url: url.clone(),
178 method: Method::HEAD,
179 headers: super::utils::webc_headers(),
180 body: None,
181 options: Default::default(),
182 };
183
184 let response = self.client.request(request).await?;
185
186 if !response.is_ok() {
187 return Err(super::utils::http_error(&response)
188 .context(format!("The HEAD request to \"{url}\" failed")));
189 }
190
191 let etag = response
192 .headers
193 .get("ETag")
194 .context("The HEAD request didn't contain an ETag header`")?
195 .to_str()
196 .context("The ETag wasn't valid UTF-8")?;
197
198 Ok(etag.to_string())
199 }
200
201 async fn fetch(&self, url: &Url) -> Result<(Vec<u8>, Option<String>), Error> {
202 let request = HttpRequest {
203 url: url.clone(),
204 method: Method::GET,
205 headers: super::utils::webc_headers(),
206 body: None,
207 options: Default::default(),
208 };
209 let response = self.client.request(request).await?;
210
211 if !response.is_ok() {
212 return Err(super::utils::http_error(&response)
213 .context(format!("The GET request to \"{url}\" failed")));
214 }
215
216 let body = response.body.context("Response didn't contain a body")?;
217
218 let etag = response
219 .headers
220 .get("ETag")
221 .and_then(|etag| etag.to_str().ok())
222 .map(|etag| etag.to_string());
223
224 Ok((body, etag))
225 }
226
227 async fn load_url(&self, url: &Url) -> Result<Vec<PackageSummary>, anyhow::Error> {
228 let local_path = self
229 .get_locally_cached_file(url)
230 .await
231 .context("Unable to get the locally cached file")?;
232
233 let webc_sha256 = crate::block_in_place(|| WebcHash::for_file(&local_path))
234 .with_context(|| format!("Unable to hash \"{}\"", local_path.display()))?;
235
236 let container = crate::block_in_place(|| from_disk(&local_path))
239 .with_context(|| format!("Unable to load \"{}\"", local_path.display()))?;
240
241 let id = PackageInfo::package_id_from_manifest(container.manifest())?
242 .unwrap_or_else(|| PackageId::Hash(PackageHash::from_sha256_bytes(webc_sha256.0)));
243
244 let pkg = PackageInfo::from_manifest(id, container.manifest(), container.version())
245 .context("Unable to determine the package's metadata")?;
246
247 let dist = DistributionInfo {
248 webc: url.clone(),
249 webc_sha256,
250 };
251
252 Ok(vec![PackageSummary { pkg, dist }])
253 }
254}
255
256#[async_trait::async_trait]
257impl Source for WebSource {
258 #[tracing::instrument(level = "debug", skip_all, fields(%package))]
259 async fn query(&self, package: &PackageSource) -> Result<Vec<PackageSummary>, QueryError> {
260 let url = match package {
261 PackageSource::Url(url) => url,
262 _ => {
263 return Err(QueryError::Unsupported {
264 query: package.clone(),
265 });
266 }
267 };
268
269 self.load_url(url)
270 .await
271 .map_err(|error| QueryError::new_other(error, package))
272 }
273}
274
275fn sha256(bytes: &[u8]) -> String {
276 let mut hasher = Sha256::default();
277 hasher.update(bytes);
278 let hash = hasher.finalize();
279 let mut buffer = String::with_capacity(hash.len() * 2);
280 for byte in hash {
281 write!(buffer, "{byte:02X}").expect("Unreachable");
282 }
283
284 buffer
285}
286
287#[derive(Debug, Clone, PartialEq)]
288enum CacheInfo {
289 Miss,
291 Hit {
293 path: PathBuf,
294 etag: Option<String>,
295 last_modified: Option<SystemTime>,
296 },
297}
298
299impl CacheInfo {
300 fn for_url(key: &str, checkout_dir: &Path) -> CacheInfo {
301 let path = checkout_dir.join(key);
302
303 if !path.exists() {
304 return CacheInfo::Miss;
305 }
306
307 let etag = std::fs::read_to_string(path.with_extension("etag")).ok();
308 let last_modified = path.metadata().and_then(|m| m.modified()).ok();
309
310 CacheInfo::Hit {
311 etag,
312 last_modified,
313 path,
314 }
315 }
316}
317
318fn classify_cache_using_mtime(
319 info: CacheInfo,
320 invalidation_threshold: Duration,
321) -> Result<PathBuf, CacheState> {
322 let (path, last_modified, etag) = match info {
323 CacheInfo::Hit {
324 path,
325 last_modified: Some(last_modified),
326 etag,
327 ..
328 } => (path, last_modified, etag),
329 CacheInfo::Hit {
330 path,
331 last_modified: None,
332 etag: Some(etag),
333 ..
334 } => return Err(CacheState::PossiblyDirty { etag, path }),
335 CacheInfo::Hit {
336 etag: None,
337 last_modified: None,
338 path,
339 ..
340 } => {
341 return Err(CacheState::UnableToVerify { path });
342 }
343 CacheInfo::Miss => return Err(CacheState::Miss),
344 };
345
346 if let Ok(time_since_last_modified) = last_modified.elapsed()
347 && time_since_last_modified <= invalidation_threshold
348 {
349 return Ok(path);
350 }
351
352 match etag {
353 Some(etag) => Err(CacheState::PossiblyDirty { etag, path }),
354 None => Err(CacheState::UnableToVerify { path }),
355 }
356}
357
358#[derive(Debug)]
360enum CacheState {
361 Miss,
363 PossiblyDirty { etag: String, path: PathBuf },
366 UnableToVerify { path: PathBuf },
370}
371
372impl CacheState {
373 fn take_path(self) -> Option<PathBuf> {
374 match self {
375 CacheState::PossiblyDirty { path, .. } | CacheState::UnableToVerify { path } => {
376 Some(path)
377 }
378 _ => None,
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use std::{collections::VecDeque, sync::Mutex};
386
387 use futures::future::BoxFuture;
388 use http::{HeaderMap, StatusCode, header::IntoHeaderName};
389 use tempfile::TempDir;
390
391 use crate::http::HttpResponse;
392
393 use super::*;
394
395 const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
396 const COREUTILS: &[u8] = include_bytes!(
397 "../../../../../tests/integration/cli/tests/webc/coreutils-1.0.16-e27dbb4f-2ef2-4b44-b46a-ddd86497c6d7.webc"
398 );
399 const DUMMY_URL: &str = "http://my-registry.io/some/package";
400 const DUMMY_URL_HASH: &str = "4D7481F44E1D971A8C60D3C7BD505E2727602CF9369ED623920E029C2BA2351D";
401
402 #[derive(Debug)]
403 pub(crate) struct DummyClient {
404 requests: Mutex<Vec<HttpRequest>>,
405 responses: Mutex<VecDeque<HttpResponse>>,
406 }
407
408 impl DummyClient {
409 pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
410 DummyClient {
411 requests: Mutex::new(Vec::new()),
412 responses: Mutex::new(responses.into_iter().collect()),
413 }
414 }
415 }
416
417 impl HttpClient for DummyClient {
418 fn request(
419 &self,
420 request: HttpRequest,
421 ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
422 let response = self.responses.lock().unwrap().pop_front().unwrap();
423 self.requests.lock().unwrap().push(request);
424 Box::pin(async { Ok(response) })
425 }
426 }
427
428 struct ResponseBuilder(HttpResponse);
429
430 impl ResponseBuilder {
431 pub fn new() -> Self {
432 ResponseBuilder(HttpResponse {
433 body: None,
434 redirected: false,
435 status: StatusCode::OK,
436 headers: HeaderMap::new(),
437 })
438 }
439
440 pub fn with_status(mut self, code: StatusCode) -> Self {
441 self.0.status = code;
442 self
443 }
444
445 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
446 self.0.body = Some(body.into());
447 self
448 }
449
450 pub fn with_etag(self, value: &str) -> Self {
451 self.with_header("ETag", value)
452 }
453
454 pub fn with_header(mut self, name: impl IntoHeaderName, value: &str) -> Self {
455 self.0.headers.insert(name, value.parse().unwrap());
456 self
457 }
458
459 pub fn build(self) -> HttpResponse {
460 self.0
461 }
462 }
463
464 async fn empty_cache_does_a_full_download_internal() {
465 let dummy_etag = "This is an etag";
466 let temp = TempDir::new().unwrap();
467 let client = DummyClient::with_responses([ResponseBuilder::new()
468 .with_body(PYTHON)
469 .with_etag(dummy_etag)
470 .build()]);
471 let source = WebSource::new(temp.path(), Arc::new(client));
472 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
473
474 let summaries = source.query(&spec).await.unwrap();
475
476 assert_eq!(summaries.len(), 1);
478 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
479 let path = temp.path().join(DUMMY_URL_HASH);
481 assert!(path.exists());
482 let etag_path = path.with_extension("etag");
483 assert!(etag_path.exists());
484 assert_eq!(std::fs::read_to_string(etag_path).unwrap(), dummy_etag);
486 assert_eq!(std::fs::read(path).unwrap(), PYTHON);
487 }
488 #[cfg(not(target_arch = "wasm32"))]
489 #[tokio::test(flavor = "multi_thread")]
490 async fn empty_cache_does_a_full_download() {
491 empty_cache_does_a_full_download_internal().await
492 }
493 #[cfg(target_arch = "wasm32")]
494 #[tokio::test()]
495 async fn empty_cache_does_a_full_download() {
496 empty_cache_does_a_full_download_internal().await
497 }
498
499 async fn cache_hit_internal() {
500 let temp = TempDir::new().unwrap();
501 let client = Arc::new(DummyClient::with_responses([]));
502 let source = WebSource::new(temp.path(), client.clone());
503 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
504 std::fs::write(temp.path().join(DUMMY_URL_HASH), PYTHON).unwrap();
506
507 let summaries = source.query(&spec).await.unwrap();
508
509 assert_eq!(summaries.len(), 1);
511 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
512 assert_eq!(client.requests.lock().unwrap().len(), 0);
514 }
515 #[cfg(not(target_arch = "wasm32"))]
516 #[tokio::test(flavor = "multi_thread")]
517 async fn cache_hit() {
518 cache_hit_internal().await
519 }
520 #[cfg(target_arch = "wasm32")]
521 #[tokio::test()]
522 async fn cache_hit() {
523 cache_hit_internal().await
524 }
525
526 async fn fall_back_to_stale_cache_if_request_fails_internal() {
527 let temp = TempDir::new().unwrap();
528 let client = Arc::new(DummyClient::with_responses([ResponseBuilder::new()
529 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
530 .build()]));
531 let python_path = temp.path().join(DUMMY_URL_HASH);
533 std::fs::write(&python_path, PYTHON).unwrap();
534 let source = WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::ZERO);
535 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
536
537 let summaries = source.query(&spec).await.unwrap();
538
539 assert_eq!(summaries.len(), 1);
541 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
542 assert_eq!(client.requests.lock().unwrap().len(), 1);
544 assert!(!python_path.with_extension("etag").exists());
546 }
547 #[cfg(not(target_arch = "wasm32"))]
548 #[tokio::test(flavor = "multi_thread")]
549 async fn fall_back_to_stale_cache_if_request_fails() {
550 fall_back_to_stale_cache_if_request_fails_internal().await
551 }
552 #[cfg(target_arch = "wasm32")]
553 #[tokio::test()]
554 async fn fall_back_to_stale_cache_if_request_fails() {
555 fall_back_to_stale_cache_if_request_fails_internal().await
556 }
557
558 async fn download_again_if_etag_is_different_internal() {
559 let temp = TempDir::new().unwrap();
560 let client = Arc::new(DummyClient::with_responses([
561 ResponseBuilder::new().with_etag("coreutils").build(),
562 ResponseBuilder::new()
563 .with_body(COREUTILS)
564 .with_etag("coreutils")
565 .build(),
566 ]));
567 let path = temp.path().join(DUMMY_URL_HASH);
569 std::fs::write(&path, PYTHON).unwrap();
570 std::fs::write(path.with_extension("etag"), "python").unwrap();
571 let source =
573 WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::new(0, 0));
574 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
575
576 let summaries = source.query(&spec).await.unwrap();
577
578 assert_eq!(summaries.len(), 1);
580 assert_eq!(
581 summaries[0].pkg.id.as_named().unwrap().full_name,
582 "sharrattj/coreutils"
583 );
584 let requests = client.requests.lock().unwrap();
586 assert_eq!(requests.len(), 2);
587 assert_eq!(requests[0].method, "HEAD");
588 assert_eq!(requests[1].method, "GET");
589 assert_eq!(
591 std::fs::read_to_string(path.with_extension("etag")).unwrap(),
592 "coreutils"
593 );
594 }
595 #[cfg(not(target_arch = "wasm32"))]
596 #[tokio::test(flavor = "multi_thread")]
597 async fn download_again_if_etag_is_different() {
598 download_again_if_etag_is_different_internal().await
599 }
600 #[cfg(target_arch = "wasm32")]
601 #[tokio::test()]
602 async fn download_again_if_etag_is_different() {
603 download_again_if_etag_is_different_internal().await
604 }
605}