1use std::convert::TryInto;
12use std::fmt;
13use std::sync::{Arc, Mutex};
14use wasmer::wasmparser::{BlockType as WpTypeOrFuncType, Operator};
15use wasmer::{
16 AsStoreMut, ExportIndex, GlobalInit, GlobalType, Instance, LocalFunctionIndex, Mutability,
17 Type,
18 sys::{FunctionMiddleware, MiddlewareError, MiddlewareReaderState, ModuleMiddleware},
19};
20use wasmer_types::{GlobalIndex, ModuleInfo};
21
22#[derive(Clone)]
23struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
24
25impl MeteringGlobalIndexes {
26 fn remaining_points(&self) -> GlobalIndex {
28 self.0
29 }
30
31 fn points_exhausted(&self) -> GlobalIndex {
37 self.1
38 }
39}
40
41impl fmt::Debug for MeteringGlobalIndexes {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("MeteringGlobalIndexes")
44 .field("remaining_points", &self.remaining_points())
45 .field("points_exhausted", &self.points_exhausted())
46 .finish()
47 }
48}
49
50pub struct Metering<F: Fn(&Operator) -> u64 + Send + Sync> {
85 initial_limit: u64,
87
88 cost_function: Arc<F>,
90
91 global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
93}
94
95pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Send + Sync> {
97 cost_function: Arc<F>,
99
100 global_indexes: MeteringGlobalIndexes,
102
103 accumulated_cost: u64,
105}
106
107#[derive(Debug, Eq, PartialEq)]
114pub enum MeteringPoints {
115 Remaining(u64),
119
120 Exhausted,
124}
125
126impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
127 pub fn new(initial_limit: u64, cost_function: F) -> Self {
133 Self {
134 initial_limit,
135 cost_function: Arc::new(cost_function),
136 global_indexes: Mutex::new(None),
137 }
138 }
139}
140
141impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for Metering<F> {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 f.debug_struct("Metering")
144 .field("initial_limit", &self.initial_limit)
145 .field("cost_function", &"<function>")
146 .field("global_indexes", &self.global_indexes)
147 .finish()
148 }
149}
150
151impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Metering<F> {
152 fn generate_function_middleware<'a>(
154 &self,
155 _: LocalFunctionIndex,
156 ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
157 Box::new(FunctionMetering {
158 cost_function: self.cost_function.clone(),
159 global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
160 accumulated_cost: 0,
161 })
162 }
163
164 fn transform_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
166 let mut global_indexes = self.global_indexes.lock().unwrap();
167
168 if global_indexes.is_some() {
169 panic!(
170 "Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules."
171 );
172 }
173
174 let remaining_points_global_index = module_info
176 .globals
177 .push(GlobalType::new(Type::I64, Mutability::Var));
178
179 module_info
180 .global_initializers
181 .push(GlobalInit::I64Const(self.initial_limit as i64));
182
183 module_info.exports.insert(
184 "wasmer_metering_remaining_points".to_string(),
185 ExportIndex::Global(remaining_points_global_index),
186 );
187
188 let points_exhausted_global_index = module_info
190 .globals
191 .push(GlobalType::new(Type::I32, Mutability::Var));
192
193 module_info
194 .global_initializers
195 .push(GlobalInit::I32Const(0));
196
197 module_info.exports.insert(
198 "wasmer_metering_points_exhausted".to_string(),
199 ExportIndex::Global(points_exhausted_global_index),
200 );
201
202 *global_indexes = Some(MeteringGlobalIndexes(
203 remaining_points_global_index,
204 points_exhausted_global_index,
205 ));
206
207 Ok(())
208 }
209}
210
211pub fn is_accounting(operator: &Operator) -> bool {
214 matches!(
216 operator,
217 Operator::Loop { .. } | Operator::End | Operator::If { .. } | Operator::Else | Operator::Br { .. } | Operator::BrTable { .. } | Operator::BrIf { .. } | Operator::Call { .. } | Operator::CallIndirect { .. } | Operator::Return | Operator::Throw { .. } | Operator::ThrowRef | Operator::Rethrow { .. } | Operator::Delegate { .. } | Operator::Catch { .. } | Operator::ReturnCall { .. } | Operator::ReturnCallIndirect { .. } | Operator::BrOnCast { .. } | Operator::BrOnCastFail { .. } | Operator::CallRef { .. } | Operator::ReturnCallRef { .. } | Operator::BrOnNull { .. } | Operator::BrOnNonNull { .. } )
245}
246
247impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("FunctionMetering")
250 .field("cost_function", &"<function>")
251 .field("global_indexes", &self.global_indexes)
252 .finish()
253 }
254}
255
256impl<'a, F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware<'a> for FunctionMetering<F> {
257 fn feed(
258 &mut self,
259 operator: Operator<'a>,
260 state: &mut MiddlewareReaderState<'a>,
261 ) -> Result<(), MiddlewareError> {
262 self.accumulated_cost += (self.cost_function)(&operator);
266
267 if is_accounting(&operator) && self.accumulated_cost > 0 {
269 state.extend(&[
270 Operator::GlobalGet {
272 global_index: self.global_indexes.remaining_points().as_u32(),
273 },
274 Operator::I64Const {
275 value: self.accumulated_cost as i64,
276 },
277 Operator::I64LtU,
278 Operator::If {
279 blockty: WpTypeOrFuncType::Empty,
280 },
281 Operator::I32Const { value: 1 },
282 Operator::GlobalSet {
283 global_index: self.global_indexes.points_exhausted().as_u32(),
284 },
285 Operator::Unreachable,
286 Operator::End,
287 Operator::GlobalGet {
289 global_index: self.global_indexes.remaining_points().as_u32(),
290 },
291 Operator::I64Const {
292 value: self.accumulated_cost as i64,
293 },
294 Operator::I64Sub,
295 Operator::GlobalSet {
296 global_index: self.global_indexes.remaining_points().as_u32(),
297 },
298 ]);
299
300 self.accumulated_cost = 0;
301 }
302 state.push_operator(operator);
303
304 Ok(())
305 }
306}
307
308pub fn get_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance) -> MeteringPoints {
333 let exhausted: i32 = instance
334 .exports
335 .get_global("wasmer_metering_points_exhausted")
336 .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
337 .get(ctx)
338 .try_into()
339 .expect("`wasmer_metering_points_exhausted` from Instance has wrong type");
340
341 if exhausted > 0 {
342 return MeteringPoints::Exhausted;
343 }
344
345 let points = instance
346 .exports
347 .get_global("wasmer_metering_remaining_points")
348 .expect("Can't get `wasmer_metering_remaining_points` from Instance")
349 .get(ctx)
350 .try_into()
351 .expect("`wasmer_metering_remaining_points` from Instance has wrong type");
352
353 MeteringPoints::Remaining(points)
354}
355
356pub fn set_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance, points: u64) {
382 instance
383 .exports
384 .get_global("wasmer_metering_remaining_points")
385 .expect("Can't get `wasmer_metering_remaining_points` from Instance")
386 .set(ctx, points.into())
387 .expect("Can't set `wasmer_metering_remaining_points` in Instance");
388
389 instance
390 .exports
391 .get_global("wasmer_metering_points_exhausted")
392 .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
393 .set(ctx, 0i32.into())
394 .expect("Can't set `wasmer_metering_points_exhausted` in Instance");
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 use std::sync::Arc;
402 use wasmer::sys::EngineBuilder;
403 use wasmer::{
404 Module, Store, TypedFunction, imports,
405 sys::{CompilerConfig, Cranelift},
406 wat2wasm,
407 };
408
409 fn cost_function(operator: &Operator) -> u64 {
410 match operator {
411 Operator::LocalGet { .. } | Operator::I32Const { .. } => 1,
412 Operator::I32Add { .. } => 2,
413 _ => 0,
414 }
415 }
416
417 fn bytecode() -> Vec<u8> {
418 wat2wasm(
419 br#"(module
420 (type $add_t (func (param i32) (result i32)))
421 (func $add_one_f (type $add_t) (param $value i32) (result i32)
422 local.get $value
423 i32.const 1
424 i32.add)
425 (func $short_loop_f
426 (local $x f64) (local $j i32)
427 (local.set $x (f64.const 5.5))
428
429 (loop $named_loop
430 ;; $j++
431 local.get $j
432 i32.const 1
433 i32.add
434 local.set $j
435
436 ;; if $j < 5, one more time
437 local.get $j
438 i32.const 5
439 i32.lt_s
440 br_if $named_loop
441 )
442 )
443 (func $infi_loop_f
444 (loop $infi_loop_start
445 br $infi_loop_start
446 )
447 )
448 (export "add_one" (func $add_one_f))
449 (export "short_loop" (func $short_loop_f))
450 (export "infi_loop" (func $infi_loop_f))
451 )"#,
452 )
453 .unwrap()
454 .into()
455 }
456
457 #[test]
458 fn get_remaining_points_works() {
459 let metering = Arc::new(Metering::new(10, cost_function));
460 let mut compiler_config = Cranelift::default();
461 compiler_config.push_middleware(metering);
462 let mut store = Store::new(EngineBuilder::new(compiler_config));
463 let module = Module::new(&store, bytecode()).unwrap();
464
465 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
467 assert_eq!(
468 get_remaining_points(&mut store, &instance),
469 MeteringPoints::Remaining(10)
470 );
471
472 let add_one: TypedFunction<i32, i32> = instance
479 .exports
480 .get_function("add_one")
481 .unwrap()
482 .typed(&store)
483 .unwrap();
484 add_one.call(&mut store, 1).unwrap();
485 assert_eq!(
486 get_remaining_points(&mut store, &instance),
487 MeteringPoints::Remaining(6)
488 );
489
490 add_one.call(&mut store, 1).unwrap();
492 assert_eq!(
493 get_remaining_points(&mut store, &instance),
494 MeteringPoints::Remaining(2)
495 );
496
497 assert!(add_one.call(&mut store, 1).is_err());
499 assert_eq!(
500 get_remaining_points(&mut store, &instance),
501 MeteringPoints::Exhausted
502 );
503 }
504
505 #[test]
506 fn set_remaining_points_works() {
507 let metering = Arc::new(Metering::new(10, cost_function));
508 let mut compiler_config = Cranelift::default();
509 compiler_config.push_middleware(metering);
510 let mut store = Store::new(EngineBuilder::new(compiler_config));
511 let module = Module::new(&store, bytecode()).unwrap();
512
513 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
515 assert_eq!(
516 get_remaining_points(&mut store, &instance),
517 MeteringPoints::Remaining(10)
518 );
519 let add_one: TypedFunction<i32, i32> = instance
520 .exports
521 .get_function("add_one")
522 .unwrap()
523 .typed(&store)
524 .unwrap();
525
526 set_remaining_points(&mut store, &instance, 12);
528
529 add_one.call(&mut store, 1).unwrap();
531 assert_eq!(
532 get_remaining_points(&mut store, &instance),
533 MeteringPoints::Remaining(8)
534 );
535
536 add_one.call(&mut store, 1).unwrap();
537 assert_eq!(
538 get_remaining_points(&mut store, &instance),
539 MeteringPoints::Remaining(4)
540 );
541
542 add_one.call(&mut store, 1).unwrap();
543 assert_eq!(
544 get_remaining_points(&mut store, &instance),
545 MeteringPoints::Remaining(0)
546 );
547
548 assert!(add_one.call(&mut store, 1).is_err());
549 assert_eq!(
550 get_remaining_points(&mut store, &instance),
551 MeteringPoints::Exhausted
552 );
553
554 set_remaining_points(&mut store, &instance, 4);
556 assert_eq!(
557 get_remaining_points(&mut store, &instance),
558 MeteringPoints::Remaining(4)
559 );
560 }
561
562 #[test]
563 fn metering_works_for_loops() {
564 const INITIAL_POINTS: u64 = 10_000;
565
566 fn cost(operator: &Operator) -> u64 {
567 match operator {
568 Operator::Loop { .. } => 1000,
569 Operator::Br { .. } | Operator::BrIf { .. } => 10,
570 Operator::F64Const { .. } => 7,
571 _ => 0,
572 }
573 }
574
575 let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
578 let mut compiler_config = Cranelift::default();
579 compiler_config.push_middleware(metering);
580 let mut store = Store::new(EngineBuilder::new(compiler_config));
581 let module = Module::new(&store, bytecode()).unwrap();
582
583 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
584 let short_loop: TypedFunction<(), ()> = instance
585 .exports
586 .get_function("short_loop")
587 .unwrap()
588 .typed(&store)
589 .unwrap();
590 short_loop.call(&mut store).unwrap();
591
592 let points_used: u64 = match get_remaining_points(&mut store, &instance) {
593 MeteringPoints::Exhausted => panic!("Unexpected exhausted"),
594 MeteringPoints::Remaining(remaining) => INITIAL_POINTS - remaining,
595 };
596
597 assert_eq!(
598 points_used,
599 7 +
600 1000 + 50 );
602
603 let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
606 let mut compiler_config = Cranelift::default();
607 compiler_config.push_middleware(metering);
608 let mut store = Store::new(EngineBuilder::new(compiler_config));
609 let module = Module::new(&store, bytecode()).unwrap();
610
611 let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
612 let infi_loop: TypedFunction<(), ()> = instance
613 .exports
614 .get_function("infi_loop")
615 .unwrap()
616 .typed(&store)
617 .unwrap();
618 infi_loop.call(&mut store).unwrap_err(); assert_eq!(
621 get_remaining_points(&mut store, &instance),
622 MeteringPoints::Exhausted
623 );
624 }
625}