1use std::convert::TryInto;
12use std::fmt;
13use std::sync::{Arc, Mutex};
14use wasmer::wasmparser::{BlockType as WpTypeOrFuncType, Operator};
15use wasmer::{
16 AsStoreMut, ExportIndex, GlobalInit, GlobalType, Instance, LocalFunctionIndex, Mutability,
17 Type,
18 sys::{FunctionMiddleware, MiddlewareError, MiddlewareReaderState, ModuleMiddleware},
19};
20use wasmer_types::{GlobalIndex, ModuleInfo};
21
22#[derive(Clone)]
23struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
24
25impl MeteringGlobalIndexes {
26 fn remaining_points(&self) -> GlobalIndex {
28 self.0
29 }
30
31 fn points_exhausted(&self) -> GlobalIndex {
37 self.1
38 }
39}
40
41impl fmt::Debug for MeteringGlobalIndexes {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("MeteringGlobalIndexes")
44 .field("remaining_points", &self.remaining_points())
45 .field("points_exhausted", &self.points_exhausted())
46 .finish()
47 }
48}
49
50pub struct Metering<F: Fn(&Operator) -> u64 + Send + Sync> {
85 initial_limit: u64,
87
88 cost_function: Arc<F>,
90
91 global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
93}
94
95pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Send + Sync> {
97 cost_function: Arc<F>,
99
100 global_indexes: MeteringGlobalIndexes,
102
103 accumulated_cost: u64,
105}
106
107#[derive(Debug, Eq, PartialEq)]
114pub enum MeteringPoints {
115 Remaining(u64),
119
120 Exhausted,
124}
125
126impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
127 pub fn new(initial_limit: u64, cost_function: F) -> Self {
133 Self {
134 initial_limit,
135 cost_function: Arc::new(cost_function),
136 global_indexes: Mutex::new(None),
137 }
138 }
139}
140
141impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for Metering<F> {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 f.debug_struct("Metering")
144 .field("initial_limit", &self.initial_limit)
145 .field("cost_function", &"<function>")
146 .field("global_indexes", &self.global_indexes)
147 .finish()
148 }
149}
150
151impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Metering<F> {
152 fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
154 Box::new(FunctionMetering {
155 cost_function: self.cost_function.clone(),
156 global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
157 accumulated_cost: 0,
158 })
159 }
160
161 fn transform_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
163 let mut global_indexes = self.global_indexes.lock().unwrap();
164
165 if global_indexes.is_some() {
166 panic!(
167 "Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules."
168 );
169 }
170
171 let remaining_points_global_index = module_info
173 .globals
174 .push(GlobalType::new(Type::I64, Mutability::Var));
175
176 module_info
177 .global_initializers
178 .push(GlobalInit::I64Const(self.initial_limit as i64));
179
180 module_info.exports.insert(
181 "wasmer_metering_remaining_points".to_string(),
182 ExportIndex::Global(remaining_points_global_index),
183 );
184
185 let points_exhausted_global_index = module_info
187 .globals
188 .push(GlobalType::new(Type::I32, Mutability::Var));
189
190 module_info
191 .global_initializers
192 .push(GlobalInit::I32Const(0));
193
194 module_info.exports.insert(
195 "wasmer_metering_points_exhausted".to_string(),
196 ExportIndex::Global(points_exhausted_global_index),
197 );
198
199 *global_indexes = Some(MeteringGlobalIndexes(
200 remaining_points_global_index,
201 points_exhausted_global_index,
202 ));
203
204 Ok(())
205 }
206}
207
208pub fn is_accounting(operator: &Operator) -> bool {
211 matches!(
213 operator,
214 Operator::Loop { .. } | Operator::End | Operator::If { .. } | Operator::Else | Operator::Br { .. } | Operator::BrTable { .. } | Operator::BrIf { .. } | Operator::Call { .. } | Operator::CallIndirect { .. } | Operator::Return | Operator::Throw { .. } | Operator::ThrowRef | Operator::Rethrow { .. } | Operator::Delegate { .. } | Operator::Catch { .. } | Operator::ReturnCall { .. } | Operator::ReturnCallIndirect { .. } | Operator::BrOnCast { .. } | Operator::BrOnCastFail { .. } | Operator::CallRef { .. } | Operator::ReturnCallRef { .. } | Operator::BrOnNull { .. } | Operator::BrOnNonNull { .. } )
242}
243
244impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 f.debug_struct("FunctionMetering")
247 .field("cost_function", &"<function>")
248 .field("global_indexes", &self.global_indexes)
249 .finish()
250 }
251}
252
253impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMetering<F> {
254 fn feed<'a>(
255 &mut self,
256 operator: Operator<'a>,
257 state: &mut MiddlewareReaderState<'a>,
258 ) -> Result<(), MiddlewareError> {
259 self.accumulated_cost += (self.cost_function)(&operator);
263
264 if is_accounting(&operator) && self.accumulated_cost > 0 {
266 state.extend(&[
267 Operator::GlobalGet {
269 global_index: self.global_indexes.remaining_points().as_u32(),
270 },
271 Operator::I64Const {
272 value: self.accumulated_cost as i64,
273 },
274 Operator::I64LtU,
275 Operator::If {
276 blockty: WpTypeOrFuncType::Empty,
277 },
278 Operator::I32Const { value: 1 },
279 Operator::GlobalSet {
280 global_index: self.global_indexes.points_exhausted().as_u32(),
281 },
282 Operator::Unreachable,
283 Operator::End,
284 Operator::GlobalGet {
286 global_index: self.global_indexes.remaining_points().as_u32(),
287 },
288 Operator::I64Const {
289 value: self.accumulated_cost as i64,
290 },
291 Operator::I64Sub,
292 Operator::GlobalSet {
293 global_index: self.global_indexes.remaining_points().as_u32(),
294 },
295 ]);
296
297 self.accumulated_cost = 0;
298 }
299 state.push_operator(operator);
300
301 Ok(())
302 }
303}
304
305pub fn get_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance) -> MeteringPoints {
330 let exhausted: i32 = instance
331 .exports
332 .get_global("wasmer_metering_points_exhausted")
333 .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
334 .get(ctx)
335 .try_into()
336 .expect("`wasmer_metering_points_exhausted` from Instance has wrong type");
337
338 if exhausted > 0 {
339 return MeteringPoints::Exhausted;
340 }
341
342 let points = instance
343 .exports
344 .get_global("wasmer_metering_remaining_points")
345 .expect("Can't get `wasmer_metering_remaining_points` from Instance")
346 .get(ctx)
347 .try_into()
348 .expect("`wasmer_metering_remaining_points` from Instance has wrong type");
349
350 MeteringPoints::Remaining(points)
351}
352
353pub fn set_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance, points: u64) {
379 instance
380 .exports
381 .get_global("wasmer_metering_remaining_points")
382 .expect("Can't get `wasmer_metering_remaining_points` from Instance")
383 .set(ctx, points.into())
384 .expect("Can't set `wasmer_metering_remaining_points` in Instance");
385
386 instance
387 .exports
388 .get_global("wasmer_metering_points_exhausted")
389 .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
390 .set(ctx, 0i32.into())
391 .expect("Can't set `wasmer_metering_points_exhausted` in Instance");
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 use std::sync::Arc;
399 use wasmer::sys::EngineBuilder;
400 use wasmer::{
401 Module, Store, TypedFunction, imports,
402 sys::{CompilerConfig, Cranelift},
403 wat2wasm,
404 };
405
406 fn cost_function(operator: &Operator) -> u64 {
407 match operator {
408 Operator::LocalGet { .. } | Operator::I32Const { .. } => 1,
409 Operator::I32Add { .. } => 2,
410 _ => 0,
411 }
412 }
413
414 fn bytecode() -> Vec<u8> {
415 wat2wasm(
416 br#"(module
417 (type $add_t (func (param i32) (result i32)))
418 (func $add_one_f (type $add_t) (param $value i32) (result i32)
419 local.get $value
420 i32.const 1
421 i32.add)
422 (func $short_loop_f
423 (local $x f64) (local $j i32)
424 (local.set $x (f64.const 5.5))
425
426 (loop $named_loop
427 ;; $j++
428 local.get $j
429 i32.const 1
430 i32.add
431 local.set $j
432
433 ;; if $j < 5, one more time
434 local.get $j
435 i32.const 5
436 i32.lt_s
437 br_if $named_loop
438 )
439 )
440 (func $infi_loop_f
441 (loop $infi_loop_start
442 br $infi_loop_start
443 )
444 )
445 (export "add_one" (func $add_one_f))
446 (export "short_loop" (func $short_loop_f))
447 (export "infi_loop" (func $infi_loop_f))
448 )"#,
449 )
450 .unwrap()
451 .into()
452 }
453
454 #[test]
455 fn get_remaining_points_works() {
456 let metering = Arc::new(Metering::new(10, cost_function));
457 let mut compiler_config = Cranelift::default();
458 compiler_config.push_middleware(metering);
459 let mut store = Store::new(EngineBuilder::new(compiler_config));
460 let module = Module::new(&store, bytecode()).unwrap();
461
462 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
464 assert_eq!(
465 get_remaining_points(&mut store, &instance),
466 MeteringPoints::Remaining(10)
467 );
468
469 let add_one: TypedFunction<i32, i32> = instance
476 .exports
477 .get_function("add_one")
478 .unwrap()
479 .typed(&store)
480 .unwrap();
481 add_one.call(&mut store, 1).unwrap();
482 assert_eq!(
483 get_remaining_points(&mut store, &instance),
484 MeteringPoints::Remaining(6)
485 );
486
487 add_one.call(&mut store, 1).unwrap();
489 assert_eq!(
490 get_remaining_points(&mut store, &instance),
491 MeteringPoints::Remaining(2)
492 );
493
494 assert!(add_one.call(&mut store, 1).is_err());
496 assert_eq!(
497 get_remaining_points(&mut store, &instance),
498 MeteringPoints::Exhausted
499 );
500 }
501
502 #[test]
503 fn set_remaining_points_works() {
504 let metering = Arc::new(Metering::new(10, cost_function));
505 let mut compiler_config = Cranelift::default();
506 compiler_config.push_middleware(metering);
507 let mut store = Store::new(EngineBuilder::new(compiler_config));
508 let module = Module::new(&store, bytecode()).unwrap();
509
510 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
512 assert_eq!(
513 get_remaining_points(&mut store, &instance),
514 MeteringPoints::Remaining(10)
515 );
516 let add_one: TypedFunction<i32, i32> = instance
517 .exports
518 .get_function("add_one")
519 .unwrap()
520 .typed(&store)
521 .unwrap();
522
523 set_remaining_points(&mut store, &instance, 12);
525
526 add_one.call(&mut store, 1).unwrap();
528 assert_eq!(
529 get_remaining_points(&mut store, &instance),
530 MeteringPoints::Remaining(8)
531 );
532
533 add_one.call(&mut store, 1).unwrap();
534 assert_eq!(
535 get_remaining_points(&mut store, &instance),
536 MeteringPoints::Remaining(4)
537 );
538
539 add_one.call(&mut store, 1).unwrap();
540 assert_eq!(
541 get_remaining_points(&mut store, &instance),
542 MeteringPoints::Remaining(0)
543 );
544
545 assert!(add_one.call(&mut store, 1).is_err());
546 assert_eq!(
547 get_remaining_points(&mut store, &instance),
548 MeteringPoints::Exhausted
549 );
550
551 set_remaining_points(&mut store, &instance, 4);
553 assert_eq!(
554 get_remaining_points(&mut store, &instance),
555 MeteringPoints::Remaining(4)
556 );
557 }
558
559 #[test]
560 fn metering_works_for_loops() {
561 const INITIAL_POINTS: u64 = 10_000;
562
563 fn cost(operator: &Operator) -> u64 {
564 match operator {
565 Operator::Loop { .. } => 1000,
566 Operator::Br { .. } | Operator::BrIf { .. } => 10,
567 Operator::F64Const { .. } => 7,
568 _ => 0,
569 }
570 }
571
572 let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
575 let mut compiler_config = Cranelift::default();
576 compiler_config.push_middleware(metering);
577 let mut store = Store::new(EngineBuilder::new(compiler_config));
578 let module = Module::new(&store, bytecode()).unwrap();
579
580 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
581 let short_loop: TypedFunction<(), ()> = instance
582 .exports
583 .get_function("short_loop")
584 .unwrap()
585 .typed(&store)
586 .unwrap();
587 short_loop.call(&mut store).unwrap();
588
589 let points_used: u64 = match get_remaining_points(&mut store, &instance) {
590 MeteringPoints::Exhausted => panic!("Unexpected exhausted"),
591 MeteringPoints::Remaining(remaining) => INITIAL_POINTS - remaining,
592 };
593
594 assert_eq!(
595 points_used,
596 7 +
597 1000 + 50 );
599
600 let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
603 let mut compiler_config = Cranelift::default();
604 compiler_config.push_middleware(metering);
605 let mut store = Store::new(EngineBuilder::new(compiler_config));
606 let module = Module::new(&store, bytecode()).unwrap();
607
608 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
609 let infi_loop: TypedFunction<(), ()> = instance
610 .exports
611 .get_function("infi_loop")
612 .unwrap()
613 .typed(&store)
614 .unwrap();
615 infi_loop.call(&mut store).unwrap_err(); assert_eq!(
618 get_remaining_points(&mut store, &instance),
619 MeteringPoints::Exhausted
620 );
621 }
622}