wasmer_middlewares/
metering.rs

1//! `metering` is a middleware for tracking how many operators are
2//! executed in total and putting a limit on the total number of
3//! operators executed. The WebAssembly instance execution is stopped
4//! when the limit is reached.
5//!
6//! # Example
7//!
8//! [See the `metering` detailed and complete
9//! example](https://github.com/wasmerio/wasmer/blob/main/examples/metering.rs).
10
11use std::convert::TryInto;
12use std::fmt;
13use std::sync::{Arc, Mutex};
14use wasmer::wasmparser::{BlockType as WpTypeOrFuncType, Operator};
15use wasmer::{
16    AsStoreMut, ExportIndex, GlobalInit, GlobalType, Instance, LocalFunctionIndex, Mutability,
17    Type,
18    sys::{FunctionMiddleware, MiddlewareError, MiddlewareReaderState, ModuleMiddleware},
19};
20use wasmer_types::{GlobalIndex, ModuleInfo};
21
22#[derive(Clone)]
23struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
24
25impl MeteringGlobalIndexes {
26    /// The global index in the current module for remaining points.
27    fn remaining_points(&self) -> GlobalIndex {
28        self.0
29    }
30
31    /// The global index in the current module for a boolean indicating whether points are exhausted
32    /// or not.
33    /// This boolean is represented as a i32 global:
34    ///   * 0: there are remaining points
35    ///   * 1: points have been exhausted
36    fn points_exhausted(&self) -> GlobalIndex {
37        self.1
38    }
39}
40
41impl fmt::Debug for MeteringGlobalIndexes {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("MeteringGlobalIndexes")
44            .field("remaining_points", &self.remaining_points())
45            .field("points_exhausted", &self.points_exhausted())
46            .finish()
47    }
48}
49
50/// The module-level metering middleware.
51///
52/// # Panic
53///
54/// An instance of `Metering` should _not_ be shared among different
55/// modules, since it tracks module-specific information like the
56/// global index to store metering state. Attempts to use a `Metering`
57/// instance from multiple modules will result in a panic.
58///
59/// # Example
60///
61/// ```rust
62/// use std::sync::Arc;
63/// use wasmer::{wasmparser::Operator, sys::CompilerConfig};
64/// use wasmer_middlewares::Metering;
65///
66/// fn create_metering_middleware(compiler_config: &mut dyn CompilerConfig) {
67///     // Let's define a dummy cost function,
68///     // which counts 1 for all operators.
69///     let cost_function = |_operator: &Operator| -> u64 { 1 };
70///
71///     // Let's define the initial limit.
72///     let initial_limit = 10;
73///
74///     // Let's creating the metering middleware.
75///     let metering = Arc::new(Metering::new(
76///         initial_limit,
77///         cost_function
78///     ));
79///
80///     // Finally, let's push the middleware.
81///     compiler_config.push_middleware(metering);
82/// }
83/// ```
84pub struct Metering<F: Fn(&Operator) -> u64 + Send + Sync> {
85    /// Initial limit of points.
86    initial_limit: u64,
87
88    /// Function that maps each operator to a cost in "points".
89    cost_function: Arc<F>,
90
91    /// The global indexes for metering points.
92    global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
93}
94
95/// The function-level metering middleware.
96pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Send + Sync> {
97    /// Function that maps each operator to a cost in "points".
98    cost_function: Arc<F>,
99
100    /// The global indexes for metering points.
101    global_indexes: MeteringGlobalIndexes,
102
103    /// Accumulated cost of the current basic block.
104    accumulated_cost: u64,
105}
106
107/// Represents the type of the metering points, either `Remaining` or
108/// `Exhausted`.
109///
110/// # Example
111///
112/// See the [`get_remaining_points`] function to get an example.
113#[derive(Debug, Eq, PartialEq)]
114pub enum MeteringPoints {
115    /// The given number of metering points is left for the execution.
116    /// If the value is 0, all points are consumed but the execution
117    /// was not terminated.
118    Remaining(u64),
119
120    /// The execution was terminated because the metering points were
121    /// exhausted.  You can recover from this state by setting the
122    /// points via [`set_remaining_points`] and restart the execution.
123    Exhausted,
124}
125
126impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
127    /// Creates a `Metering` middleware.
128    ///
129    /// When providing a cost function, you should consider that branching operations do
130    /// additional work to track the metering points and probably need to have a higher cost.
131    /// To find out which operations are affected by this, you can call [`is_accounting`].
132    pub fn new(initial_limit: u64, cost_function: F) -> Self {
133        Self {
134            initial_limit,
135            cost_function: Arc::new(cost_function),
136            global_indexes: Mutex::new(None),
137        }
138    }
139}
140
141impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for Metering<F> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        f.debug_struct("Metering")
144            .field("initial_limit", &self.initial_limit)
145            .field("cost_function", &"<function>")
146            .field("global_indexes", &self.global_indexes)
147            .finish()
148    }
149}
150
151impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Metering<F> {
152    /// Generates a `FunctionMiddleware` for a given function.
153    fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
154        Box::new(FunctionMetering {
155            cost_function: self.cost_function.clone(),
156            global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
157            accumulated_cost: 0,
158        })
159    }
160
161    /// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
162    fn transform_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
163        let mut global_indexes = self.global_indexes.lock().unwrap();
164
165        if global_indexes.is_some() {
166            panic!(
167                "Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules."
168            );
169        }
170
171        // Append a global for remaining points and initialize it.
172        let remaining_points_global_index = module_info
173            .globals
174            .push(GlobalType::new(Type::I64, Mutability::Var));
175
176        module_info
177            .global_initializers
178            .push(GlobalInit::I64Const(self.initial_limit as i64));
179
180        module_info.exports.insert(
181            "wasmer_metering_remaining_points".to_string(),
182            ExportIndex::Global(remaining_points_global_index),
183        );
184
185        // Append a global for the exhausted points boolean and initialize it.
186        let points_exhausted_global_index = module_info
187            .globals
188            .push(GlobalType::new(Type::I32, Mutability::Var));
189
190        module_info
191            .global_initializers
192            .push(GlobalInit::I32Const(0));
193
194        module_info.exports.insert(
195            "wasmer_metering_points_exhausted".to_string(),
196            ExportIndex::Global(points_exhausted_global_index),
197        );
198
199        *global_indexes = Some(MeteringGlobalIndexes(
200            remaining_points_global_index,
201            points_exhausted_global_index,
202        ));
203
204        Ok(())
205    }
206}
207
208/// Returns `true` if and only if the given operator is an accounting operator.
209/// Accounting operators do additional work to track the metering points.
210pub fn is_accounting(operator: &Operator) -> bool {
211    // Possible sources and targets of a branch.
212    matches!(
213        operator,
214        Operator::Loop { .. } // loop headers are branch targets
215            | Operator::End // block ends are branch targets
216            | Operator::If { .. } // branch source, "if" can branch to else branch
217            | Operator::Else // "else" is the "end" of an if branch
218            | Operator::Br { .. } // branch source
219            | Operator::BrTable { .. } // branch source
220            | Operator::BrIf { .. } // branch source
221            | Operator::Call { .. } // function call - branch source
222            | Operator::CallIndirect { .. } // function call - branch source
223            | Operator::Return // end of function - branch source
224            // exceptions proposal
225            | Operator::Throw { .. } // branch source
226            | Operator::ThrowRef // branch source
227            | Operator::Rethrow { .. } // branch source
228            | Operator::Delegate { .. } // branch source
229            | Operator::Catch { .. } // branch target
230            // tail_call proposal
231            | Operator::ReturnCall { .. } // branch source
232            | Operator::ReturnCallIndirect { .. } // branch source
233            // gc proposal
234            | Operator::BrOnCast { .. } // branch source
235            | Operator::BrOnCastFail { .. } // branch source
236            // function_references proposal
237            | Operator::CallRef { .. } // branch source
238            | Operator::ReturnCallRef { .. } // branch source
239            | Operator::BrOnNull { .. } // branch source
240            | Operator::BrOnNonNull { .. } // branch source
241    )
242}
243
244impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        f.debug_struct("FunctionMetering")
247            .field("cost_function", &"<function>")
248            .field("global_indexes", &self.global_indexes)
249            .finish()
250    }
251}
252
253impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMetering<F> {
254    fn feed<'a>(
255        &mut self,
256        operator: Operator<'a>,
257        state: &mut MiddlewareReaderState<'a>,
258    ) -> Result<(), MiddlewareError> {
259        // Get the cost of the current operator, and add it to the accumulator.
260        // This needs to be done before the metering logic, to prevent operators like `Call` from escaping metering in some
261        // corner cases.
262        self.accumulated_cost += (self.cost_function)(&operator);
263
264        // Finalize the cost of the previous basic block and perform necessary checks.
265        if is_accounting(&operator) && self.accumulated_cost > 0 {
266            state.extend(&[
267                // if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
268                Operator::GlobalGet {
269                    global_index: self.global_indexes.remaining_points().as_u32(),
270                },
271                Operator::I64Const {
272                    value: self.accumulated_cost as i64,
273                },
274                Operator::I64LtU,
275                Operator::If {
276                    blockty: WpTypeOrFuncType::Empty,
277                },
278                Operator::I32Const { value: 1 },
279                Operator::GlobalSet {
280                    global_index: self.global_indexes.points_exhausted().as_u32(),
281                },
282                Operator::Unreachable,
283                Operator::End,
284                // globals[remaining_points_index] -= self.accumulated_cost;
285                Operator::GlobalGet {
286                    global_index: self.global_indexes.remaining_points().as_u32(),
287                },
288                Operator::I64Const {
289                    value: self.accumulated_cost as i64,
290                },
291                Operator::I64Sub,
292                Operator::GlobalSet {
293                    global_index: self.global_indexes.remaining_points().as_u32(),
294                },
295            ]);
296
297            self.accumulated_cost = 0;
298        }
299        state.push_operator(operator);
300
301        Ok(())
302    }
303}
304
305/// Get the remaining points in an [`Instance`].
306///
307/// Note: This can be used in a headless engine after an ahead-of-time
308/// compilation as all required state lives in the instance.
309///
310/// # Panic
311///
312/// The [`Instance`] must have been processed with
313/// the [`Metering`] middleware at compile time, otherwise this will
314/// panic.
315///
316/// # Example
317///
318/// ```rust
319/// use wasmer::Instance;
320/// use wasmer::AsStoreMut;
321/// use wasmer_middlewares::metering::{get_remaining_points, MeteringPoints};
322///
323/// /// Check whether the instance can continue to run based on the
324/// /// number of remaining points.
325/// fn can_continue_to_run(store: &mut impl AsStoreMut, instance: &Instance) -> bool {
326///     matches!(get_remaining_points(store, instance), MeteringPoints::Remaining(points) if points > 0)
327/// }
328/// ```
329pub fn get_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance) -> MeteringPoints {
330    let exhausted: i32 = instance
331        .exports
332        .get_global("wasmer_metering_points_exhausted")
333        .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
334        .get(ctx)
335        .try_into()
336        .expect("`wasmer_metering_points_exhausted` from Instance has wrong type");
337
338    if exhausted > 0 {
339        return MeteringPoints::Exhausted;
340    }
341
342    let points = instance
343        .exports
344        .get_global("wasmer_metering_remaining_points")
345        .expect("Can't get `wasmer_metering_remaining_points` from Instance")
346        .get(ctx)
347        .try_into()
348        .expect("`wasmer_metering_remaining_points` from Instance has wrong type");
349
350    MeteringPoints::Remaining(points)
351}
352
353/// Set the new provided remaining points in an [`Instance`].
354///
355/// Note: This can be used in a headless engine after an ahead-of-time
356/// compilation as all required state lives in the instance.
357///
358/// # Panic
359///
360/// The given [`Instance`] must have been processed
361/// with the [`Metering`] middleware at compile time, otherwise this
362/// will panic.
363///
364/// # Example
365///
366/// ```rust
367/// use wasmer::{AsStoreMut, Instance};
368/// use wasmer_middlewares::metering::set_remaining_points;
369///
370/// fn update_remaining_points(store: &mut impl AsStoreMut, instance: &Instance) {
371///     // The new limit.
372///     let new_limit = 10;
373///
374///     // Update the remaining points to the `new_limit`.
375///     set_remaining_points(store, instance, new_limit);
376/// }
377/// ```
378pub fn set_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance, points: u64) {
379    instance
380        .exports
381        .get_global("wasmer_metering_remaining_points")
382        .expect("Can't get `wasmer_metering_remaining_points` from Instance")
383        .set(ctx, points.into())
384        .expect("Can't set `wasmer_metering_remaining_points` in Instance");
385
386    instance
387        .exports
388        .get_global("wasmer_metering_points_exhausted")
389        .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
390        .set(ctx, 0i32.into())
391        .expect("Can't set `wasmer_metering_points_exhausted` in Instance");
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    use std::sync::Arc;
399    use wasmer::sys::EngineBuilder;
400    use wasmer::{
401        Module, Store, TypedFunction, imports,
402        sys::{CompilerConfig, Cranelift},
403        wat2wasm,
404    };
405
406    fn cost_function(operator: &Operator) -> u64 {
407        match operator {
408            Operator::LocalGet { .. } | Operator::I32Const { .. } => 1,
409            Operator::I32Add { .. } => 2,
410            _ => 0,
411        }
412    }
413
414    fn bytecode() -> Vec<u8> {
415        wat2wasm(
416            br#"(module
417            (type $add_t (func (param i32) (result i32)))
418            (func $add_one_f (type $add_t) (param $value i32) (result i32)
419                local.get $value
420                i32.const 1
421                i32.add)
422            (func $short_loop_f
423                (local $x f64) (local $j i32)
424                (local.set $x (f64.const 5.5))
425
426                (loop $named_loop
427                    ;; $j++
428                    local.get $j
429                    i32.const 1
430                    i32.add
431                    local.set $j
432
433                    ;; if $j < 5, one more time
434                    local.get $j
435                    i32.const 5
436                    i32.lt_s
437                    br_if $named_loop
438                )
439            )
440            (func $infi_loop_f
441                (loop $infi_loop_start
442                    br $infi_loop_start
443                )
444            )
445            (export "add_one" (func $add_one_f))
446            (export "short_loop" (func $short_loop_f))
447            (export "infi_loop" (func $infi_loop_f))
448        )"#,
449        )
450        .unwrap()
451        .into()
452    }
453
454    #[test]
455    fn get_remaining_points_works() {
456        let metering = Arc::new(Metering::new(10, cost_function));
457        let mut compiler_config = Cranelift::default();
458        compiler_config.push_middleware(metering);
459        let mut store = Store::new(EngineBuilder::new(compiler_config));
460        let module = Module::new(&store, bytecode()).unwrap();
461
462        // Instantiate
463        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
464        assert_eq!(
465            get_remaining_points(&mut store, &instance),
466            MeteringPoints::Remaining(10)
467        );
468
469        // First call
470        //
471        // Calling add_one costs 4 points. Here are the details of how it has been computed:
472        // * `local.get $value` is a `Operator::LocalGet` which costs 1 point;
473        // * `i32.const` is a `Operator::I32Const` which costs 1 point;
474        // * `i32.add` is a `Operator::I32Add` which costs 2 points.
475        let add_one: TypedFunction<i32, i32> = instance
476            .exports
477            .get_function("add_one")
478            .unwrap()
479            .typed(&store)
480            .unwrap();
481        add_one.call(&mut store, 1).unwrap();
482        assert_eq!(
483            get_remaining_points(&mut store, &instance),
484            MeteringPoints::Remaining(6)
485        );
486
487        // Second call
488        add_one.call(&mut store, 1).unwrap();
489        assert_eq!(
490            get_remaining_points(&mut store, &instance),
491            MeteringPoints::Remaining(2)
492        );
493
494        // Third call fails due to limit
495        assert!(add_one.call(&mut store, 1).is_err());
496        assert_eq!(
497            get_remaining_points(&mut store, &instance),
498            MeteringPoints::Exhausted
499        );
500    }
501
502    #[test]
503    fn set_remaining_points_works() {
504        let metering = Arc::new(Metering::new(10, cost_function));
505        let mut compiler_config = Cranelift::default();
506        compiler_config.push_middleware(metering);
507        let mut store = Store::new(EngineBuilder::new(compiler_config));
508        let module = Module::new(&store, bytecode()).unwrap();
509
510        // Instantiate
511        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
512        assert_eq!(
513            get_remaining_points(&mut store, &instance),
514            MeteringPoints::Remaining(10)
515        );
516        let add_one: TypedFunction<i32, i32> = instance
517            .exports
518            .get_function("add_one")
519            .unwrap()
520            .typed(&store)
521            .unwrap();
522
523        // Increase a bit to have enough for 3 calls
524        set_remaining_points(&mut store, &instance, 12);
525
526        // Ensure we can use the new points now
527        add_one.call(&mut store, 1).unwrap();
528        assert_eq!(
529            get_remaining_points(&mut store, &instance),
530            MeteringPoints::Remaining(8)
531        );
532
533        add_one.call(&mut store, 1).unwrap();
534        assert_eq!(
535            get_remaining_points(&mut store, &instance),
536            MeteringPoints::Remaining(4)
537        );
538
539        add_one.call(&mut store, 1).unwrap();
540        assert_eq!(
541            get_remaining_points(&mut store, &instance),
542            MeteringPoints::Remaining(0)
543        );
544
545        assert!(add_one.call(&mut store, 1).is_err());
546        assert_eq!(
547            get_remaining_points(&mut store, &instance),
548            MeteringPoints::Exhausted
549        );
550
551        // Add some points for another call
552        set_remaining_points(&mut store, &instance, 4);
553        assert_eq!(
554            get_remaining_points(&mut store, &instance),
555            MeteringPoints::Remaining(4)
556        );
557    }
558
559    #[test]
560    fn metering_works_for_loops() {
561        const INITIAL_POINTS: u64 = 10_000;
562
563        fn cost(operator: &Operator) -> u64 {
564            match operator {
565                Operator::Loop { .. } => 1000,
566                Operator::Br { .. } | Operator::BrIf { .. } => 10,
567                Operator::F64Const { .. } => 7,
568                _ => 0,
569            }
570        }
571
572        // Short loop
573
574        let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
575        let mut compiler_config = Cranelift::default();
576        compiler_config.push_middleware(metering);
577        let mut store = Store::new(EngineBuilder::new(compiler_config));
578        let module = Module::new(&store, bytecode()).unwrap();
579
580        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
581        let short_loop: TypedFunction<(), ()> = instance
582            .exports
583            .get_function("short_loop")
584            .unwrap()
585            .typed(&store)
586            .unwrap();
587        short_loop.call(&mut store).unwrap();
588
589        let points_used: u64 = match get_remaining_points(&mut store, &instance) {
590            MeteringPoints::Exhausted => panic!("Unexpected exhausted"),
591            MeteringPoints::Remaining(remaining) => INITIAL_POINTS - remaining,
592        };
593
594        assert_eq!(
595            points_used,
596            7 /* pre-loop instructions */ +
597            1000 /* loop instruction */ + 50 /* five conditional breaks */
598        );
599
600        // Infinite loop
601
602        let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
603        let mut compiler_config = Cranelift::default();
604        compiler_config.push_middleware(metering);
605        let mut store = Store::new(EngineBuilder::new(compiler_config));
606        let module = Module::new(&store, bytecode()).unwrap();
607
608        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
609        let infi_loop: TypedFunction<(), ()> = instance
610            .exports
611            .get_function("infi_loop")
612            .unwrap()
613            .typed(&store)
614            .unwrap();
615        infi_loop.call(&mut store).unwrap_err(); // exhausted leads to runtime error
616
617        assert_eq!(
618            get_remaining_points(&mut store, &instance),
619            MeteringPoints::Exhausted
620        );
621    }
622}