use smallvec::SmallVec;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::ops::{Deref, Range};
use wasmer_types::{LocalFunctionIndex, MiddlewareError, ModuleInfo, WasmResult};
use wasmparser::{BinaryReader, Operator, ValType, WasmFeatures};
use super::error::from_binaryreadererror_wasmerror;
use crate::translator::environ::FunctionBinaryReader;
pub trait ModuleMiddleware: Debug + Send + Sync {
fn generate_function_middleware(
&self,
local_function_index: LocalFunctionIndex,
) -> Box<dyn FunctionMiddleware>;
fn transform_module_info(&self, _: &mut ModuleInfo) -> Result<(), MiddlewareError> {
Ok(())
}
}
pub trait FunctionMiddleware: Debug {
fn feed<'a>(
&mut self,
operator: Operator<'a>,
state: &mut MiddlewareReaderState<'a>,
) -> Result<(), MiddlewareError> {
state.push_operator(operator);
Ok(())
}
}
#[derive(Debug)]
pub struct MiddlewareBinaryReader<'a> {
state: MiddlewareReaderState<'a>,
chain: Vec<Box<dyn FunctionMiddleware>>,
}
#[derive(Debug)]
pub struct MiddlewareReaderState<'a> {
inner: BinaryReader<'a>,
pending_operations: VecDeque<Operator<'a>>,
}
pub trait ModuleMiddlewareChain {
fn generate_function_middleware_chain(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware>>;
fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError>;
}
impl<T: Deref<Target = dyn ModuleMiddleware>> ModuleMiddlewareChain for [T] {
fn generate_function_middleware_chain(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware>> {
self.iter()
.map(|x| x.generate_function_middleware(local_function_index))
.collect()
}
fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
for item in self {
item.transform_module_info(module_info)?;
}
Ok(())
}
}
impl<'a> MiddlewareReaderState<'a> {
pub fn push_operator(&mut self, operator: Operator<'a>) {
self.pending_operations.push_back(operator);
}
}
impl<'a> Extend<Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter);
}
}
impl<'a: 'b, 'b> Extend<&'b Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = &'b Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter.into_iter().cloned());
}
}
impl<'a> MiddlewareBinaryReader<'a> {
pub fn new_with_offset(data: &'a [u8], original_offset: usize) -> Self {
let inner = BinaryReader::new(data, original_offset, WasmFeatures::default());
Self {
state: MiddlewareReaderState {
inner,
pending_operations: VecDeque::new(),
},
chain: vec![],
}
}
pub fn set_middleware_chain(&mut self, stages: Vec<Box<dyn FunctionMiddleware>>) {
self.chain = stages;
}
}
impl<'a> FunctionBinaryReader<'a> for MiddlewareBinaryReader<'a> {
fn read_local_count(&mut self) -> WasmResult<u32> {
self.state
.inner
.read_var_u32()
.map_err(from_binaryreadererror_wasmerror)
}
fn read_local_decl(&mut self) -> WasmResult<(u32, ValType)> {
let count = self
.state
.inner
.read_var_u32()
.map_err(from_binaryreadererror_wasmerror)?;
let ty: ValType = self
.state
.inner
.read::<ValType>()
.map_err(from_binaryreadererror_wasmerror)?;
Ok((count, ty))
}
fn read_operator(&mut self) -> WasmResult<Operator<'a>> {
if self.chain.is_empty() {
return self
.state
.inner
.read_operator()
.map_err(from_binaryreadererror_wasmerror);
}
while self.state.pending_operations.is_empty() {
let raw_op = self
.state
.inner
.read_operator()
.map_err(from_binaryreadererror_wasmerror)?;
self.state.pending_operations.push_back(raw_op);
for stage in &mut self.chain {
let pending: SmallVec<[Operator<'a>; 2]> =
self.state.pending_operations.drain(0..).collect();
for pending_op in pending {
stage.feed(pending_op, &mut self.state)?;
}
}
}
Ok(self.state.pending_operations.pop_front().unwrap())
}
fn current_position(&self) -> usize {
self.state.inner.current_position()
}
fn original_position(&self) -> usize {
self.state.inner.original_position()
}
fn bytes_remaining(&self) -> usize {
self.state.inner.bytes_remaining()
}
fn eof(&self) -> bool {
self.state.inner.eof()
}
fn range(&self) -> Range<usize> {
self.state.inner.range()
}
}