wasmer_middlewares/
metering.rs

1//! `metering` is a middleware for tracking how many operators are
2//! executed in total and putting a limit on the total number of
3//! operators executed. The WebAssembly instance execution is stopped
4//! when the limit is reached.
5//!
6//! # Example
7//!
8//! [See the `metering` detailed and complete
9//! example](https://github.com/wasmerio/wasmer/blob/main/examples/metering.rs).
10
11use std::convert::TryInto;
12use std::fmt;
13use std::sync::{Arc, Mutex};
14use wasmer::wasmparser::{BlockType as WpTypeOrFuncType, Operator};
15use wasmer::{
16    AsStoreMut, ExportIndex, GlobalInit, GlobalType, Instance, LocalFunctionIndex, Mutability,
17    Type,
18    sys::{FunctionMiddleware, MiddlewareError, MiddlewareReaderState, ModuleMiddleware},
19};
20use wasmer_types::{GlobalIndex, ModuleInfo};
21
22#[derive(Clone)]
23struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
24
25impl MeteringGlobalIndexes {
26    /// The global index in the current module for remaining points.
27    fn remaining_points(&self) -> GlobalIndex {
28        self.0
29    }
30
31    /// The global index in the current module for a boolean indicating whether points are exhausted
32    /// or not.
33    /// This boolean is represented as a i32 global:
34    ///   * 0: there are remaining points
35    ///   * 1: points have been exhausted
36    fn points_exhausted(&self) -> GlobalIndex {
37        self.1
38    }
39}
40
41impl fmt::Debug for MeteringGlobalIndexes {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("MeteringGlobalIndexes")
44            .field("remaining_points", &self.remaining_points())
45            .field("points_exhausted", &self.points_exhausted())
46            .finish()
47    }
48}
49
50/// The module-level metering middleware.
51///
52/// # Panic
53///
54/// An instance of `Metering` should _not_ be shared among different
55/// modules, since it tracks module-specific information like the
56/// global index to store metering state. Attempts to use a `Metering`
57/// instance from multiple modules will result in a panic.
58///
59/// # Example
60///
61/// ```rust
62/// use std::sync::Arc;
63/// use wasmer::{wasmparser::Operator, sys::CompilerConfig};
64/// use wasmer_middlewares::Metering;
65///
66/// fn create_metering_middleware(compiler_config: &mut dyn CompilerConfig) {
67///     // Let's define a dummy cost function,
68///     // which counts 1 for all operators.
69///     let cost_function = |_operator: &Operator| -> u64 { 1 };
70///
71///     // Let's define the initial limit.
72///     let initial_limit = 10;
73///
74///     // Let's creating the metering middleware.
75///     let metering = Arc::new(Metering::new(
76///         initial_limit,
77///         cost_function
78///     ));
79///
80///     // Finally, let's push the middleware.
81///     compiler_config.push_middleware(metering);
82/// }
83/// ```
84pub struct Metering<F: Fn(&Operator) -> u64 + Send + Sync> {
85    /// Initial limit of points.
86    initial_limit: u64,
87
88    /// Function that maps each operator to a cost in "points".
89    cost_function: Arc<F>,
90
91    /// The global indexes for metering points.
92    global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
93}
94
95/// The function-level metering middleware.
96pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Send + Sync> {
97    /// Function that maps each operator to a cost in "points".
98    cost_function: Arc<F>,
99
100    /// The global indexes for metering points.
101    global_indexes: MeteringGlobalIndexes,
102
103    /// Accumulated cost of the current basic block.
104    accumulated_cost: u64,
105}
106
107/// Represents the type of the metering points, either `Remaining` or
108/// `Exhausted`.
109///
110/// # Example
111///
112/// See the [`get_remaining_points`] function to get an example.
113#[derive(Debug, Eq, PartialEq)]
114pub enum MeteringPoints {
115    /// The given number of metering points is left for the execution.
116    /// If the value is 0, all points are consumed but the execution
117    /// was not terminated.
118    Remaining(u64),
119
120    /// The execution was terminated because the metering points were
121    /// exhausted.  You can recover from this state by setting the
122    /// points via [`set_remaining_points`] and restart the execution.
123    Exhausted,
124}
125
126impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
127    /// Creates a `Metering` middleware.
128    ///
129    /// When providing a cost function, you should consider that branching operations do
130    /// additional work to track the metering points and probably need to have a higher cost.
131    /// To find out which operations are affected by this, you can call [`is_accounting`].
132    pub fn new(initial_limit: u64, cost_function: F) -> Self {
133        Self {
134            initial_limit,
135            cost_function: Arc::new(cost_function),
136            global_indexes: Mutex::new(None),
137        }
138    }
139}
140
141impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for Metering<F> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        f.debug_struct("Metering")
144            .field("initial_limit", &self.initial_limit)
145            .field("cost_function", &"<function>")
146            .field("global_indexes", &self.global_indexes)
147            .finish()
148    }
149}
150
151impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Metering<F> {
152    /// Generates a `FunctionMiddleware` for a given function.
153    fn generate_function_middleware<'a>(
154        &self,
155        _: LocalFunctionIndex,
156    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
157        Box::new(FunctionMetering {
158            cost_function: self.cost_function.clone(),
159            global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
160            accumulated_cost: 0,
161        })
162    }
163
164    /// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
165    fn transform_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
166        let mut global_indexes = self.global_indexes.lock().unwrap();
167
168        if global_indexes.is_some() {
169            panic!(
170                "Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules."
171            );
172        }
173
174        // Append a global for remaining points and initialize it.
175        let remaining_points_global_index = module_info
176            .globals
177            .push(GlobalType::new(Type::I64, Mutability::Var));
178
179        module_info
180            .global_initializers
181            .push(GlobalInit::I64Const(self.initial_limit as i64));
182
183        module_info.exports.insert(
184            "wasmer_metering_remaining_points".to_string(),
185            ExportIndex::Global(remaining_points_global_index),
186        );
187
188        // Append a global for the exhausted points boolean and initialize it.
189        let points_exhausted_global_index = module_info
190            .globals
191            .push(GlobalType::new(Type::I32, Mutability::Var));
192
193        module_info
194            .global_initializers
195            .push(GlobalInit::I32Const(0));
196
197        module_info.exports.insert(
198            "wasmer_metering_points_exhausted".to_string(),
199            ExportIndex::Global(points_exhausted_global_index),
200        );
201
202        *global_indexes = Some(MeteringGlobalIndexes(
203            remaining_points_global_index,
204            points_exhausted_global_index,
205        ));
206
207        Ok(())
208    }
209}
210
211/// Returns `true` if and only if the given operator is an accounting operator.
212/// Accounting operators do additional work to track the metering points.
213pub fn is_accounting(operator: &Operator) -> bool {
214    // Possible sources and targets of a branch.
215    matches!(
216        operator,
217        Operator::Loop { .. } // loop headers are branch targets
218            | Operator::End // block ends are branch targets
219            | Operator::If { .. } // branch source, "if" can branch to else branch
220            | Operator::Else // "else" is the "end" of an if branch
221            | Operator::Br { .. } // branch source
222            | Operator::BrTable { .. } // branch source
223            | Operator::BrIf { .. } // branch source
224            | Operator::Call { .. } // function call - branch source
225            | Operator::CallIndirect { .. } // function call - branch source
226            | Operator::Return // end of function - branch source
227            // exceptions proposal
228            | Operator::Throw { .. } // branch source
229            | Operator::ThrowRef // branch source
230            | Operator::Rethrow { .. } // branch source
231            | Operator::Delegate { .. } // branch source
232            | Operator::Catch { .. } // branch target
233            // tail_call proposal
234            | Operator::ReturnCall { .. } // branch source
235            | Operator::ReturnCallIndirect { .. } // branch source
236            // gc proposal
237            | Operator::BrOnCast { .. } // branch source
238            | Operator::BrOnCastFail { .. } // branch source
239            // function_references proposal
240            | Operator::CallRef { .. } // branch source
241            | Operator::ReturnCallRef { .. } // branch source
242            | Operator::BrOnNull { .. } // branch source
243            | Operator::BrOnNonNull { .. } // branch source
244    )
245}
246
247impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        f.debug_struct("FunctionMetering")
250            .field("cost_function", &"<function>")
251            .field("global_indexes", &self.global_indexes)
252            .finish()
253    }
254}
255
256impl<'a, F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware<'a> for FunctionMetering<F> {
257    fn feed(
258        &mut self,
259        operator: Operator<'a>,
260        state: &mut MiddlewareReaderState<'a>,
261    ) -> Result<(), MiddlewareError> {
262        // Get the cost of the current operator, and add it to the accumulator.
263        // This needs to be done before the metering logic, to prevent operators like `Call` from escaping metering in some
264        // corner cases.
265        self.accumulated_cost += (self.cost_function)(&operator);
266
267        // Finalize the cost of the previous basic block and perform necessary checks.
268        if is_accounting(&operator) && self.accumulated_cost > 0 {
269            state.extend(&[
270                // if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
271                Operator::GlobalGet {
272                    global_index: self.global_indexes.remaining_points().as_u32(),
273                },
274                Operator::I64Const {
275                    value: self.accumulated_cost as i64,
276                },
277                Operator::I64LtU,
278                Operator::If {
279                    blockty: WpTypeOrFuncType::Empty,
280                },
281                Operator::I32Const { value: 1 },
282                Operator::GlobalSet {
283                    global_index: self.global_indexes.points_exhausted().as_u32(),
284                },
285                Operator::Unreachable,
286                Operator::End,
287                // globals[remaining_points_index] -= self.accumulated_cost;
288                Operator::GlobalGet {
289                    global_index: self.global_indexes.remaining_points().as_u32(),
290                },
291                Operator::I64Const {
292                    value: self.accumulated_cost as i64,
293                },
294                Operator::I64Sub,
295                Operator::GlobalSet {
296                    global_index: self.global_indexes.remaining_points().as_u32(),
297                },
298            ]);
299
300            self.accumulated_cost = 0;
301        }
302        state.push_operator(operator);
303
304        Ok(())
305    }
306}
307
308/// Get the remaining points in an [`Instance`].
309///
310/// Note: This can be used in a headless engine after an ahead-of-time
311/// compilation as all required state lives in the instance.
312///
313/// # Panic
314///
315/// The [`Instance`] must have been processed with
316/// the [`Metering`] middleware at compile time, otherwise this will
317/// panic.
318///
319/// # Example
320///
321/// ```rust
322/// use wasmer::Instance;
323/// use wasmer::AsStoreMut;
324/// use wasmer_middlewares::metering::{get_remaining_points, MeteringPoints};
325///
326/// /// Check whether the instance can continue to run based on the
327/// /// number of remaining points.
328/// fn can_continue_to_run(store: &mut impl AsStoreMut, instance: &Instance) -> bool {
329///     matches!(get_remaining_points(store, instance), MeteringPoints::Remaining(points) if points > 0)
330/// }
331/// ```
332pub fn get_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance) -> MeteringPoints {
333    let exhausted: i32 = instance
334        .exports
335        .get_global("wasmer_metering_points_exhausted")
336        .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
337        .get(ctx)
338        .try_into()
339        .expect("`wasmer_metering_points_exhausted` from Instance has wrong type");
340
341    if exhausted > 0 {
342        return MeteringPoints::Exhausted;
343    }
344
345    let points = instance
346        .exports
347        .get_global("wasmer_metering_remaining_points")
348        .expect("Can't get `wasmer_metering_remaining_points` from Instance")
349        .get(ctx)
350        .try_into()
351        .expect("`wasmer_metering_remaining_points` from Instance has wrong type");
352
353    MeteringPoints::Remaining(points)
354}
355
356/// Set the new provided remaining points in an [`Instance`].
357///
358/// Note: This can be used in a headless engine after an ahead-of-time
359/// compilation as all required state lives in the instance.
360///
361/// # Panic
362///
363/// The given [`Instance`] must have been processed
364/// with the [`Metering`] middleware at compile time, otherwise this
365/// will panic.
366///
367/// # Example
368///
369/// ```rust
370/// use wasmer::{AsStoreMut, Instance};
371/// use wasmer_middlewares::metering::set_remaining_points;
372///
373/// fn update_remaining_points(store: &mut impl AsStoreMut, instance: &Instance) {
374///     // The new limit.
375///     let new_limit = 10;
376///
377///     // Update the remaining points to the `new_limit`.
378///     set_remaining_points(store, instance, new_limit);
379/// }
380/// ```
381pub fn set_remaining_points(ctx: &mut impl AsStoreMut, instance: &Instance, points: u64) {
382    instance
383        .exports
384        .get_global("wasmer_metering_remaining_points")
385        .expect("Can't get `wasmer_metering_remaining_points` from Instance")
386        .set(ctx, points.into())
387        .expect("Can't set `wasmer_metering_remaining_points` in Instance");
388
389    instance
390        .exports
391        .get_global("wasmer_metering_points_exhausted")
392        .expect("Can't get `wasmer_metering_points_exhausted` from Instance")
393        .set(ctx, 0i32.into())
394        .expect("Can't set `wasmer_metering_points_exhausted` in Instance");
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    use std::sync::Arc;
402    use wasmer::sys::EngineBuilder;
403    use wasmer::{
404        Module, Store, TypedFunction, imports,
405        sys::{CompilerConfig, Cranelift},
406        wat2wasm,
407    };
408
409    fn cost_function(operator: &Operator) -> u64 {
410        match operator {
411            Operator::LocalGet { .. } | Operator::I32Const { .. } => 1,
412            Operator::I32Add { .. } => 2,
413            _ => 0,
414        }
415    }
416
417    fn bytecode() -> Vec<u8> {
418        wat2wasm(
419            br#"(module
420            (type $add_t (func (param i32) (result i32)))
421            (func $add_one_f (type $add_t) (param $value i32) (result i32)
422                local.get $value
423                i32.const 1
424                i32.add)
425            (func $short_loop_f
426                (local $x f64) (local $j i32)
427                (local.set $x (f64.const 5.5))
428
429                (loop $named_loop
430                    ;; $j++
431                    local.get $j
432                    i32.const 1
433                    i32.add
434                    local.set $j
435
436                    ;; if $j < 5, one more time
437                    local.get $j
438                    i32.const 5
439                    i32.lt_s
440                    br_if $named_loop
441                )
442            )
443            (func $infi_loop_f
444                (loop $infi_loop_start
445                    br $infi_loop_start
446                )
447            )
448            (export "add_one" (func $add_one_f))
449            (export "short_loop" (func $short_loop_f))
450            (export "infi_loop" (func $infi_loop_f))
451        )"#,
452        )
453        .unwrap()
454        .into()
455    }
456
457    #[test]
458    fn get_remaining_points_works() {
459        let metering = Arc::new(Metering::new(10, cost_function));
460        let mut compiler_config = Cranelift::default();
461        compiler_config.push_middleware(metering);
462        let mut store = Store::new(EngineBuilder::new(compiler_config));
463        let module = Module::new(&store, bytecode()).unwrap();
464
465        // Instantiate
466        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
467        assert_eq!(
468            get_remaining_points(&mut store, &instance),
469            MeteringPoints::Remaining(10)
470        );
471
472        // First call
473        //
474        // Calling add_one costs 4 points. Here are the details of how it has been computed:
475        // * `local.get $value` is a `Operator::LocalGet` which costs 1 point;
476        // * `i32.const` is a `Operator::I32Const` which costs 1 point;
477        // * `i32.add` is a `Operator::I32Add` which costs 2 points.
478        let add_one: TypedFunction<i32, i32> = instance
479            .exports
480            .get_function("add_one")
481            .unwrap()
482            .typed(&store)
483            .unwrap();
484        add_one.call(&mut store, 1).unwrap();
485        assert_eq!(
486            get_remaining_points(&mut store, &instance),
487            MeteringPoints::Remaining(6)
488        );
489
490        // Second call
491        add_one.call(&mut store, 1).unwrap();
492        assert_eq!(
493            get_remaining_points(&mut store, &instance),
494            MeteringPoints::Remaining(2)
495        );
496
497        // Third call fails due to limit
498        assert!(add_one.call(&mut store, 1).is_err());
499        assert_eq!(
500            get_remaining_points(&mut store, &instance),
501            MeteringPoints::Exhausted
502        );
503    }
504
505    #[test]
506    fn set_remaining_points_works() {
507        let metering = Arc::new(Metering::new(10, cost_function));
508        let mut compiler_config = Cranelift::default();
509        compiler_config.push_middleware(metering);
510        let mut store = Store::new(EngineBuilder::new(compiler_config));
511        let module = Module::new(&store, bytecode()).unwrap();
512
513        // Instantiate
514        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
515        assert_eq!(
516            get_remaining_points(&mut store, &instance),
517            MeteringPoints::Remaining(10)
518        );
519        let add_one: TypedFunction<i32, i32> = instance
520            .exports
521            .get_function("add_one")
522            .unwrap()
523            .typed(&store)
524            .unwrap();
525
526        // Increase a bit to have enough for 3 calls
527        set_remaining_points(&mut store, &instance, 12);
528
529        // Ensure we can use the new points now
530        add_one.call(&mut store, 1).unwrap();
531        assert_eq!(
532            get_remaining_points(&mut store, &instance),
533            MeteringPoints::Remaining(8)
534        );
535
536        add_one.call(&mut store, 1).unwrap();
537        assert_eq!(
538            get_remaining_points(&mut store, &instance),
539            MeteringPoints::Remaining(4)
540        );
541
542        add_one.call(&mut store, 1).unwrap();
543        assert_eq!(
544            get_remaining_points(&mut store, &instance),
545            MeteringPoints::Remaining(0)
546        );
547
548        assert!(add_one.call(&mut store, 1).is_err());
549        assert_eq!(
550            get_remaining_points(&mut store, &instance),
551            MeteringPoints::Exhausted
552        );
553
554        // Add some points for another call
555        set_remaining_points(&mut store, &instance, 4);
556        assert_eq!(
557            get_remaining_points(&mut store, &instance),
558            MeteringPoints::Remaining(4)
559        );
560    }
561
562    #[test]
563    fn metering_works_for_loops() {
564        const INITIAL_POINTS: u64 = 10_000;
565
566        fn cost(operator: &Operator) -> u64 {
567            match operator {
568                Operator::Loop { .. } => 1000,
569                Operator::Br { .. } | Operator::BrIf { .. } => 10,
570                Operator::F64Const { .. } => 7,
571                _ => 0,
572            }
573        }
574
575        // Short loop
576
577        let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
578        let mut compiler_config = Cranelift::default();
579        compiler_config.push_middleware(metering);
580        let mut store = Store::new(EngineBuilder::new(compiler_config));
581        let module = Module::new(&store, bytecode()).unwrap();
582
583        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
584        let short_loop: TypedFunction<(), ()> = instance
585            .exports
586            .get_function("short_loop")
587            .unwrap()
588            .typed(&store)
589            .unwrap();
590        short_loop.call(&mut store).unwrap();
591
592        let points_used: u64 = match get_remaining_points(&mut store, &instance) {
593            MeteringPoints::Exhausted => panic!("Unexpected exhausted"),
594            MeteringPoints::Remaining(remaining) => INITIAL_POINTS - remaining,
595        };
596
597        assert_eq!(
598            points_used,
599            7 /* pre-loop instructions */ +
600            1000 /* loop instruction */ + 50 /* five conditional breaks */
601        );
602
603        // Infinite loop
604
605        let metering = Arc::new(Metering::new(INITIAL_POINTS, cost));
606        let mut compiler_config = Cranelift::default();
607        compiler_config.push_middleware(metering);
608        let mut store = Store::new(EngineBuilder::new(compiler_config));
609        let module = Module::new(&store, bytecode()).unwrap();
610
611        let instance = Instance::new(&mut store, &module, &imports! {}).unwrap();
612        let infi_loop: TypedFunction<(), ()> = instance
613            .exports
614            .get_function("infi_loop")
615            .unwrap()
616            .typed(&store)
617            .unwrap();
618        infi_loop.call(&mut store).unwrap_err(); // exhausted leads to runtime error
619
620        assert_eq!(
621            get_remaining_points(&mut store, &instance),
622            MeteringPoints::Exhausted
623        );
624    }
625}