wasmer_compiler/translator/
middleware.rs

1//! The middleware parses the function binary bytecodes and transform them
2//! with the chosen functions.
3
4use smallvec::SmallVec;
5use std::collections::VecDeque;
6use std::fmt::Debug;
7use std::ops::{Deref, Range};
8use wasmer_types::{LocalFunctionIndex, MiddlewareError, ModuleInfo, WasmError, WasmResult};
9use wasmparser::{BinaryReader, FunctionBody, Operator, OperatorsReader, ValType};
10
11use super::error::from_binaryreadererror_wasmerror;
12use crate::translator::environ::FunctionBinaryReader;
13
14/// A shared builder for function middlewares.
15pub trait ModuleMiddleware: Debug + Send + Sync {
16    /// Generates a `FunctionMiddleware` for a given function.
17    ///
18    /// Here we generate a separate object for each function instead of executing directly on per-function operators,
19    /// in order to enable concurrent middleware application. Takes immutable `&self` because this function can be called
20    /// concurrently from multiple compilation threads.
21    fn generate_function_middleware<'a>(
22        &self,
23        local_function_index: LocalFunctionIndex,
24    ) -> Box<dyn FunctionMiddleware<'a> + 'a>;
25
26    /// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
27    fn transform_module_info(&self, _: &mut ModuleInfo) -> Result<(), MiddlewareError> {
28        Ok(())
29    }
30}
31
32/// A function middleware specialized for a single function.
33pub trait FunctionMiddleware<'a>: Debug {
34    /// Provide info on the function's locals. This is called before feed.
35    fn locals_info(&mut self, _locals: &[ValType]) {}
36
37    /// Processes the given operator.
38    fn feed(
39        &mut self,
40        operator: Operator<'a>,
41        state: &mut MiddlewareReaderState<'a>,
42    ) -> Result<(), MiddlewareError> {
43        state.push_operator(operator);
44        Ok(())
45    }
46}
47
48/// A Middleware binary reader of the WebAssembly structures and types.
49pub struct MiddlewareBinaryReader<'a> {
50    /// Parsing state.
51    state: MiddlewareReaderState<'a>,
52
53    /// The backing middleware chain for this reader.
54    chain: Vec<Box<dyn FunctionMiddleware<'a> + 'a>>,
55}
56
57enum MiddlewareInnerReader<'a> {
58    Binary {
59        reader: BinaryReader<'a>,
60        original_reader: BinaryReader<'a>,
61    },
62    Operator(OperatorsReader<'a>),
63}
64
65/// The state of the binary reader. Exposed to middlewares to push their outputs.
66pub struct MiddlewareReaderState<'a> {
67    /// Raw binary reader.
68    inner: Option<MiddlewareInnerReader<'a>>,
69
70    /// The pending operations added by the middleware.
71    pending_operations: VecDeque<Operator<'a>>,
72
73    /// Number of local declaration groups (each group is a count + type pair).
74    local_decls_group: u32,
75
76    /// Number of local declaration groups read so far.
77    local_decls_group_read: u32,
78
79    /// Locals read so far.
80    locals: Vec<ValType>,
81}
82
83/// Trait for generating middleware chains from "prototype" (generator) chains.
84pub trait ModuleMiddlewareChain {
85    /// Generates a function middleware chain.
86    fn generate_function_middleware_chain<'a>(
87        &self,
88        local_function_index: LocalFunctionIndex,
89    ) -> Vec<Box<dyn FunctionMiddleware<'a> + 'a>>;
90
91    /// Applies the chain on a `ModuleInfo` struct.
92    fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError>;
93}
94
95impl<T: Deref<Target = dyn ModuleMiddleware>> ModuleMiddlewareChain for [T] {
96    /// Generates a function middleware chain.
97    fn generate_function_middleware_chain<'a>(
98        &self,
99        local_function_index: LocalFunctionIndex,
100    ) -> Vec<Box<dyn FunctionMiddleware<'a> + 'a>> {
101        self.iter()
102            .map(|x| x.generate_function_middleware(local_function_index))
103            .collect()
104    }
105
106    /// Applies the chain on a `ModuleInfo` struct.
107    fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
108        for item in self {
109            item.transform_module_info(module_info)?;
110        }
111        Ok(())
112    }
113}
114
115impl<'a> MiddlewareReaderState<'a> {
116    /// Push an operator.
117    pub fn push_operator(&mut self, operator: Operator<'a>) {
118        self.pending_operations.push_back(operator);
119    }
120}
121
122impl<'a> Extend<Operator<'a>> for MiddlewareReaderState<'a> {
123    fn extend<I: IntoIterator<Item = Operator<'a>>>(&mut self, iter: I) {
124        self.pending_operations.extend(iter);
125    }
126}
127
128impl<'a: 'b, 'b> Extend<&'b Operator<'a>> for MiddlewareReaderState<'a> {
129    fn extend<I: IntoIterator<Item = &'b Operator<'a>>>(&mut self, iter: I) {
130        self.pending_operations.extend(iter.into_iter().cloned());
131    }
132}
133
134impl<'a> MiddlewareBinaryReader<'a> {
135    /// Constructs a `MiddlewareBinaryReader` with an explicit starting offset.
136    pub fn new_with_offset(data: &'a [u8], original_offset: usize) -> Self {
137        let inner = BinaryReader::new(data, original_offset);
138        Self {
139            state: MiddlewareReaderState {
140                inner: Some(MiddlewareInnerReader::Binary {
141                    original_reader: inner.clone(),
142                    reader: inner,
143                }),
144                pending_operations: VecDeque::new(),
145                local_decls_group: 0,
146                local_decls_group_read: 0,
147                locals: vec![],
148            },
149            chain: vec![],
150        }
151    }
152
153    /// Replaces the middleware chain with a new one.
154    pub fn set_middleware_chain(&mut self, stages: Vec<Box<dyn FunctionMiddleware<'a> + 'a>>) {
155        self.chain = stages;
156    }
157
158    /// Pass info about the locals of a function to all middlewares
159    fn emit_locals_info(&mut self) {
160        for middleware in &mut self.chain {
161            middleware.locals_info(&self.state.locals)
162        }
163    }
164}
165
166impl<'a> FunctionBinaryReader<'a> for MiddlewareBinaryReader<'a> {
167    fn read_local_count(&mut self) -> WasmResult<u32> {
168        let total = match self.state.inner.as_mut().expect("inner state must exist") {
169            MiddlewareInnerReader::Binary { reader, .. } => reader
170                .read_var_u32()
171                .map_err(from_binaryreadererror_wasmerror),
172            MiddlewareInnerReader::Operator(..) => Err(WasmError::InvalidWebAssembly {
173                message: "locals must be read before the function body".to_string(),
174                offset: self.current_position(),
175            }),
176        }?;
177        self.state.local_decls_group = total;
178        self.state.locals.reserve(total as usize);
179        if total == 0 {
180            self.emit_locals_info();
181        }
182        Ok(total)
183    }
184
185    fn read_local_decl(&mut self) -> WasmResult<(u32, ValType)> {
186        let (count, ty) = match self.state.inner.as_mut().expect("inner state must exist") {
187            MiddlewareInnerReader::Binary { reader, .. } => {
188                let count = reader
189                    .read_var_u32()
190                    .map_err(from_binaryreadererror_wasmerror)?;
191                let ty: ValType = reader
192                    .read::<ValType>()
193                    .map_err(from_binaryreadererror_wasmerror)?;
194                Ok((count, ty))
195            }
196            MiddlewareInnerReader::Operator(..) => Err(WasmError::InvalidWebAssembly {
197                message: "locals must be read before the function body".to_string(),
198                offset: self.current_position(),
199            }),
200        }?;
201        for _ in 0..count {
202            self.state.locals.push(ty);
203        }
204
205        self.state.local_decls_group_read += 1;
206        if self.state.local_decls_group_read == self.state.local_decls_group {
207            self.emit_locals_info();
208        }
209        Ok((count, ty))
210    }
211
212    fn read_operator(&mut self) -> WasmResult<Operator<'a>> {
213        if let Some(inner) = self.state.inner.take() {
214            self.state.inner = Some(match inner {
215                MiddlewareInnerReader::Binary {
216                    original_reader, ..
217                } => {
218                    let operator_reader = FunctionBody::new(original_reader)
219                        .get_operators_reader()
220                        .map_err(from_binaryreadererror_wasmerror)?;
221                    MiddlewareInnerReader::Operator(operator_reader)
222                }
223                other => other,
224            });
225        }
226
227        let read_operator = |state: &mut MiddlewareReaderState<'a>| {
228            let Some(MiddlewareInnerReader::Operator(operator_reader)) = state.inner.as_mut()
229            else {
230                unreachable!();
231            };
232            operator_reader
233                .read()
234                .map_err(from_binaryreadererror_wasmerror)
235        };
236
237        if self.chain.is_empty() {
238            // We short-circuit in case no chain is used
239            return read_operator(&mut self.state);
240        }
241
242        // Try to fill the `self.pending_operations` buffer, until it is non-empty.
243        while self.state.pending_operations.is_empty() {
244            let raw_op = read_operator(&mut self.state)?;
245
246            // Fill the initial raw operator into pending buffer.
247            self.state.pending_operations.push_back(raw_op);
248
249            // Run the operator through each stage.
250            for stage in &mut self.chain {
251                // Take the outputs from the previous stage.
252                let pending: SmallVec<[Operator<'a>; 2]> =
253                    self.state.pending_operations.drain(0..).collect();
254
255                // ...and feed them into the current stage.
256                for pending_op in pending {
257                    stage.feed(pending_op, &mut self.state)?;
258                }
259            }
260        }
261
262        Ok(self.state.pending_operations.pop_front().unwrap())
263    }
264
265    fn current_position(&self) -> usize {
266        match self.state.inner.as_ref().expect("inner state must exist") {
267            MiddlewareInnerReader::Binary { reader, .. } => reader.current_position(),
268            MiddlewareInnerReader::Operator(operator_reader) => {
269                operator_reader.get_binary_reader().current_position()
270            }
271        }
272    }
273
274    fn original_position(&self) -> usize {
275        match self.state.inner.as_ref().expect("inner state must exist") {
276            MiddlewareInnerReader::Binary { reader, .. } => reader.original_position(),
277            MiddlewareInnerReader::Operator(operator_reader) => operator_reader.original_position(),
278        }
279    }
280
281    fn bytes_remaining(&self) -> usize {
282        match self.state.inner.as_ref().expect("inner state must exist") {
283            MiddlewareInnerReader::Binary { reader, .. } => reader.bytes_remaining(),
284            MiddlewareInnerReader::Operator(operator_reader) => {
285                operator_reader.get_binary_reader().bytes_remaining()
286            }
287        }
288    }
289
290    fn eof(&self) -> bool {
291        match self.state.inner.as_ref().expect("inner state must exist") {
292            MiddlewareInnerReader::Binary { reader, .. } => reader.eof(),
293            MiddlewareInnerReader::Operator(operator_reader) => operator_reader.eof(),
294        }
295    }
296
297    fn range(&self) -> Range<usize> {
298        match self.state.inner.as_ref().expect("inner state must exist") {
299            MiddlewareInnerReader::Binary { reader, .. } => reader.range(),
300            MiddlewareInnerReader::Operator(operator_reader) => {
301                operator_reader.get_binary_reader().range()
302            }
303        }
304    }
305}