From f70a2e74bb557a07e53fdb3923cc0acf277f15c7 Mon Sep 17 00:00:00 2001 From: George Cosma Date: Tue, 17 Sep 2024 16:56:44 +0300 Subject: [PATCH] chore: split FuncInst into local and imported variant. In preparation for linker Signed-off-by: George Cosma --- src/execution/interpreter_loop.rs | 12 ++++- src/execution/mod.rs | 89 +++++++++++++++++++++---------- src/execution/store.rs | 39 +++++++++++++- src/validation/code.rs | 25 +++++++-- src/validation/mod.rs | 56 ++++++++++++++++--- tests/imports.rs | 59 ++++++++++++++++++++ 6 files changed, 237 insertions(+), 43 deletions(-) create mode 100644 tests/imports.rs diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index 6892b52e..b77a1c6f 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -42,6 +42,8 @@ pub(super) fn run( let func_inst = store .funcs .get(stack.current_stackframe().func_idx) + .unwrap_validated() + .try_into_local() .unwrap_validated(); // Start reading the function's instructions @@ -78,7 +80,7 @@ pub(super) fn run( let func_to_call_idx = stack.current_stackframe().func_idx; let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); - let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); + let func_to_call_ty = types.get(func_to_call_inst.ty()).unwrap_validated(); let ret_vals = stack .pop_tail_iter(func_to_call_ty.returns.valtypes.len()) @@ -99,7 +101,13 @@ pub(super) fn run( CALL => { let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx; - let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); + // TODO: if it is imported, defer to linking + let func_to_call_inst = store + .funcs + .get(func_to_call_idx) + .unwrap_validated() + .try_into_local() + .expect("TODO: call imported functions"); let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); let params = stack.pop_tail_iter(func_to_call_ty.params.valtypes.len()); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 07a2f500..8463bae2 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -3,10 +3,12 @@ use alloc::vec::Vec; use const_interpreter_loop::run_const; use interpreter_loop::run; use locals::Locals; +use store::{ImportedFuncInst, LocalFuncInst}; use value_stack::Stack; use crate::core::indices::FuncIdx; use crate::core::reader::types::export::{Export, ExportDesc}; +use crate::core::reader::types::import::ImportDesc; use crate::core::reader::types::{FuncType, ValType}; use crate::core::reader::WasmReader; use crate::execution::assert_validated::UnwrapValidatedExt; @@ -90,6 +92,7 @@ where }); if let Some(func_idx) = func_idx { + trace!("Found function index: {} for name {}", func_idx, func_name); self.invoke_func(func_idx, param) } else { Err(RuntimeError::FunctionNotFound) @@ -103,15 +106,30 @@ where params: Param, ) -> Result { // -=-= Verification =-=- - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); + trace!("{:?}", self.store.funcs); + let func_inst = self + .store + .funcs + .get(func_idx) + .ok_or(RuntimeError::FunctionNotFound)? + .try_into_local() + .ok_or(RuntimeError::FunctionNotFound)?; let func_ty = self.types.get(func_inst.ty).unwrap_validated(); // Check correct function parameters and return types if func_ty.params.valtypes != Param::TYS { - panic!("Invalid `Param` generics"); + panic!( + "Invalid `Param` generics. Expected: {:?}, Found: {:?}", + func_ty.params.valtypes, + Param::TYS + ); } if func_ty.returns.valtypes != Returns::TYS { - panic!("Invalid `Returns` generics"); + panic!( + "Invalid `Returns` generics. Expected: {:?}, Found: {:?}", + func_ty.returns.valtypes, + Returns::TYS + ); } // Prepare a new stack with the locals for the entry function @@ -155,7 +173,13 @@ where ret_types: &[ValType], ) -> Result, RuntimeError> { // -=-= Verification =-=- - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); + let func_inst = self + .store + .funcs + .get(func_idx) + .ok_or(RuntimeError::FunctionNotFound)? + .try_into_local() + .ok_or(RuntimeError::FunctionNotFound)?; let func_ty = self.types.get(func_inst.ty).unwrap_validated(); // Verify that the given parameters match the function parameters @@ -184,9 +208,6 @@ where EmptyHookSet, )?; - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); - let func_ty = self.types.get(func_inst.ty).unwrap_validated(); - // Pop return values from stack let return_values = func_ty .returns @@ -209,28 +230,40 @@ where let functions = validation_info.functions.iter(); let func_blocks = validation_info.func_blocks.iter(); - functions - .zip(func_blocks) - .map(|(ty, func)| { - wasm_reader - .move_start_to(*func) - .expect("function index to be in the bounds of the WASM binary"); - - let (locals, bytes_read) = wasm_reader - .measure_num_read_bytes(read_declared_locals) - .unwrap_validated(); - - let code_expr = wasm_reader - .make_span(func.len() - bytes_read) - .expect("TODO remove this expect"); - - FuncInst { - ty: *ty, - locals, - code_expr, - } + let local_function_inst = functions.zip(func_blocks).map(|(ty, func)| { + wasm_reader + .move_start_to(*func) + .expect("function index to be in the bounds of the WASM binary"); + + let (locals, bytes_read) = wasm_reader + .measure_num_read_bytes(read_declared_locals) + .unwrap_validated(); + + let code_expr = wasm_reader + .make_span(func.len() - bytes_read) + .expect("TODO remove this expect"); + + FuncInst::Local(LocalFuncInst { + ty: *ty, + locals, + code_expr, }) - .collect() + }); + + let imported_function_inst = + validation_info + .imports + .iter() + .filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(FuncInst::Imported(ImportedFuncInst { + ty: *type_idx, + module_name: import.module_name.clone(), + function_name: import.name.clone(), + })), + _ => None, + }); + + imported_function_inst.chain(local_function_inst).collect() }; let memory_instances: Vec = validation_info diff --git a/src/execution/store.rs b/src/execution/store.rs index 6d897b8f..d3b17bdd 100644 --- a/src/execution/store.rs +++ b/src/execution/store.rs @@ -1,3 +1,4 @@ +use alloc::string::String; use alloc::vec; use alloc::vec::Vec; use core::iter; @@ -13,6 +14,7 @@ use crate::execution::value::{Ref, Value}; /// globals, element segments, and data segments that have been allocated during the life time of /// the abstract machine. /// +#[derive(Debug)] pub struct Store { pub funcs: Vec, // tables: Vec, @@ -20,18 +22,52 @@ pub struct Store { pub globals: Vec, } -pub struct FuncInst { +#[derive(Debug)] +pub enum FuncInst { + Local(LocalFuncInst), + Imported(ImportedFuncInst), +} + +impl FuncInst { + pub fn ty(&self) -> TypeIdx { + match self { + FuncInst::Local(f) => f.ty, + FuncInst::Imported(f) => f.ty, + } + } + + pub fn try_into_local(&self) -> Option<&LocalFuncInst> { + match self { + FuncInst::Local(f) => Some(f), + FuncInst::Imported(_) => None, + } + } +} + +#[derive(Debug)] +pub struct LocalFuncInst { pub ty: TypeIdx, pub locals: Vec, pub code_expr: Span, } +#[derive(Debug)] +pub struct ImportedFuncInst { + pub ty: TypeIdx, + #[allow(dead_code)] + pub module_name: String, + #[allow(dead_code)] + pub function_name: String, +} + #[allow(dead_code)] +#[derive(Debug)] pub struct TableInst { pub ty: TableType, pub elem: Vec, } +#[derive(Debug)] pub struct MemInst { #[allow(warnings)] pub ty: MemType, @@ -61,6 +97,7 @@ impl MemInst { } } +#[derive(Debug)] pub struct GlobalInst { pub global: Global, /// Must be of the same type as specified in `ty` diff --git a/src/validation/code.rs b/src/validation/code.rs index 52400d73..6543f73b 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -1,7 +1,7 @@ use alloc::vec::Vec; use core::iter; -use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx}; +use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx, TypeIdx}; use crate::core::reader::section_header::{SectionHeader, SectionTy}; use crate::core::reader::span::Span; use crate::core::reader::types::global::Global; @@ -11,17 +11,34 @@ use crate::core::reader::{WasmReadable, WasmReader}; use crate::validation_stack::ValidationStack; use crate::{Error, Result}; +/// +/// +/// # Arguments +/// - `wasm`: The reader over the whole wasm binary. It is expected to be at the beginning of the code section, and +/// after execution it will be at the beginning of the next section if the result is `Ok(...)`. +/// - `section_header`: The header of the code section. +/// - `fn_types`: The types of all functions in the module, including imported functions. +/// - `type_idx_of_fn`: The index of the type of each function in `fn_types`, including imported functions. As per the +/// specification, the indicies of the type of imported functions come first. +/// - `num_imported_funcs`: The number of imported functions. This is used as an offset, to determine the first index of +/// a local function in `type_idx_of_fn`. +/// - `globals`: The global variables of the module. +/// +/// # Returns +/// pub fn validate_code_section( wasm: &mut WasmReader, section_header: SectionHeader, fn_types: &[FuncType], - type_idx_of_fn: &[usize], + type_idx_of_fn: &[TypeIdx], + num_imported_funcs: usize, globals: &[Global], ) -> Result> { assert_eq!(section_header.ty, SectionTy::Code); let code_block_spans = wasm.read_vec_enumerated(|wasm, idx| { - let ty_idx = type_idx_of_fn[idx]; + // We need to offset the index by the number of functions that were imported + let ty_idx = type_idx_of_fn[idx + num_imported_funcs]; let func_ty = fn_types[ty_idx].clone(); debug!("{:x?}", wasm.full_wasm_binary); @@ -39,7 +56,7 @@ pub fn validate_code_section( let mut stack = ValidationStack::new(); read_instructions( - idx, + idx + num_imported_funcs, wasm, &mut stack, &locals, diff --git a/src/validation/mod.rs b/src/validation/mod.rs index d6a52fa9..fc8b40bc 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -5,7 +5,7 @@ use crate::core::reader::section_header::{SectionHeader, SectionTy}; use crate::core::reader::span::Span; use crate::core::reader::types::export::Export; use crate::core::reader::types::global::Global; -use crate::core::reader::types::import::Import; +use crate::core::reader::types::import::{Import, ImportDesc}; use crate::core::reader::types::{FuncType, MemType, TableType}; use crate::core::reader::{WasmReadable, WasmReader}; use crate::{Error, Result}; @@ -73,10 +73,30 @@ pub fn validate(wasm: &[u8]) -> Result { while (skip_section(&mut wasm, &mut header)?).is_some() {} - let functions = handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| { - wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize)) - })? - .unwrap_or_default(); + // The `Function` section only covers module-level (or "local") functions. Imported functions have their types known + // in the `import` section. Both local and imported functions share the same index space. + // + // Imported functions are given priority and have the first indicies, and only after that do the local functions get + // assigned their indices. + let local_functions = + handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| { + wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize)) + })? + .unwrap_or_default(); + + let imported_functions = imports + .iter() + .filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(*type_idx), + _ => None, + }) + .collect::>(); + + let all_functions = imported_functions + .iter() + .chain(local_functions.iter()) + .cloned() + .collect::>(); while (skip_section(&mut wasm, &mut header)?).is_some() {} @@ -130,11 +150,22 @@ pub fn validate(wasm: &[u8]) -> Result { while (skip_section(&mut wasm, &mut header)?).is_some() {} let func_blocks = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| { - code::validate_code_section(wasm, h, &types, &functions, &globals) + code::validate_code_section( + wasm, + h, + &types, + &all_functions, + imported_functions.len(), + &globals, + ) })? .unwrap_or_default(); - assert_eq!(func_blocks.len(), functions.len(), "these should be equal"); // TODO check if this is in the spec + assert_eq!( + func_blocks.len(), + local_functions.len(), + "these should be equal" + ); // TODO check if this is in the spec while (skip_section(&mut wasm, &mut header)?).is_some() {} @@ -154,7 +185,7 @@ pub fn validate(wasm: &[u8]) -> Result { wasm: wasm.into_inner(), types, imports, - functions, + functions: local_functions, tables, memories, globals, @@ -189,3 +220,12 @@ fn handle_section Result>( _ => Ok(None), } } + +impl ValidationInfo<'_> { + pub fn get_imported_funcs(&self) -> impl Iterator { + self.imports.iter().filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(type_idx), + _ => None, + }) + } +} diff --git a/tests/imports.rs b/tests/imports.rs new file mode 100644 index 00000000..19084666 --- /dev/null +++ b/tests/imports.rs @@ -0,0 +1,59 @@ +use wasm::{validate, RuntimeInstance}; + +const UNUSED_IMPORTS: &str = r#" +(module + (import "env" "dummy1" (func (param i32))) + (import "env" "dummy2" (func (param i32))) + (func (export "get_three") (param) (result i32) + i32.const 1 + i32.const 2 + i32.add + ) +)"#; + +const SIMPLE_IMPORT: &str = r#" +(module + (import "env" "print" (func $print (param i32))) + (func (export "print_three") + i32.const 1 + i32.const 2 + i32.add + call $print + ) +)"#; + +/// This test checks that the import order is correct, even if the imports are not used. +#[test_log::test] +pub fn import_order() { + let wasm_bytes = wat::parse_str(UNUSED_IMPORTS).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!(3, instance.invoke_named("get_three", ()).unwrap()); + // Function 0 should be the imported function "dummy1" + // Function 1 should be the imported function "dummy2" + // Function 2 should be the local function "get_three" + assert_eq!(3, instance.invoke_func(2, ()).unwrap()); +} + +#[test_log::test] +pub fn compile_simple_import() { + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let _ = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + // assert_eq!((), instance.invoke_named("print_three", ()).unwrap()); + // Function 0 should be the imported function + // assert_eq!((), instance.invoke_func(1, ()).unwrap()); +} + +#[test_log::test] +pub fn run_simple_import() { + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!((), instance.invoke_named("print_three", ()).unwrap()); + // Function 0 should be the imported function + assert_eq!((), instance.invoke_func(1, ()).unwrap()); +}