Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: import example implementation #81

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/execution/interpreter_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ pub(super) fn run<H: HookSet>(
mut hooks: H,
) -> Result<(), RuntimeError> {
let func_inst = store
.funcs
.get(stack.current_stackframe().func_idx)
.local_funcs
.get(stack.current_stackframe().func_idx - store.imports.len())
.unwrap_validated();

// Start reading the function's instructions
Expand Down Expand Up @@ -75,9 +75,9 @@ pub(super) fn run<H: HookSet>(
RETURN => {
trace!("returning from function");

let func_to_call_idx = stack.current_stackframe().func_idx;
let func_to_call_idx = stack.current_stackframe().func_idx - store.imports.len();

let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated();
let func_to_call_inst = store.local_funcs.get(func_to_call_idx).unwrap_validated();
let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated();

let ret_vals = stack
Expand All @@ -99,7 +99,7 @@ pub(super) fn run<H: HookSet>(
CALL => {
let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx;

let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated();
let func_to_call_inst = store.local_funcs.get(func_to_call_idx).unwrap_validated();
let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated();

let params = stack.pop_tail_iter(func_to_call_ty.params.valtypes.len());
Expand Down
223 changes: 135 additions & 88 deletions src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use alloc::vec::Vec;

use interpreter_loop::run;
use locals::Locals;
use store::ImportInst;
use value_stack::Stack;

use crate::core::indices::FuncIdx;
use crate::core::reader::types::export::{Export, ExportDesc};
use crate::core::reader::types::import::ImportDesc;
use crate::core::reader::types::{FuncType, ValType};
use crate::core::reader::WasmReader;
use crate::execution::assert_validated::UnwrapValidatedExt;
Expand Down Expand Up @@ -79,6 +81,7 @@ where
let func_idx = self.exports.iter().find_map(|export| {
if export.name == func_name {
match export.desc {
// the first ones are the imported functions
ExportDesc::FuncIdx(idx) => Some(idx),
_ => None,
}
Expand All @@ -100,49 +103,62 @@ where
func_idx: FuncIdx,
params: Param,
) -> Result<Returns, RuntimeError> {
// -=-= Verification =-=-
let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
let func_ty = self.types.get(func_inst.ty).unwrap_validated();
if func_idx >= self.store.imports.len() {
// -=-= Verification =-=-
let func_inst = self
.store
.local_funcs
.get(func_idx - self.store.imports.len())
.expect("valid FuncIdx");
let func_ty = self
.types
.get(func_inst.ty + self.store.imports.len())
.unwrap_validated();

// Check correct function parameters and return types
if func_ty.params.valtypes != Param::TYS {
panic!("Invalid `Param` generics");
}
if func_ty.returns.valtypes != Returns::TYS {
panic!("Invalid `Returns` generics");
}

// Check correct function parameters and return types
if func_ty.params.valtypes != Param::TYS {
panic!("Invalid `Param` generics");
}
if func_ty.returns.valtypes != Returns::TYS {
panic!("Invalid `Returns` generics");
// Prepare a new stack with the locals for the entry function
let mut stack = Stack::new();
let locals = Locals::new(
params.into_values().into_iter(),
func_inst.locals.iter().cloned(),
);

// setting `usize::MAX` as return address for the outermost function ensures that we
// observably fail upon errornoeusly continuing execution after that function returns.

// WARN: here, func_idx is not the index of the function in the wasm binary, but in the `local_funcs` vector
stack.push_stackframe(func_idx, func_ty, locals, usize::MAX);

// Run the interpreter
run(
self.wasm_bytecode,
&self.types,
&mut self.store,
&mut stack,
EmptyHookSet,
)?;

// Pop return values from stack
let return_values = Returns::TYS
.iter()
.map(|ty| stack.pop_value(*ty))
.collect::<Vec<Value>>();

// Values are reversed because they were popped from stack one-by-one. Now reverse them back
let reversed_values = return_values.into_iter().rev();
let ret: Returns = Returns::from_values(reversed_values);
debug!("Successfully invoked function");
Ok(ret)
} else {
panic!("Calling imported function is not implemented yet");
}

// Prepare a new stack with the locals for the entry function
let mut stack = Stack::new();
let locals = Locals::new(
params.into_values().into_iter(),
func_inst.locals.iter().cloned(),
);

// setting `usize::MAX` as return address for the outermost function ensures that we
// observably fail upon errornoeusly continuing execution after that function returns.
stack.push_stackframe(func_idx, func_ty, locals, usize::MAX);

// Run the interpreter
run(
self.wasm_bytecode,
&self.types,
&mut self.store,
&mut stack,
EmptyHookSet,
)?;

// Pop return values from stack
let return_values = Returns::TYS
.iter()
.map(|ty| stack.pop_value(*ty))
.collect::<Vec<Value>>();

// Values are reversed because they were popped from stack one-by-one. Now reverse them back
let reversed_values = return_values.into_iter().rev();
let ret: Returns = Returns::from_values(reversed_values);
debug!("Successfully invoked function");
Ok(ret)
}

/// Invokes a function with the given parameters, and return types which are not known at compile time.
Expand All @@ -152,62 +168,83 @@ where
params: Vec<Value>,
ret_types: &[ValType],
) -> Result<Vec<Value>, RuntimeError> {
// -=-= Verification =-=-
let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
let func_ty = self.types.get(func_inst.ty).unwrap_validated();

// Verify that the given parameters match the function parameters
let param_types = params.iter().map(|v| v.to_ty()).collect::<Vec<_>>();
if func_idx >= self.store.imports.len() {
// -=-= Verification =-=-
let func_inst = self
.store
.local_funcs
.get(func_idx - self.store.imports.len())
.expect("valid FuncIdx");
let func_ty = self
.types
.get(func_inst.ty + self.store.imports.len())
.unwrap_validated();

// Verify that the given parameters match the function parameters
let param_types = params.iter().map(|v| v.to_ty()).collect::<Vec<_>>();

if func_ty.params.valtypes != param_types {
panic!("Invalid parameters for function");
}

if func_ty.params.valtypes != param_types {
panic!("Invalid parameters for function");
}
// Verify that the given return types match the function return types
if func_ty.returns.valtypes != ret_types {
panic!("Invalid return types for function");
}

// Verify that the given return types match the function return types
if func_ty.returns.valtypes != ret_types {
panic!("Invalid return types for function");
// Prepare a new stack with the locals for the entry function
let mut stack = Stack::new();
let locals = Locals::new(params.into_iter(), func_inst.locals.iter().cloned());
stack.push_stackframe(func_idx, func_ty, locals, 0);

// Run the interpreter
run(
self.wasm_bytecode,
&self.types,
&mut self.store,
&mut stack,
EmptyHookSet,
)?;

let func_inst = self
.store
.local_funcs
.get(func_idx - self.store.imports.len())
.expect("valid FuncIdx");
let func_ty = self
.types
.get(func_inst.ty + self.store.imports.len())
.unwrap_validated();

// Pop return values from stack
let return_values = func_ty
.returns
.valtypes
.iter()
.map(|ty| stack.pop_value(*ty))
.collect::<Vec<Value>>();

// Values are reversed because they were popped from stack one-by-one. Now reverse them back
let reversed_values = return_values.into_iter().rev();
let ret = reversed_values.collect();
debug!("Successfully invoked function");
Ok(ret)
} else {
// we have to call an imported function
// make the call to the linker haha
panic!("Calling imported function is not implemented yet");
}

// Prepare a new stack with the locals for the entry function
let mut stack = Stack::new();
let locals = Locals::new(params.into_iter(), func_inst.locals.iter().cloned());
stack.push_stackframe(func_idx, func_ty, locals, 0);

// Run the interpreter
run(
self.wasm_bytecode,
&self.types,
&mut self.store,
&mut stack,
EmptyHookSet,
)?;

let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
let func_ty = self.types.get(func_inst.ty).unwrap_validated();

// Pop return values from stack
let return_values = func_ty
.returns
.valtypes
.iter()
.map(|ty| stack.pop_value(*ty))
.collect::<Vec<Value>>();

// Values are reversed because they were popped from stack one-by-one. Now reverse them back
let reversed_values = return_values.into_iter().rev();
let ret = reversed_values.collect();
debug!("Successfully invoked function");
Ok(ret)
}

fn init_store(validation_info: &ValidationInfo) -> Store {
let function_instances: Vec<FuncInst> = {
let local_function_instances: Vec<FuncInst> = {
let mut wasm_reader = WasmReader::new(validation_info.wasm);

let functions = validation_info.functions.iter();
let local_funcs = validation_info.functions.iter();
// let functions = validation_info.functions.iter();
let func_blocks = validation_info.func_blocks.iter();

functions
local_funcs
.zip(func_blocks)
.map(|(ty, func)| {
wasm_reader
Expand All @@ -231,6 +268,15 @@ where
.collect()
};

let import_instances: Vec<ImportInst> = validation_info
.imports
.iter()
.filter_map(|import| match import.desc {
ImportDesc::Func(func) => Some(func),
_ => None,
})
.map(ImportInst::Func)
.collect();
let memory_instances: Vec<MemInst> = validation_info
.memories
.iter()
Expand All @@ -252,7 +298,8 @@ where
.collect();

Store {
funcs: function_instances,
imports: import_instances,
local_funcs: local_function_instances,
mems: memory_instances,
globals: global_instances,
}
Expand Down
10 changes: 8 additions & 2 deletions src/execution/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use alloc::vec;
use alloc::vec::Vec;
use core::iter;

use crate::core::indices::TypeIdx;
use crate::core::indices::{FuncIdx, TypeIdx};
use crate::core::reader::span::Span;
use crate::core::reader::types::global::Global;
use crate::core::reader::types::{MemType, TableType, ValType};
Expand All @@ -14,10 +14,16 @@ use crate::execution::value::{Ref, Value};
/// the abstract machine.
/// <https://webassembly.github.io/spec/core/exec/runtime.html#store>
pub struct Store {
pub funcs: Vec<FuncInst>,
pub local_funcs: Vec<FuncInst>,
// tables: Vec<TableInst>,
pub mems: Vec<MemInst>,
pub globals: Vec<GlobalInst>,
pub imports: Vec<ImportInst>,
}

pub enum ImportInst {
#[allow(dead_code)]
Func(FuncIdx),
}

pub struct FuncInst {
Expand Down
35 changes: 28 additions & 7 deletions src/validation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::core::reader::section_header::{SectionHeader, SectionTy};
use crate::core::reader::span::Span;
use crate::core::reader::types::export::Export;
use crate::core::reader::types::global::Global;
use crate::core::reader::types::import::Import;
use crate::core::reader::types::import::{Import, ImportDesc};
use crate::core::reader::types::{FuncType, MemType, TableType};
use crate::core::reader::{WasmReadable, WasmReader};
use crate::{Error, Result};
Expand Down Expand Up @@ -71,10 +71,19 @@ pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {

while (skip_section(&mut wasm, &mut header)?).is_some() {}

let functions = handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| {
wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize))
})?
.unwrap_or_default();
let local_functions =
handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| {
wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize))
})?
.unwrap_or_default();

let imported_functions = imports
.iter()
.filter_map(|import| match &import.desc {
ImportDesc::Func(type_idx) => Some(*type_idx),
_ => None,
})
.collect::<Vec<TypeIdx>>();

while (skip_section(&mut wasm, &mut header)?).is_some() {}

Expand Down Expand Up @@ -132,12 +141,24 @@ pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {

while (skip_section(&mut wasm, &mut header)?).is_some() {}

// WARN: Make sure the linker will validate imported functions.
let func_blocks = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| {
code::validate_code_section(wasm, h, &types, &functions, &globals)
code::validate_code_section(wasm, h, &types, &local_functions, &globals)
})?
.unwrap_or_default();

assert_eq!(func_blocks.len(), functions.len(), "these should be equal"); // TODO check if this is in the spec
// This is NOT EXPLICITLY stated in the spec, but it is implicitly required for a valid WASM Module
assert_eq!(
func_blocks.len(),
local_functions.len(),
"these should be equal"
);

let functions = imported_functions
.iter()
.chain(local_functions.iter())
.cloned()
.collect::<Vec<TypeIdx>>();

while (skip_section(&mut wasm, &mut header)?).is_some() {}

Expand Down
Loading