1use std::{
2 fmt::Write as _,
3 io::Write,
4 path::{Path, PathBuf},
5 sync::Arc,
6 time::{Duration, SystemTime},
7};
8
9use anyhow::{Context, Error};
10use http::Method;
11use sha2::{Digest, Sha256};
12use tempfile::NamedTempFile;
13use url::Url;
14use wasmer_config::package::{PackageHash, PackageId, PackageSource};
15use wasmer_package::utils::from_disk;
16
17use crate::{
18 http::{HttpClient, HttpRequest},
19 runtime::resolver::{
20 DistributionInfo, PackageInfo, PackageSummary, QueryError, Source, WebcHash,
21 },
22};
23
24#[derive(Debug, Clone)]
39pub struct WebSource {
40 cache_dir: PathBuf,
41 client: Arc<dyn HttpClient + Send + Sync>,
42 retry_period: Duration,
43}
44
45impl WebSource {
46 pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5 * 60);
47
48 pub fn new(cache_dir: impl Into<PathBuf>, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
49 WebSource {
50 cache_dir: cache_dir.into(),
51 client,
52 retry_period: WebSource::DEFAULT_RETRY_PERIOD,
53 }
54 }
55
56 pub fn with_retry_period(self, retry_period: Duration) -> Self {
59 WebSource {
60 retry_period,
61 ..self
62 }
63 }
64
65 #[tracing::instrument(level = "debug", skip_all, fields(%url))]
67 async fn get_locally_cached_file(&self, url: &Url) -> Result<PathBuf, Error> {
68 let cache_key = sha256(url.as_str().as_bytes());
72
73 let cache_info = CacheInfo::for_url(&cache_key, &self.cache_dir);
75
76 let state = match classify_cache_using_mtime(cache_info, self.retry_period) {
78 Ok(path) => {
79 tracing::debug!(path=%path.display(), "Cache hit!");
80 return Ok(path);
81 }
82 Err(s) => s,
83 };
84
85 if let CacheState::PossiblyDirty { etag, path } = &state {
87 match self.get_etag(url).await {
88 Ok(new_etag) if new_etag == *etag => {
89 return Ok(path.clone());
90 }
91 Ok(different_etag) => {
92 tracing::debug!(
93 original_etag=%etag,
94 new_etag=%different_etag,
95 path=%path.display(),
96 "File has been updated. Redownloading.",
97 );
98 }
99 Err(e) => {
100 tracing::debug!(
101 error=&*e,
102 path=%path.display(),
103 original_etag=%etag,
104 "Unable to check if the etag is out of date",
105 )
106 }
107 }
108 }
109
110 let (bytes, etag) = match self.fetch(url).await {
112 Ok((bytes, etag)) => (bytes, etag),
113 Err(e) => {
114 tracing::warn!(error = &*e, "Download failed");
115 match state.take_path() {
116 Some(path) => {
117 tracing::debug!(
118 path=%path.display(),
119 "Using a possibly stale cached file",
120 );
121 return Ok(path);
122 }
123 None => {
124 return Err(e);
125 }
126 }
127 }
128 };
129
130 let path = self.cache_dir.join(&cache_key);
131 self.atomically_save_file(&path, &bytes)
132 .await
133 .with_context(|| {
134 format!(
135 "Unable to save the downloaded file to \"{}\"",
136 path.display()
137 )
138 })?;
139
140 if let Some(etag) = etag {
141 if let Err(e) = self
142 .atomically_save_file(path.with_extension("etag"), etag.as_bytes())
143 .await
144 {
145 tracing::warn!(
146 error=&*e,
147 %etag,
148 %url,
149 path=%path.display(),
150 "Unable to save the etag file",
151 )
152 }
153 }
154
155 Ok(path)
156 }
157
158 async fn atomically_save_file(&self, path: impl AsRef<Path>, data: &[u8]) -> Result<(), Error> {
159 let path = path.as_ref();
162
163 if let Some(parent) = path.parent() {
164 std::fs::create_dir_all(parent)
165 .with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
166 }
167
168 let mut temp = NamedTempFile::new_in(&self.cache_dir)?;
169 temp.write_all(data)?;
170 temp.as_file().sync_all()?;
171 temp.persist(path)?;
172
173 Ok(())
174 }
175
176 async fn get_etag(&self, url: &Url) -> Result<String, Error> {
177 let request = HttpRequest {
178 url: url.clone(),
179 method: Method::HEAD,
180 headers: super::utils::webc_headers(),
181 body: None,
182 options: Default::default(),
183 };
184
185 let response = self.client.request(request).await?;
186
187 if !response.is_ok() {
188 return Err(super::utils::http_error(&response)
189 .context(format!("The HEAD request to \"{url}\" failed")));
190 }
191
192 let etag = response
193 .headers
194 .get("ETag")
195 .context("The HEAD request didn't contain an ETag header`")?
196 .to_str()
197 .context("The ETag wasn't valid UTF-8")?;
198
199 Ok(etag.to_string())
200 }
201
202 async fn fetch(&self, url: &Url) -> Result<(Vec<u8>, Option<String>), Error> {
203 let request = HttpRequest {
204 url: url.clone(),
205 method: Method::GET,
206 headers: super::utils::webc_headers(),
207 body: None,
208 options: Default::default(),
209 };
210 let response = self.client.request(request).await?;
211
212 if !response.is_ok() {
213 return Err(super::utils::http_error(&response)
214 .context(format!("The GET request to \"{url}\" failed")));
215 }
216
217 let body = response.body.context("Response didn't contain a body")?;
218
219 let etag = response
220 .headers
221 .get("ETag")
222 .and_then(|etag| etag.to_str().ok())
223 .map(|etag| etag.to_string());
224
225 Ok((body, etag))
226 }
227
228 async fn load_url(&self, url: &Url) -> Result<Vec<PackageSummary>, anyhow::Error> {
229 let local_path = self
230 .get_locally_cached_file(url)
231 .await
232 .context("Unable to get the locally cached file")?;
233
234 let webc_sha256 = crate::block_in_place(|| WebcHash::for_file(&local_path))
235 .with_context(|| format!("Unable to hash \"{}\"", local_path.display()))?;
236
237 let container = crate::block_in_place(|| from_disk(&local_path))
240 .with_context(|| format!("Unable to load \"{}\"", local_path.display()))?;
241
242 let id = PackageInfo::package_id_from_manifest(container.manifest())?
243 .unwrap_or_else(|| PackageId::Hash(PackageHash::from_sha256_bytes(webc_sha256.0)));
244
245 let pkg = PackageInfo::from_manifest(id, container.manifest(), container.version())
246 .context("Unable to determine the package's metadata")?;
247
248 let dist = DistributionInfo {
249 webc: url.clone(),
250 webc_sha256,
251 };
252
253 Ok(vec![PackageSummary { pkg, dist }])
254 }
255}
256
257#[async_trait::async_trait]
258impl Source for WebSource {
259 #[tracing::instrument(level = "debug", skip_all, fields(%package))]
260 async fn query(&self, package: &PackageSource) -> Result<Vec<PackageSummary>, QueryError> {
261 let url = match package {
262 PackageSource::Url(url) => url,
263 _ => {
264 return Err(QueryError::Unsupported {
265 query: package.clone(),
266 });
267 }
268 };
269
270 self.load_url(url)
271 .await
272 .map_err(|error| QueryError::new_other(error, package))
273 }
274}
275
276fn sha256(bytes: &[u8]) -> String {
277 let mut hasher = Sha256::default();
278 hasher.update(bytes);
279 let hash = hasher.finalize();
280 let mut buffer = String::with_capacity(hash.len() * 2);
281 for byte in hash {
282 write!(buffer, "{byte:02X}").expect("Unreachable");
283 }
284
285 buffer
286}
287
288#[derive(Debug, Clone, PartialEq)]
289enum CacheInfo {
290 Miss,
292 Hit {
294 path: PathBuf,
295 etag: Option<String>,
296 last_modified: Option<SystemTime>,
297 },
298}
299
300impl CacheInfo {
301 fn for_url(key: &str, checkout_dir: &Path) -> CacheInfo {
302 let path = checkout_dir.join(key);
303
304 if !path.exists() {
305 return CacheInfo::Miss;
306 }
307
308 let etag = std::fs::read_to_string(path.with_extension("etag")).ok();
309 let last_modified = path.metadata().and_then(|m| m.modified()).ok();
310
311 CacheInfo::Hit {
312 etag,
313 last_modified,
314 path,
315 }
316 }
317}
318
319fn classify_cache_using_mtime(
320 info: CacheInfo,
321 invalidation_threshold: Duration,
322) -> Result<PathBuf, CacheState> {
323 let (path, last_modified, etag) = match info {
324 CacheInfo::Hit {
325 path,
326 last_modified: Some(last_modified),
327 etag,
328 ..
329 } => (path, last_modified, etag),
330 CacheInfo::Hit {
331 path,
332 last_modified: None,
333 etag: Some(etag),
334 ..
335 } => return Err(CacheState::PossiblyDirty { etag, path }),
336 CacheInfo::Hit {
337 etag: None,
338 last_modified: None,
339 path,
340 ..
341 } => {
342 return Err(CacheState::UnableToVerify { path });
343 }
344 CacheInfo::Miss => return Err(CacheState::Miss),
345 };
346
347 if let Ok(time_since_last_modified) = last_modified.elapsed() {
348 if time_since_last_modified <= invalidation_threshold {
349 return Ok(path);
350 }
351 }
352
353 match etag {
354 Some(etag) => Err(CacheState::PossiblyDirty { etag, path }),
355 None => Err(CacheState::UnableToVerify { path }),
356 }
357}
358
359#[derive(Debug)]
361enum CacheState {
362 Miss,
364 PossiblyDirty { etag: String, path: PathBuf },
367 UnableToVerify { path: PathBuf },
371}
372
373impl CacheState {
374 fn take_path(self) -> Option<PathBuf> {
375 match self {
376 CacheState::PossiblyDirty { path, .. } | CacheState::UnableToVerify { path } => {
377 Some(path)
378 }
379 _ => None,
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use std::{collections::VecDeque, sync::Mutex};
387
388 use futures::future::BoxFuture;
389 use http::{HeaderMap, StatusCode, header::IntoHeaderName};
390 use tempfile::TempDir;
391
392 use crate::http::HttpResponse;
393
394 use super::*;
395
396 const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
397 const COREUTILS: &[u8] = include_bytes!(
398 "../../../../../tests/integration/cli/tests/webc/coreutils-1.0.16-e27dbb4f-2ef2-4b44-b46a-ddd86497c6d7.webc"
399 );
400 const DUMMY_URL: &str = "http://my-registry.io/some/package";
401 const DUMMY_URL_HASH: &str = "4D7481F44E1D971A8C60D3C7BD505E2727602CF9369ED623920E029C2BA2351D";
402
403 #[derive(Debug)]
404 pub(crate) struct DummyClient {
405 requests: Mutex<Vec<HttpRequest>>,
406 responses: Mutex<VecDeque<HttpResponse>>,
407 }
408
409 impl DummyClient {
410 pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
411 DummyClient {
412 requests: Mutex::new(Vec::new()),
413 responses: Mutex::new(responses.into_iter().collect()),
414 }
415 }
416 }
417
418 impl HttpClient for DummyClient {
419 fn request(
420 &self,
421 request: HttpRequest,
422 ) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
423 let response = self.responses.lock().unwrap().pop_front().unwrap();
424 self.requests.lock().unwrap().push(request);
425 Box::pin(async { Ok(response) })
426 }
427 }
428
429 struct ResponseBuilder(HttpResponse);
430
431 impl ResponseBuilder {
432 pub fn new() -> Self {
433 ResponseBuilder(HttpResponse {
434 body: None,
435 redirected: false,
436 status: StatusCode::OK,
437 headers: HeaderMap::new(),
438 })
439 }
440
441 pub fn with_status(mut self, code: StatusCode) -> Self {
442 self.0.status = code;
443 self
444 }
445
446 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
447 self.0.body = Some(body.into());
448 self
449 }
450
451 pub fn with_etag(self, value: &str) -> Self {
452 self.with_header("ETag", value)
453 }
454
455 pub fn with_header(mut self, name: impl IntoHeaderName, value: &str) -> Self {
456 self.0.headers.insert(name, value.parse().unwrap());
457 self
458 }
459
460 pub fn build(self) -> HttpResponse {
461 self.0
462 }
463 }
464
465 async fn empty_cache_does_a_full_download_internal() {
466 let dummy_etag = "This is an etag";
467 let temp = TempDir::new().unwrap();
468 let client = DummyClient::with_responses([ResponseBuilder::new()
469 .with_body(PYTHON)
470 .with_etag(dummy_etag)
471 .build()]);
472 let source = WebSource::new(temp.path(), Arc::new(client));
473 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
474
475 let summaries = source.query(&spec).await.unwrap();
476
477 assert_eq!(summaries.len(), 1);
479 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
480 let path = temp.path().join(DUMMY_URL_HASH);
482 assert!(path.exists());
483 let etag_path = path.with_extension("etag");
484 assert!(etag_path.exists());
485 assert_eq!(std::fs::read_to_string(etag_path).unwrap(), dummy_etag);
487 assert_eq!(std::fs::read(path).unwrap(), PYTHON);
488 }
489 #[cfg(not(target_arch = "wasm32"))]
490 #[tokio::test(flavor = "multi_thread")]
491 async fn empty_cache_does_a_full_download() {
492 empty_cache_does_a_full_download_internal().await
493 }
494 #[cfg(target_arch = "wasm32")]
495 #[tokio::test()]
496 async fn empty_cache_does_a_full_download() {
497 empty_cache_does_a_full_download_internal().await
498 }
499
500 async fn cache_hit_internal() {
501 let temp = TempDir::new().unwrap();
502 let client = Arc::new(DummyClient::with_responses([]));
503 let source = WebSource::new(temp.path(), client.clone());
504 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
505 std::fs::write(temp.path().join(DUMMY_URL_HASH), PYTHON).unwrap();
507
508 let summaries = source.query(&spec).await.unwrap();
509
510 assert_eq!(summaries.len(), 1);
512 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
513 assert_eq!(client.requests.lock().unwrap().len(), 0);
515 }
516 #[cfg(not(target_arch = "wasm32"))]
517 #[tokio::test(flavor = "multi_thread")]
518 async fn cache_hit() {
519 cache_hit_internal().await
520 }
521 #[cfg(target_arch = "wasm32")]
522 #[tokio::test()]
523 async fn cache_hit() {
524 cache_hit_internal().await
525 }
526
527 async fn fall_back_to_stale_cache_if_request_fails_internal() {
528 let temp = TempDir::new().unwrap();
529 let client = Arc::new(DummyClient::with_responses([ResponseBuilder::new()
530 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
531 .build()]));
532 let python_path = temp.path().join(DUMMY_URL_HASH);
534 std::fs::write(&python_path, PYTHON).unwrap();
535 let source = WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::ZERO);
536 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
537
538 let summaries = source.query(&spec).await.unwrap();
539
540 assert_eq!(summaries.len(), 1);
542 assert_eq!(summaries[0].pkg.id.as_named().unwrap().full_name, "python");
543 assert_eq!(client.requests.lock().unwrap().len(), 1);
545 assert!(!python_path.with_extension("etag").exists());
547 }
548 #[cfg(not(target_arch = "wasm32"))]
549 #[tokio::test(flavor = "multi_thread")]
550 async fn fall_back_to_stale_cache_if_request_fails() {
551 fall_back_to_stale_cache_if_request_fails_internal().await
552 }
553 #[cfg(target_arch = "wasm32")]
554 #[tokio::test()]
555 async fn fall_back_to_stale_cache_if_request_fails() {
556 fall_back_to_stale_cache_if_request_fails_internal().await
557 }
558
559 async fn download_again_if_etag_is_different_internal() {
560 let temp = TempDir::new().unwrap();
561 let client = Arc::new(DummyClient::with_responses([
562 ResponseBuilder::new().with_etag("coreutils").build(),
563 ResponseBuilder::new()
564 .with_body(COREUTILS)
565 .with_etag("coreutils")
566 .build(),
567 ]));
568 let path = temp.path().join(DUMMY_URL_HASH);
570 std::fs::write(&path, PYTHON).unwrap();
571 std::fs::write(path.with_extension("etag"), "python").unwrap();
572 let source =
574 WebSource::new(temp.path(), client.clone()).with_retry_period(Duration::new(0, 0));
575 let spec = PackageSource::Url(DUMMY_URL.parse().unwrap());
576
577 let summaries = source.query(&spec).await.unwrap();
578
579 assert_eq!(summaries.len(), 1);
581 assert_eq!(
582 summaries[0].pkg.id.as_named().unwrap().full_name,
583 "sharrattj/coreutils"
584 );
585 let requests = client.requests.lock().unwrap();
587 assert_eq!(requests.len(), 2);
588 assert_eq!(requests[0].method, "HEAD");
589 assert_eq!(requests[1].method, "GET");
590 assert_eq!(
592 std::fs::read_to_string(path.with_extension("etag")).unwrap(),
593 "coreutils"
594 );
595 }
596 #[cfg(not(target_arch = "wasm32"))]
597 #[tokio::test(flavor = "multi_thread")]
598 async fn download_again_if_etag_is_different() {
599 download_again_if_etag_is_different_internal().await
600 }
601 #[cfg(target_arch = "wasm32")]
602 #[tokio::test()]
603 async fn download_again_if_etag_is_different() {
604 download_again_if_etag_is_different_internal().await
605 }
606}