diff --git a/golem-rib/src/compiler/byte_code.rs b/golem-rib/src/compiler/byte_code.rs index cebf097e8..71e6ed990 100644 --- a/golem-rib/src/compiler/byte_code.rs +++ b/golem-rib/src/compiler/byte_code.rs @@ -376,9 +376,9 @@ mod internal { convert_to_analysed_type_for(expr, inferred_type)?, )); } - CallType::EnumConstructor(enmum_name) => { + CallType::EnumConstructor(enum_name) => { instructions.push(RibIR::PushEnum( - enmum_name.clone(), + enum_name.clone(), convert_to_analysed_type_for(expr, inferred_type)?, )); } @@ -496,8 +496,8 @@ mod internal { #[cfg(test)] mod compiler_tests { use super::*; - use crate::{compiler, ArmPattern, InferredType, MatchArm, Number, VariableId}; - use golem_wasm_ast::analysis::{AnalysedType, NameTypePair, TypeRecord, TypeStr, TypeU32}; + use crate::{ArmPattern, InferredType, MatchArm, Number, VariableId}; + use golem_wasm_ast::analysis::{AnalysedType, NameTypePair, TypeRecord, TypeStr}; use golem_wasm_rpc::protobuf::type_annotated_value::TypeAnnotatedValue; #[test] @@ -984,6 +984,208 @@ mod compiler_tests { assert_eq!(instructions, expected_instructions); } + #[cfg(test)] + mod invalid_function_invoke_tests { + use crate::compiler::byte_code::compiler_tests::internal; + use crate::{compiler, Expr}; + use golem_wasm_ast::analysis::{AnalysedType, TypeStr}; + + #[test] + fn test_unknown_function() { + let expr = r#" + foo(request); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &vec![]).unwrap_err(); + + assert_eq!(compiler_error, "Unknown function call: `foo`"); + } + + #[test] + fn test_unknown_resource_constructor() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + let user_id = "user"; + golem:it/api.{cart(user_id).add-item}("apple"); + golem:it/api.{cart0(user_id).add-item}("apple"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Unknown resource constructor call: `golem:it/api.{cart0(user_id).add-item}`. Resource `cart0` doesn't exist" + ); + } + + #[test] + fn test_unknown_resource_method() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + let user_id = "user"; + golem:it/api.{cart(user_id).add-item}("apple"); + golem:it/api.{cart(user_id).foo}("apple"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Unknown resource method call `golem:it/api.{cart(user_id).foo}`. `foo` doesn't exist in resource `cart`" + ); + } + + #[test] + fn test_invalid_arg_size_function() { + let metadata = internal::get_component_metadata( + "foo", + vec![AnalysedType::Str(TypeStr)], + AnalysedType::Str(TypeStr), + ); + + let expr = r#" + let user_id = "user"; + let result = foo(user_id, user_id); + result + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Incorrect number of arguments for function `foo`. Expected 1, but provided 2" + ); + } + + #[test] + fn test_invalid_arg_size_resource_constructor() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + let user_id = "user"; + golem:it/api.{cart(user_id, user_id).add-item}("apple"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Incorrect number of arguments for resource constructor `cart`. Expected 1, but provided 2" + ); + } + + #[test] + fn test_invalid_arg_size_resource_method() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + let user_id = "user"; + golem:it/api.{cart(user_id).add-item}("apple", "samsung"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Incorrect number of arguments in resource method `golem:it/api.{cart(user_id).add-item}`. Expected 1, but provided 2" + ); + } + + #[test] + fn test_invalid_arg_size_variants() { + let metadata = internal::metadata_with_variants(); + + let expr = r#" + let regiser_user_action = register-user(1, "foo"); + let result = golem:it/api.{foo}(regiser_user_action); + result + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Invalid number of arguments in variant `register-user`. Expected 1, but provided 2" + ); + } + + #[test] + fn test_invalid_arg_types_function() { + let metadata = internal::get_component_metadata( + "foo", + vec![AnalysedType::Str(TypeStr)], + AnalysedType::Str(TypeStr), + ); + + let expr = r#" + let result = foo(1u64); + result + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Invalid type for the argument in function `foo`. Expected type `str`, but provided argument `1u64` is a `number`" + ); + } + + #[test] + fn test_invalid_arg_types_resource_method() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + let user_id = "user"; + golem:it/api.{cart(user_id).add-item}("apple"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Invalid type for the argument in resource method `golem:it/api.{cart(user_id).add-item}`. Expected type `record`, but provided argument `\"apple\"` is a `str`" + ); + } + + #[test] + fn test_invalid_arg_types_resource_constructor() { + let metadata = internal::metadata_with_resource_methods(); + let expr = r#" + golem:it/api.{cart({foo : "bar"}).add-item}("apple"); + "success" + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Invalid type for the argument in resource constructor `cart`. Expected type `str`, but provided argument `{foo: \"bar\"}` is a `record`" + ); + } + + #[test] + fn test_invalid_arg_types_variants() { + let metadata = internal::metadata_with_variants(); + + let expr = r#" + let regiser_user_action = register-user("foo"); + let result = golem:it/api.{foo}(regiser_user_action); + result + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiler_error = compiler::compile(&expr, &metadata).unwrap_err(); + assert_eq!( + compiler_error, + "Invalid type for the argument in variant constructor `register-user`. Expected type `number`, but provided argument `\"foo\"` is a `str`" + ); + } + } + #[cfg(test)] mod global_input_tests { use crate::compiler::byte_code::compiler_tests::internal; @@ -993,6 +1195,64 @@ mod compiler_tests { TypeRecord, TypeResult, TypeStr, TypeTuple, TypeU32, TypeU64, TypeVariant, }; + #[tokio::test] + async fn test_str_global_input() { + let request_value_type = AnalysedType::Str(TypeStr); + + let output_analysed_type = AnalysedType::Str(TypeStr); + + let analysed_exports = internal::get_component_metadata( + "my-worker-function", + vec![request_value_type.clone()], + output_analysed_type, + ); + + let expr = r#" + let x = request; + my-worker-function(x); + match x { + "foo" => "success", + _ => "fallback" + } + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiled = compiler::compile(&expr, &analysed_exports).unwrap(); + let expected_type_info = + internal::rib_input_type_info(vec![("request", request_value_type)]); + + assert_eq!(compiled.global_input_type_info, expected_type_info); + } + + #[tokio::test] + async fn test_number_global_input() { + let request_value_type = AnalysedType::U32(TypeU32); + + let output_analysed_type = AnalysedType::Str(TypeStr); + + let analysed_exports = internal::get_component_metadata( + "my-worker-function", + vec![request_value_type.clone()], + output_analysed_type, + ); + + let expr = r#" + let x = request; + my-worker-function(x); + match x { + 1 => "success", + 0 => "failure" + } + "#; + + let expr = Expr::from_text(expr).unwrap(); + let compiled = compiler::compile(&expr, &analysed_exports).unwrap(); + let expected_type_info = + internal::rib_input_type_info(vec![("request", request_value_type)]); + + assert_eq!(compiled.global_input_type_info, expected_type_info); + } + #[tokio::test] async fn test_variant_type_info() { let request_value_type = AnalysedType::Variant(TypeVariant { @@ -1280,72 +1540,93 @@ mod compiler_tests { } } - #[tokio::test] - async fn test_str_global_input() { - let request_value_type = AnalysedType::Str(TypeStr); - - let output_analysed_type = AnalysedType::Str(TypeStr); - - let analysed_exports = internal::get_component_metadata( - "my-worker-function", - vec![request_value_type.clone()], - output_analysed_type, - ); - - let expr = r#" - let x = request; - my-worker-function(x); - match x { - "foo" => "success", - _ => "fallback" - } - "#; - - let expr = Expr::from_text(expr).unwrap(); - let compiled = compiler::compile(&expr, &analysed_exports).unwrap(); - let expected_type_info = - internal::rib_input_type_info(vec![("request", request_value_type)]); - - assert_eq!(compiled.global_input_type_info, expected_type_info); - } - - #[tokio::test] - async fn test_number_global_input() { - let request_value_type = AnalysedType::U32(TypeU32); - - let output_analysed_type = AnalysedType::Str(TypeStr); - - let analysed_exports = internal::get_component_metadata( - "my-worker-function", - vec![request_value_type.clone()], - output_analysed_type, - ); - - let expr = r#" - let x = request; - my-worker-function(x); - match x { - 1 => "success", - 0 => "failure" - } - "#; - - let expr = Expr::from_text(expr).unwrap(); - let compiled = compiler::compile(&expr, &analysed_exports).unwrap(); - let expected_type_info = - internal::rib_input_type_info(vec![("request", request_value_type)]); - - assert_eq!(compiled.global_input_type_info, expected_type_info); - } - mod internal { use crate::RibInputTypeInfo; - use golem_wasm_ast::analysis::{ - AnalysedExport, AnalysedFunction, AnalysedFunctionParameter, AnalysedFunctionResult, - AnalysedType, - }; + use golem_wasm_ast::analysis::*; use std::collections::HashMap; + pub(crate) fn metadata_with_variants() -> Vec { + let instance = AnalysedExport::Instance(AnalysedInstance { + name: "golem:it/api".to_string(), + functions: vec![AnalysedFunction { + name: "foo".to_string(), + parameters: vec![AnalysedFunctionParameter { + name: "param1".to_string(), + typ: AnalysedType::Variant(TypeVariant { + cases: vec![ + NameOptionTypePair { + name: "register-user".to_string(), + typ: Some(AnalysedType::U64(TypeU64)), + }, + NameOptionTypePair { + name: "process-user".to_string(), + typ: Some(AnalysedType::Str(TypeStr)), + }, + NameOptionTypePair { + name: "validate".to_string(), + typ: None, + }, + ], + }), + }], + results: vec![AnalysedFunctionResult { + name: None, + typ: AnalysedType::Handle(TypeHandle { + resource_id: AnalysedResourceId(0), + mode: AnalysedResourceMode::Owned, + }), + }], + }], + }); + + vec![instance] + } + + pub(crate) fn metadata_with_resource_methods() -> Vec { + let instance = AnalysedExport::Instance(AnalysedInstance { + name: "golem:it/api".to_string(), + functions: vec![ + AnalysedFunction { + name: "[constructor]cart".to_string(), + parameters: vec![AnalysedFunctionParameter { + name: "param1".to_string(), + typ: AnalysedType::Str(TypeStr), + }], + results: vec![AnalysedFunctionResult { + name: None, + typ: AnalysedType::Handle(TypeHandle { + resource_id: AnalysedResourceId(0), + mode: AnalysedResourceMode::Owned, + }), + }], + }, + AnalysedFunction { + name: "[method]cart.add-item".to_string(), + parameters: vec![ + AnalysedFunctionParameter { + name: "self".to_string(), + typ: AnalysedType::Handle(TypeHandle { + resource_id: AnalysedResourceId(0), + mode: AnalysedResourceMode::Borrowed, + }), + }, + AnalysedFunctionParameter { + name: "item".to_string(), + typ: AnalysedType::Record(TypeRecord { + fields: vec![NameTypePair { + name: "name".to_string(), + typ: AnalysedType::Str(TypeStr), + }], + }), + }, + ], + results: vec![], + }, + ], + }); + + vec![instance] + } pub(crate) fn get_component_metadata( function_name: &str, input_types: Vec, diff --git a/golem-rib/src/compiler/desugar.rs b/golem-rib/src/compiler/desugar.rs index e2c2a0ac5..cab54a592 100644 --- a/golem-rib/src/compiler/desugar.rs +++ b/golem-rib/src/compiler/desugar.rs @@ -261,19 +261,19 @@ mod internal { ) -> Option { match pred_expr_inferred_type { InferredType::Record(field_and_types) => { - // Resolution body is a list of expressions which grows (may be with some let bindings) + // Resolution body is a list of expressions which grows (maybe with some let bindings) // as we recursively iterate over the bind patterns // where bind patterns are {name: x, age: _, address : _ } in the case of `match record { {name: x, age: _, address : _ } ) =>` - // These will exist prior to the original resolution of a successful tuple match. + // These will exist prior to the original resolution of a successful record match. let mut resolution_body = vec![]; // The conditions keep growing as we recursively iterate over the bind patterns - // and there are multiple conditions (if condition) for each element in the tuple + // and there are multiple conditions (if condition) for each element in the record. let mut conditions = vec![]; - // We assume pred-expr can be queried by field using Expr::select_field and we pick each element in the bind pattern + // We assume pred-expr can be queried by field using Expr::select_field, and we pick each element in the bind pattern // to get the corresponding expr in pred-expr and keep recursively iterating until the record is completed. - // However there is no resolution body for each of this iteration, so we use an empty expression + // However, there is no resolution body for each of this iteration, so we use an empty expression // and finally push the original resolution body once we fully build the conditions. for (field, arm_pattern) in bind_patterns.iter() { let new_pred = Expr::select_field(pred_expr.clone(), field); diff --git a/golem-rib/src/expr.rs b/golem-rib/src/expr.rs index f33ed518e..787e5e127 100644 --- a/golem-rib/src/expr.rs +++ b/golem-rib/src/expr.rs @@ -132,6 +132,10 @@ impl Expr { matches!(self, Expr::Cond(_, _, _, _)) } + pub fn is_function_call(&self) -> bool { + matches!(self, Expr::Call(_, _, _)) + } + pub fn is_match_expr(&self) -> bool { matches!(self, Expr::PatternMatch(_, _, _)) } @@ -419,10 +423,10 @@ impl Expr { self.bind_types(); self.name_binding_pattern_match_variables(); self.name_binding_local_variables(); - self.infer_function_types(function_type_registry) - .map_err(|x| vec![x])?; self.infer_variants(function_type_registry); self.infer_enums(function_type_registry); + self.infer_call_arguments_type(function_type_registry) + .map_err(|x| vec![x])?; type_inference::type_inference_fix_point(Self::inference_scan, self) .map_err(|x| vec![x])?; self.unify_types()?; @@ -452,11 +456,11 @@ impl Expr { } // At this point we simply update the types to the parameter type expressions and the call expression itself. - pub fn infer_function_types( + pub fn infer_call_arguments_type( &mut self, function_type_registry: &FunctionTypeRegistry, ) -> Result<(), String> { - type_inference::infer_function_types(self, function_type_registry) + type_inference::infer_call_arguments_type(self, function_type_registry) } pub fn push_types_down(&mut self) -> Result<(), String> { diff --git a/golem-rib/src/function_name.rs b/golem-rib/src/function_name.rs index 52b277ec6..2586d7a6a 100644 --- a/golem-rib/src/function_name.rs +++ b/golem-rib/src/function_name.rs @@ -476,25 +476,35 @@ impl ParsedFunctionReference { Self::RawResourceStaticMethod { resource, method, .. } => format!("[static]{resource}.{method}"), - ParsedFunctionReference::IndexedResourceConstructor { resource, .. } => { + Self::IndexedResourceConstructor { resource, .. } => { format!("[constructor]{resource}") } - ParsedFunctionReference::IndexedResourceMethod { + Self::IndexedResourceMethod { resource, method, .. } => { format!("[method]{resource}.{method}") } - ParsedFunctionReference::IndexedResourceStaticMethod { + Self::IndexedResourceStaticMethod { resource, method, .. } => { format!("[static]{resource}.{method}") } - ParsedFunctionReference::IndexedResourceDrop { resource, .. } => { + Self::IndexedResourceDrop { resource, .. } => { format!("[drop]{resource}") } } } + pub fn resource_method_name(&self) -> Option { + match self { + Self::IndexedResourceStaticMethod { method, .. } + | Self::RawResourceMethod { method, .. } + | Self::RawResourceStaticMethod { method, .. } + | Self::IndexedResourceMethod { method, .. } => Some(method.clone()), + _ => None, + } + } + pub fn method_as_static(&self) -> Option { match self { Self::RawResourceMethod { resource, method } => Some(Self::RawResourceStaticMethod { diff --git a/golem-rib/src/inferred_type.rs b/golem-rib/src/inferred_type.rs index 6577def69..bee7f8269 100644 --- a/golem-rib/src/inferred_type.rs +++ b/golem-rib/src/inferred_type.rs @@ -1117,6 +1117,25 @@ impl InferredType { } } } + + pub fn from_variant_cases(type_variant: &TypeVariant) -> InferredType { + let cases = type_variant + .cases + .iter() + .map(|name_type_pair| { + ( + name_type_pair.name.clone(), + name_type_pair.typ.clone().map(|t| t.into()), + ) + }) + .collect(); + + InferredType::Variant(cases) + } + + pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType { + InferredType::Enum(type_enum.cases.clone()) + } } impl From for InferredType { @@ -1146,7 +1165,7 @@ impl From for InferredType { .collect(), ), AnalysedType::Flags(vs) => InferredType::Flags(vs.names), - AnalysedType::Enum(vs) => InferredType::Enum(vs.cases), + AnalysedType::Enum(vs) => InferredType::from_enum_cases(&vs), AnalysedType::Option(t) => InferredType::Option(Box::new((*t.inner).into())), AnalysedType::Result(golem_wasm_ast::analysis::TypeResult { ok, err, .. }) => { InferredType::Result { @@ -1154,14 +1173,7 @@ impl From for InferredType { error: err.map(|t| Box::new((*t).into())), } } - AnalysedType::Variant(vs) => InferredType::Variant( - vs.cases - .into_iter() - .map(|name_type_pair| { - (name_type_pair.name, name_type_pair.typ.map(|t| t.into())) - }) - .collect(), - ), + AnalysedType::Variant(vs) => InferredType::from_variant_cases(&vs), AnalysedType::Handle(golem_wasm_ast::analysis::TypeHandle { resource_id, mode }) => { InferredType::Resource { resource_id: resource_id.0, diff --git a/golem-rib/src/interpreter/mod.rs b/golem-rib/src/interpreter/mod.rs index d86dd8f7f..ab44f8e5c 100644 --- a/golem-rib/src/interpreter/mod.rs +++ b/golem-rib/src/interpreter/mod.rs @@ -35,3 +35,14 @@ pub async fn interpret( let mut interpreter = Interpreter::new(rib_input, function_invoke); interpreter.run(rib.clone()).await } + +// This function can be used for those the Rib Scripts +// where there are no side effecting function calls. +// It is recommended to use `interpret` over `interpret_pure` if you are unsure. +pub async fn interpret_pure( + rib: &RibByteCode, + rib_input: &HashMap, +) -> Result { + let mut interpreter = Interpreter::pure(rib_input.clone()); + interpreter.run(rib.clone()).await +} diff --git a/golem-rib/src/interpreter/result.rs b/golem-rib/src/interpreter/result.rs index c100ae567..934dbbc5b 100644 --- a/golem-rib/src/interpreter/result.rs +++ b/golem-rib/src/interpreter/result.rs @@ -47,20 +47,21 @@ impl RibInterpreterResult { pub fn get_bool(&self) -> Option { match self { RibInterpreterResult::Val(TypeAnnotatedValue::Bool(bool)) => Some(*bool), - _ => None, + RibInterpreterResult::Val(_) => None, + RibInterpreterResult::Unit => None, } } pub fn get_val(&self) -> Option { match self { RibInterpreterResult::Val(val) => Some(val.clone()), - _ => None, + RibInterpreterResult::Unit => None, } } pub fn get_literal(&self) -> Option { match self { RibInterpreterResult::Val(val) => val.get_literal(), - _ => None, + RibInterpreterResult::Unit => None, } } diff --git a/golem-rib/src/interpreter/rib_interpreter.rs b/golem-rib/src/interpreter/rib_interpreter.rs index 8ee11279b..087984dcc 100644 --- a/golem-rib/src/interpreter/rib_interpreter.rs +++ b/golem-rib/src/interpreter/rib_interpreter.rs @@ -50,7 +50,9 @@ impl Interpreter { } } - pub fn from_input(env: HashMap) -> Self { + // Interpreter that's not expected to call a side-effecting function call. + // All it needs is environment with the required variables to evaluate the Rib script + pub fn pure(env: HashMap) -> Self { Interpreter { stack: InterpreterStack::new(), env: InterpreterEnv::from_input(env), @@ -140,8 +142,8 @@ impl Interpreter { internal::run_create_function_name_instruction(site, function_type, self)?; } - RibIR::InvokeFunction(arity, _) => { - internal::run_call_instruction(arity, self).await?; + RibIR::InvokeFunction(arg_size, _) => { + internal::run_call_instruction(arg_size, self).await?; } RibIR::PushVariant(variant_name, analysed_type) => { @@ -335,7 +337,7 @@ mod internal { .map(|interpreter_result| { interpreter_result .get_val() - .ok_or("Failed to get value from the stack".to_string()) + .ok_or("Internal Error: Failed to construct list".to_string()) }) .collect::, String>>()?; @@ -363,12 +365,13 @@ mod internal { .pop_n(list_size) .ok_or(format!("Expected {} value on the stack", list_size))?; + dbg!(last_list.clone()); let type_annotated_values = last_list .iter() .map(|interpreter_result| { interpreter_result .get_val() - .ok_or("Failed to get value from the stack".to_string()) + .ok_or("Internal Error: Failed to construct tuple".to_string()) }) .collect::, String>>()?; @@ -634,7 +637,7 @@ mod internal { .map(|interpreter_result| { interpreter_result .get_val() - .ok_or("Failed to get value from the stack".to_string()) + .ok_or("Internal Error: Failed to construct resource".to_string()) }) .collect::, String>>()?; @@ -666,9 +669,9 @@ mod internal { let type_anntoated_values = last_n_elements .iter() .map(|interpreter_result| { - interpreter_result - .get_val() - .ok_or("Failed to get value from the stack".to_string()) + interpreter_result.get_val().ok_or( + "Internal Error: Failed to call indexed resource method".to_string(), + ) }) .collect::, String>>()?; @@ -693,17 +696,17 @@ mod internal { arg_size, method, } => { - let last_n_elements = interpreter - .stack - .pop_n(arg_size) - .ok_or("Failed to get values from the stack".to_string())?; + let last_n_elements = interpreter.stack.pop_n(arg_size).ok_or( + "Internal error: Failed to get arguments for static resource method" + .to_string(), + )?; let type_anntoated_values = last_n_elements .iter() .map(|interpreter_result| { - interpreter_result - .get_val() - .ok_or("Failed to get value from the stack".to_string()) + interpreter_result.get_val().ok_or( + "Internal error: Failed to call static resource method".to_string(), + ) }) .collect::, String>>()?; @@ -724,17 +727,17 @@ mod internal { .push_val(TypeAnnotatedValue::Str(parsed_function_name.to_string())); } FunctionReferenceType::IndexedResourceDrop { resource, arg_size } => { - let last_n_elements = interpreter - .stack - .pop_n(arg_size) - .ok_or("Failed to get values from the stack".to_string())?; + let last_n_elements = interpreter.stack.pop_n(arg_size).ok_or( + "Internal Error: Failed to get resource parameters for indexed resource drop" + .to_string(), + )?; - let type_anntoated_values = last_n_elements + let type_annotated_values = last_n_elements .iter() .map(|interpreter_result| { - interpreter_result - .get_val() - .ok_or("Failed to get value from the stack".to_string()) + interpreter_result.get_val().ok_or( + "Internal Error: Failed to call indexed resource drop".to_string(), + ) }) .collect::, String>>()?; @@ -742,7 +745,7 @@ mod internal { site, function: ParsedFunctionReference::IndexedResourceDrop { resource, - resource_params: type_anntoated_values + resource_params: type_annotated_values .iter() .map(type_annotated_value_to_string) .collect::, String>>()?, @@ -758,27 +761,27 @@ mod internal { Ok(()) } - // Separate variant pub(crate) async fn run_call_instruction( - argument_size: usize, + arg_size: usize, interpreter: &mut Interpreter, ) -> Result<(), String> { let function_name = interpreter .stack .pop_str() - .ok_or("Failed to get a function name from the stack".to_string())?; + .ok_or("Internal Error: Failed to get a function name".to_string())?; let last_n_elements = interpreter .stack - .pop_n(argument_size) - .ok_or("Failed to get values from the stack".to_string())?; + .pop_n(arg_size) + .ok_or("Internal Error: Failed to get arguments for the function call".to_string())?; let type_anntoated_values = last_n_elements .iter() .map(|interpreter_result| { - interpreter_result - .get_val() - .ok_or("Failed to get value from the stack".to_string()) + interpreter_result.get_val().ok_or(format!( + "Internal Error: Failed to call function {}", + function_name + )) }) .collect::, String>>()?; @@ -921,19 +924,19 @@ mod internal { ) -> Result<(), String> { let last_n_elements = interpreter_stack .pop_n(arg_size) - .ok_or("Failed to get values from the stack".to_string())?; + .ok_or("Internal Error: Failed to get arguments for concatenation".to_string())?; - let type_anntoated_values = last_n_elements + let type_annotated_values = last_n_elements .iter() .map(|interpreter_result| { interpreter_result .get_val() - .ok_or("Failed to get value from the stack".to_string()) + .ok_or("Internal Error: Failed to execute concatenation".to_string()) }) .collect::, String>>()?; let mut str = String::new(); - for value in type_anntoated_values { + for value in type_annotated_values { let result = value .get_literal() .ok_or("Expected a literal value".to_string())? diff --git a/golem-rib/src/type_inference/call_arguments_inference.rs b/golem-rib/src/type_inference/call_arguments_inference.rs new file mode 100644 index 000000000..855d4dca4 --- /dev/null +++ b/golem-rib/src/type_inference/call_arguments_inference.rs @@ -0,0 +1,492 @@ +// Copyright 2024 Golem Cloud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::type_registry::FunctionTypeRegistry; +use crate::Expr; +use std::collections::VecDeque; + +pub fn infer_call_arguments_type( + expr: &mut Expr, + function_type_registry: &FunctionTypeRegistry, +) -> Result<(), String> { + let mut queue = VecDeque::new(); + queue.push_back(expr); + while let Some(expr) = queue.pop_back() { + match expr { + Expr::Call(parsed_fn_name, args, inferred_type) => { + internal::resolve_call_argument_types( + parsed_fn_name, + function_type_registry, + args, + inferred_type, + )?; + } + _ => expr.visit_children_mut_bottom_up(&mut queue), + } + } + + Ok(()) +} + +mod internal { + use crate::call_type::CallType; + use crate::type_inference::kind::GetTypeKind; + use crate::{Expr, FunctionTypeRegistry, InferredType, RegistryKey, RegistryValue}; + use golem_wasm_ast::analysis::AnalysedType; + use std::fmt::Display; + + pub(crate) fn resolve_call_argument_types( + call_type: &mut CallType, + function_type_registry: &FunctionTypeRegistry, + args: &mut [Expr], + inferred_type: &mut InferredType, + ) -> Result<(), String> { + match call_type { + CallType::Function(dynamic_parsed_function_name) => { + let parsed_function_static = dynamic_parsed_function_name.clone().to_static(); + let function = parsed_function_static.clone().function; + if function.resource_name().is_some() { + let resource_name = + function.resource_name().ok_or("Resource name not found")?; + + let constructor_name = { format!["[constructor]{}", resource_name] }; + + let mut constructor_params: &mut Vec = &mut vec![]; + + if let Some(resource_params) = dynamic_parsed_function_name + .function + .raw_resource_params_mut() + { + constructor_params = resource_params + } + + let registry_key = RegistryKey::from_function_name( + &parsed_function_static.site, + constructor_name.as_str(), + ); + + // Infer the types of constructor parameter expressions + infer_types( + &FunctionTypeInternal::ResourceConstructorName { + fqn: parsed_function_static.to_string(), + resource_constructor_name_pretty: parsed_function_static + .function + .resource_name() + .cloned() + .unwrap_or_default(), + resource_constructor_name: constructor_name, + }, + function_type_registry, + registry_key, + constructor_params, + inferred_type, + ) + .map_err(|e| e.to_string())?; + + // Infer the types of resource method parameters + let resource_method_name = function.function_name(); + let registry_key = RegistryKey::from_function_name( + &parsed_function_static.site, + resource_method_name.as_str(), + ); + + infer_types( + &FunctionTypeInternal::ResourceMethodName { + fqn: parsed_function_static.to_string(), + resource_constructor_name_pretty: parsed_function_static + .function + .resource_name() + .cloned() + .unwrap_or_default(), + resource_method_name_pretty: parsed_function_static + .function + .resource_method_name() + .unwrap_or_default(), + resource_method_name, + }, + function_type_registry, + registry_key, + args, + inferred_type, + ) + .map_err(|e| e.to_string()) + } else { + let registry_key = RegistryKey::from_invocation_name(call_type); + infer_types( + &FunctionTypeInternal::Fqn(parsed_function_static.to_string()), + function_type_registry, + registry_key, + args, + inferred_type, + ) + .map_err(|e| e.to_string()) + } + } + + CallType::EnumConstructor(_) => { + if args.is_empty() { + Ok(()) + } else { + Err("Enum constructor does not take any arguments".to_string()) + } + } + + CallType::VariantConstructor(variant_name) => { + let registry_key = RegistryKey::FunctionName(variant_name.clone()); + infer_types( + &FunctionTypeInternal::VariantName(variant_name.clone()), + function_type_registry, + registry_key, + args, + inferred_type, + ) + .map_err(|e| e.to_string()) + } + } + } + + // An internal error type for all possibilities of errors + // when inferring the type of arguments + enum FunctionArgsTypeInferenceError { + UnknownFunction(FunctionTypeInternal), + ArgumentSizeMisMatch { + function_type_internal: FunctionTypeInternal, + expected: usize, + provided: usize, + }, + TypeMisMatchError { + function_type_internal: FunctionTypeInternal, + expected: AnalysedType, + provided: Expr, + }, + } + + impl FunctionArgsTypeInferenceError { + fn type_mismatch( + function_type_internal: FunctionTypeInternal, + expected: AnalysedType, + provided: Expr, + ) -> FunctionArgsTypeInferenceError { + FunctionArgsTypeInferenceError::TypeMisMatchError { + function_type_internal, + expected, + provided, + } + } + } + + impl Display for FunctionArgsTypeInferenceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FunctionArgsTypeInferenceError::UnknownFunction(FunctionTypeInternal::Fqn( + parsed_function_name, + )) => { + write!(f, "Unknown function call: `{}`", parsed_function_name) + } + FunctionArgsTypeInferenceError::UnknownFunction( + FunctionTypeInternal::ResourceMethodName { + fqn, + resource_constructor_name_pretty: resource_name_human, + resource_method_name_pretty: resource_method_name_human, + .. + }, + ) => { + write!( + f, + "Unknown resource method call `{}`. `{}` doesn't exist in resource `{}`", + fqn, resource_method_name_human, resource_name_human + ) + } + FunctionArgsTypeInferenceError::UnknownFunction( + FunctionTypeInternal::ResourceConstructorName { + fqn, + resource_constructor_name_pretty: resource_constructor_name_human, + .. + }, + ) => { + write!( + f, + "Unknown resource constructor call: `{}`. Resource `{}` doesn't exist", + fqn, resource_constructor_name_human + ) + } + + FunctionArgsTypeInferenceError::UnknownFunction( + FunctionTypeInternal::VariantName(variant_name), + ) => { + write!(f, "Invalid variant constructor call: {}", variant_name) + } + + FunctionArgsTypeInferenceError::TypeMisMatchError { + function_type_internal, + expected, + provided, + } => match function_type_internal { + FunctionTypeInternal::ResourceConstructorName { + resource_constructor_name_pretty: resource_constructor_name_human, + .. + } => { + write!(f,"Invalid type for the argument in resource constructor `{}`. Expected type `{}`, but provided argument `{}` is a `{}`", resource_constructor_name_human, expected.get_type_kind(), provided, provided.inferred_type().get_type_kind()) + } + FunctionTypeInternal::ResourceMethodName { fqn, .. } => { + write!(f,"Invalid type for the argument in resource method `{}`. Expected type `{}`, but provided argument `{}` is a `{}`", fqn, expected.get_type_kind(), provided, provided.inferred_type().get_type_kind()) + } + FunctionTypeInternal::Fqn(fqn) => { + write!(f,"Invalid type for the argument in function `{}`. Expected type `{}`, but provided argument `{}` is a `{}`", fqn, expected.get_type_kind(), provided, provided.inferred_type().get_type_kind()) + } + FunctionTypeInternal::VariantName(str) => { + write!(f,"Invalid type for the argument in variant constructor `{}`. Expected type `{}`, but provided argument `{}` is a `{}`", str, expected.get_type_kind(), provided, provided.inferred_type().get_type_kind()) + } + }, + FunctionArgsTypeInferenceError::ArgumentSizeMisMatch { + function_type_internal, + expected, + provided, + } => match function_type_internal { + FunctionTypeInternal::ResourceConstructorName { + resource_constructor_name_pretty, + .. + } => { + write!(f, "Incorrect number of arguments for resource constructor `{}`. Expected {}, but provided {}", resource_constructor_name_pretty, expected, provided) + } + FunctionTypeInternal::ResourceMethodName { fqn, .. } => { + write!(f, "Incorrect number of arguments in resource method `{}`. Expected {}, but provided {}", fqn, expected, provided) + } + FunctionTypeInternal::Fqn(fqn) => { + write!(f, "Incorrect number of arguments for function `{}`. Expected {}, but provided {}", fqn, expected, provided) + } + FunctionTypeInternal::VariantName(str) => { + write!(f, "Invalid number of arguments in variant `{}`. Expected {}, but provided {}", str, expected, provided) + } + }, + } + } + } + + fn infer_types( + function_name: &FunctionTypeInternal, + function_type_registry: &FunctionTypeRegistry, + key: RegistryKey, + args: &mut [Expr], + inferred_type: &mut InferredType, + ) -> Result<(), FunctionArgsTypeInferenceError> { + if let Some(value) = function_type_registry.types.get(&key) { + match value { + RegistryValue::Value(_) => Ok(()), + RegistryValue::Variant { + parameter_types, + variant_type, + } => { + let parameter_types = parameter_types.clone(); + + if parameter_types.len() == args.len() { + tag_argument_types(function_name, args, ¶meter_types)?; + *inferred_type = InferredType::from_variant_cases(variant_type); + + Ok(()) + } else { + Err(FunctionArgsTypeInferenceError::ArgumentSizeMisMatch { + function_type_internal: function_name.clone(), + expected: parameter_types.len(), + provided: args.len(), + }) + } + } + RegistryValue::Function { + parameter_types, + return_types, + } => { + let mut parameter_types = parameter_types.clone(); + + if let FunctionTypeInternal::ResourceMethodName { .. } = function_name { + if let Some(AnalysedType::Handle(_)) = parameter_types.first() { + parameter_types.remove(0); + } + } + + if parameter_types.len() == args.len() { + tag_argument_types(function_name, args, ¶meter_types)?; + + *inferred_type = { + if return_types.len() == 1 { + return_types[0].clone().into() + } else { + InferredType::Sequence( + return_types.iter().map(|t| t.clone().into()).collect(), + ) + } + }; + + Ok(()) + } else { + Err(FunctionArgsTypeInferenceError::ArgumentSizeMisMatch { + function_type_internal: function_name.clone(), + expected: parameter_types.len(), + provided: args.len(), + }) + } + } + } + } else { + Err(FunctionArgsTypeInferenceError::UnknownFunction( + function_name.clone(), + )) + } + } + + #[derive(Clone)] + enum FunctionTypeInternal { + ResourceConstructorName { + fqn: String, + resource_constructor_name_pretty: String, + resource_constructor_name: String, + }, + ResourceMethodName { + fqn: String, + resource_constructor_name_pretty: String, + resource_method_name_pretty: String, + resource_method_name: String, + }, + Fqn(String), + VariantName(String), + } + + impl Display for FunctionTypeInternal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FunctionTypeInternal::ResourceConstructorName { + resource_constructor_name, + .. + } => { + write!(f, "{}", resource_constructor_name) + } + FunctionTypeInternal::ResourceMethodName { + resource_method_name, + .. + } => { + write!(f, "{}", resource_method_name) + } + FunctionTypeInternal::Fqn(fqn) => { + write!(f, "{}", fqn) + } + FunctionTypeInternal::VariantName(name) => { + write!(f, "{}", name) + } + } + } + } + + // A preliminary check of the arguments passed before typ inference + fn check_function_arguments( + function_name: &FunctionTypeInternal, + expected: &AnalysedType, + provided: &Expr, + ) -> Result<(), FunctionArgsTypeInferenceError> { + let is_valid = if provided.inferred_type().is_unknown() { + true + } else { + provided.inferred_type().get_type_kind() == expected.get_type_kind() + }; + + if is_valid { + Ok(()) + } else { + Err(FunctionArgsTypeInferenceError::type_mismatch( + function_name.clone(), + expected.clone(), + provided.clone(), + )) + } + } + + fn tag_argument_types( + function_name: &FunctionTypeInternal, + args: &mut [Expr], + parameter_types: &[AnalysedType], + ) -> Result<(), FunctionArgsTypeInferenceError> { + for (arg, param_type) in args.iter_mut().zip(parameter_types) { + check_function_arguments(function_name, param_type, arg)?; + arg.add_infer_type_mut(param_type.clone().into()); + } + + Ok(()) + } +} + +#[cfg(test)] +mod function_parameters_inference_tests { + use crate::call_type::CallType; + use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference}; + use crate::type_registry::FunctionTypeRegistry; + use crate::{Expr, InferredType, ParsedFunctionSite, VariableId}; + use golem_wasm_ast::analysis::{ + AnalysedExport, AnalysedFunction, AnalysedFunctionParameter, AnalysedType, TypeU32, TypeU64, + }; + + fn get_function_type_registry() -> FunctionTypeRegistry { + let metadata = vec![ + AnalysedExport::Function(AnalysedFunction { + name: "foo".to_string(), + parameters: vec![AnalysedFunctionParameter { + name: "my_parameter".to_string(), + typ: AnalysedType::U64(TypeU64), + }], + results: vec![], + }), + AnalysedExport::Function(AnalysedFunction { + name: "baz".to_string(), + parameters: vec![AnalysedFunctionParameter { + name: "my_parameter".to_string(), + typ: AnalysedType::U32(TypeU32), + }], + results: vec![], + }), + ]; + FunctionTypeRegistry::from_export_metadata(&metadata) + } + + #[test] + fn test_infer_function_types() { + let rib_expr = r#" + let x = 1; + foo(x) + "#; + + let function_type_registry = get_function_type_registry(); + + let mut expr = Expr::from_text(rib_expr).unwrap(); + expr.infer_call_arguments_type(&function_type_registry) + .unwrap(); + + let let_binding = Expr::let_binding("x", Expr::number(1f64)); + + let call_expr = Expr::Call( + CallType::Function(DynamicParsedFunctionName { + site: ParsedFunctionSite::Global, + function: DynamicParsedFunctionReference::Function { + function: "foo".to_string(), + }, + }), + vec![Expr::Identifier( + VariableId::global("x".to_string()), + InferredType::U64, // Call argument's types are updated + )], + InferredType::Sequence(vec![]), // Call Expressions return type is updated + ); + + let expected = Expr::Multiple(vec![let_binding, call_expr], InferredType::Unknown); + + assert_eq!(expr, expected); + } +} diff --git a/golem-rib/src/type_inference/enum_resolution.rs b/golem-rib/src/type_inference/enum_resolution.rs index 11226529f..58937995e 100644 --- a/golem-rib/src/type_inference/enum_resolution.rs +++ b/golem-rib/src/type_inference/enum_resolution.rs @@ -23,6 +23,7 @@ pub fn infer_enums(expr: &mut Expr, function_type_registry: &FunctionTypeRegistr mod internal { use crate::call_type::CallType; use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue}; + use golem_wasm_ast::analysis::AnalysedType; use std::collections::VecDeque; pub(crate) fn convert_identifiers_to_enum_function_calls( @@ -62,12 +63,13 @@ mod internal { match expr { Expr::Identifier(variable_id, inferred_type) => { // Retrieve the possible no-arg variant from the registry - let key = RegistryKey::EnumName(variable_id.name().clone()); - if let Some(RegistryValue::Value(analysed_type)) = + let key = RegistryKey::FunctionName(variable_id.name().clone()); + if let Some(RegistryValue::Value(AnalysedType::Enum(typed_enum))) = function_type_registry.types.get(&key) { enum_cases.push(variable_id.name()); - *inferred_type = inferred_type.merge(analysed_type.clone().into()); + *inferred_type = inferred_type + .merge(AnalysedType::Enum(typed_enum.clone()).clone().into()); } } diff --git a/golem-rib/src/type_inference/function_type_inference.rs b/golem-rib/src/type_inference/function_type_inference.rs deleted file mode 100644 index 484d195ec..000000000 --- a/golem-rib/src/type_inference/function_type_inference.rs +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2024 Golem Cloud -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::type_registry::FunctionTypeRegistry; -use crate::Expr; -use std::collections::VecDeque; - -pub fn infer_function_types( - expr: &mut Expr, - function_type_registry: &FunctionTypeRegistry, -) -> Result<(), String> { - let mut queue = VecDeque::new(); - queue.push_back(expr); - while let Some(expr) = queue.pop_back() { - match expr { - Expr::Call(parsed_fn_name, args, inferred_type) => { - internal::resolve_call_expressions( - parsed_fn_name, - function_type_registry, - args, - inferred_type, - )?; - } - _ => expr.visit_children_mut_bottom_up(&mut queue), - } - } - - Ok(()) -} - -mod internal { - use crate::call_type::CallType; - use crate::{ - Expr, FunctionTypeRegistry, InferredType, ParsedFunctionName, RegistryKey, RegistryValue, - }; - use golem_wasm_ast::analysis::AnalysedType; - use std::fmt::Display; - - pub(crate) fn resolve_call_expressions( - call_type: &mut CallType, - function_type_registry: &FunctionTypeRegistry, - args: &mut [Expr], - inferred_type: &mut InferredType, - ) -> Result<(), String> { - match call_type { - CallType::Function(dynamic_parsed_function_name) => { - let parsed_function_static = dynamic_parsed_function_name.clone().to_static(); - let function = parsed_function_static.clone().function; - if function.resource_name().is_some() { - let constructor_name = { - let raw_str = function.resource_name().ok_or("Resource name not found")?; - format!["[constructor]{}", raw_str] - }; - - let mut constructor_params: &mut Vec = &mut vec![]; - - if let Some(resource_params) = dynamic_parsed_function_name - .function - .raw_resource_params_mut() - { - constructor_params = resource_params - } - - let registry_key = RegistryKey::from_function_name( - &parsed_function_static.site, - constructor_name.as_str(), - ); - - // Infer the types of constructor parameter expressions - infer_types( - &FunctionNameInternal::ResourceConstructorName(constructor_name), - function_type_registry, - registry_key, - constructor_params, - inferred_type, - )?; - - // Infer the types of resource method parameters - let resource_method_name = function.function_name(); - let registry_key = RegistryKey::from_function_name( - &parsed_function_static.site, - resource_method_name.as_str(), - ); - - infer_types( - &FunctionNameInternal::ResourceMethodName(resource_method_name), - function_type_registry, - registry_key, - args, - inferred_type, - ) - } else { - let registry_key = RegistryKey::from_invocation_name(call_type); - infer_types( - &FunctionNameInternal::Fqn(parsed_function_static), - function_type_registry, - registry_key, - args, - inferred_type, - ) - } - } - - _ => Ok(()), - } - } - - fn infer_types( - function_name: &FunctionNameInternal, - function_type_registry: &FunctionTypeRegistry, - key: RegistryKey, - args: &mut [Expr], - inferred_type: &mut InferredType, - ) -> Result<(), String> { - if let Some(value) = function_type_registry.types.get(&key) { - match value { - RegistryValue::Value(_) => {} - RegistryValue::Function { - parameter_types, - return_types, - } => { - let mut parameter_types = parameter_types.clone(); - - if let FunctionNameInternal::ResourceMethodName(_) = function_name { - if let Some(AnalysedType::Handle(_)) = parameter_types.first() { - parameter_types.remove(0); - } - } - - if parameter_types.len() == args.len() { - for (arg, param_type) in args.iter_mut().zip(parameter_types) { - check_function_arguments(¶m_type, arg)?; - arg.add_infer_type_mut(param_type.clone().into()); - arg.push_types_down()? - } - - *inferred_type = { - if return_types.len() == 1 { - return_types[0].clone().into() - } else { - InferredType::Sequence( - return_types.iter().map(|t| t.clone().into()).collect(), - ) - } - } - } else { - return Err(format!( - "Function {} expects {} arguments, but {} were provided", - function_name, - parameter_types.len(), - args.len() - )); - } - } - } - } - - Ok(()) - } - - enum FunctionNameInternal { - ResourceConstructorName(String), - ResourceMethodName(String), - Fqn(ParsedFunctionName), - } - - impl Display for FunctionNameInternal { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FunctionNameInternal::ResourceConstructorName(name) => { - write!(f, "{}", name) - } - FunctionNameInternal::ResourceMethodName(name) => { - write!(f, "{}", name) - } - FunctionNameInternal::Fqn(name) => { - write!(f, "{}", name) - } - } - } - } - - // A preliminary check of the arguments passed before typ inference - pub(crate) fn check_function_arguments( - expected: &AnalysedType, - passed: &Expr, - ) -> Result<(), String> { - let valid_possibilities = passed.is_identifier() - || passed.is_select_field() - || passed.is_select_index() - || passed.is_select_field() - || passed.is_match_expr() - || passed.is_if_else(); - - match expected { - AnalysedType::U32(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected U32, but found {:?}", passed)) - } - } - - AnalysedType::U64(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected U64, but found {:?}", passed)) - } - } - - AnalysedType::Variant(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected Variant, but found {:?}", passed)) - } - } - - AnalysedType::Result(_) => { - if valid_possibilities || passed.is_result() { - Ok(()) - } else { - Err(format!("Expected Result, but found {:?}", passed)) - } - } - AnalysedType::Option(_) => { - if valid_possibilities || passed.is_option() { - Ok(()) - } else { - Err(format!("Expected Option, but found {:?}", passed)) - } - } - AnalysedType::Enum(_) => { - if valid_possibilities { - Ok(()) - } else { - Err(format!("Expected Enum, but found {:?}", passed)) - } - } - AnalysedType::Flags(_) => { - if valid_possibilities || passed.is_flags() { - Ok(()) - } else { - Err(format!("Expected Flags, but found {:?}", passed)) - } - } - AnalysedType::Record(_) => { - if valid_possibilities || passed.is_record() { - Ok(()) - } else { - Err(format!("Expected Record, but found {:?}", passed)) - } - } - AnalysedType::Tuple(_) => { - if valid_possibilities || passed.is_tuple() { - Ok(()) - } else { - Err(format!("Expected Tuple, but found {:?}", passed)) - } - } - AnalysedType::List(_) => { - if valid_possibilities || passed.is_list() { - Ok(()) - } else { - Err(format!("Expected List, but found {:?}", passed)) - } - } - AnalysedType::Str(_) => { - if valid_possibilities || passed.is_concat() || passed.is_literal() { - Ok(()) - } else { - Err(format!("Expected Str, but found {:?}", passed)) - } - } - // TODO? - AnalysedType::Chr(_) => { - if valid_possibilities || passed.is_literal() { - Ok(()) - } else { - Err(format!("Expected Chr, but found {:?}", passed)) - } - } - AnalysedType::F64(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected F64, but found {:?}", passed)) - } - } - AnalysedType::F32(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected F32, but found {:?}", passed)) - } - } - AnalysedType::S64(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected S64, but found {:?}", passed)) - } - } - AnalysedType::S32(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected S32, but found {:?}", passed)) - } - } - AnalysedType::U16(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected U16, but found {:?}", passed)) - } - } - AnalysedType::S16(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected S16, but found {:?}", passed)) - } - } - AnalysedType::U8(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected U8, but found {:?}", passed)) - } - } - AnalysedType::S8(_) => { - if valid_possibilities || passed.is_number() { - Ok(()) - } else { - Err(format!("Expected S8, but found {:?}", passed)) - } - } - AnalysedType::Bool(_) => { - if valid_possibilities || passed.is_boolean() || passed.is_comparison() { - Ok(()) - } else { - Err(format!("Expected Bool, but found {:?}", passed)) - } - } - AnalysedType::Handle(_) => { - if valid_possibilities { - Ok(()) - } else { - Err(format!("Expected Handle, but found {:?}", passed)) - } - } - } - } -} - -#[cfg(test)] -mod function_parameters_inference_tests { - use crate::call_type::CallType; - use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference}; - use crate::type_registry::FunctionTypeRegistry; - use crate::{Expr, InferredType, ParsedFunctionSite, VariableId}; - use golem_wasm_ast::analysis::{ - AnalysedExport, AnalysedFunction, AnalysedFunctionParameter, AnalysedType, TypeU32, TypeU64, - }; - - fn get_function_type_registry() -> FunctionTypeRegistry { - let metadata = vec![ - AnalysedExport::Function(AnalysedFunction { - name: "foo".to_string(), - parameters: vec![AnalysedFunctionParameter { - name: "my_parameter".to_string(), - typ: AnalysedType::U64(TypeU64), - }], - results: vec![], - }), - AnalysedExport::Function(AnalysedFunction { - name: "baz".to_string(), - parameters: vec![AnalysedFunctionParameter { - name: "my_parameter".to_string(), - typ: AnalysedType::U32(TypeU32), - }], - results: vec![], - }), - ]; - FunctionTypeRegistry::from_export_metadata(&metadata) - } - - #[test] - fn test_infer_function_types() { - let rib_expr = r#" - let x = 1; - foo(x) - "#; - - let function_type_registry = get_function_type_registry(); - - let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_function_types(&function_type_registry).unwrap(); - - let let_binding = Expr::let_binding("x", Expr::number(1f64)); - - let call_expr = Expr::Call( - CallType::Function(DynamicParsedFunctionName { - site: ParsedFunctionSite::Global, - function: DynamicParsedFunctionReference::Function { - function: "foo".to_string(), - }, - }), - vec![Expr::Identifier( - VariableId::global("x".to_string()), - InferredType::U64, // Call argument's types are updated - )], - InferredType::Sequence(vec![]), // Call Expressions return type is updated - ); - - let expected = Expr::Multiple(vec![let_binding, call_expr], InferredType::Unknown); - - assert_eq!(expr, expected); - } -} diff --git a/golem-rib/src/type_inference/kind.rs b/golem-rib/src/type_inference/kind.rs new file mode 100644 index 000000000..4c4386f02 --- /dev/null +++ b/golem-rib/src/type_inference/kind.rs @@ -0,0 +1,126 @@ +use crate::InferredType; +use golem_wasm_ast::analysis::AnalysedType; +use std::fmt::Display; + +pub trait GetTypeKind { + fn get_type_kind(&self) -> TypeKind; +} + +#[derive(PartialEq)] +pub enum TypeKind { + Record, + Tuple, + Flag, + Str, + Number, + List, + Boolean, + Option, + Enum, + Char, + Result, + Resource, + Variant, + Unknown, +} + +impl Display for TypeKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeKind::Record => write!(f, "record"), + TypeKind::Tuple => write!(f, "tuple"), + TypeKind::Flag => write!(f, "flag"), + TypeKind::Str => write!(f, "str"), + TypeKind::Number => write!(f, "number"), + TypeKind::List => write!(f, "list"), + TypeKind::Boolean => write!(f, "boolean"), + TypeKind::Option => write!(f, "option"), + TypeKind::Enum => write!(f, "enum"), + TypeKind::Char => write!(f, "chr"), + TypeKind::Result => write!(f, "result"), + TypeKind::Resource => write!(f, "resource"), + TypeKind::Variant => write!(f, "variant"), + TypeKind::Unknown => write!(f, "unknown"), + } + } +} + +impl GetTypeKind for AnalysedType { + fn get_type_kind(&self) -> TypeKind { + match self { + AnalysedType::Record(_) => TypeKind::Record, + AnalysedType::Tuple(_) => TypeKind::Tuple, + AnalysedType::Flags(_) => TypeKind::Flag, + AnalysedType::Str(_) => TypeKind::Str, + AnalysedType::S8(_) => TypeKind::Number, + AnalysedType::U8(_) => TypeKind::Number, + AnalysedType::S16(_) => TypeKind::Number, + AnalysedType::U16(_) => TypeKind::Number, + AnalysedType::S32(_) => TypeKind::Number, + AnalysedType::U32(_) => TypeKind::Number, + AnalysedType::S64(_) => TypeKind::Number, + AnalysedType::U64(_) => TypeKind::Number, + AnalysedType::F32(_) => TypeKind::Number, + AnalysedType::F64(_) => TypeKind::Number, + AnalysedType::Chr(_) => TypeKind::Char, + AnalysedType::List(_) => TypeKind::List, + AnalysedType::Bool(_) => TypeKind::Boolean, + AnalysedType::Option(_) => TypeKind::Option, + AnalysedType::Enum(_) => TypeKind::Enum, + AnalysedType::Result(_) => TypeKind::Result, + AnalysedType::Handle(_) => TypeKind::Resource, + AnalysedType::Variant(_) => TypeKind::Variant, + } + } +} + +impl GetTypeKind for InferredType { + fn get_type_kind(&self) -> TypeKind { + match self { + InferredType::Bool => TypeKind::Boolean, + InferredType::S8 => TypeKind::Number, + InferredType::U8 => TypeKind::Number, + InferredType::S16 => TypeKind::Number, + InferredType::U16 => TypeKind::Number, + InferredType::S32 => TypeKind::Number, + InferredType::U32 => TypeKind::Number, + InferredType::S64 => TypeKind::Number, + InferredType::U64 => TypeKind::Number, + InferredType::F32 => TypeKind::Number, + InferredType::F64 => TypeKind::Number, + InferredType::Chr => TypeKind::Char, + InferredType::Str => TypeKind::Str, + InferredType::List(_) => TypeKind::List, + InferredType::Tuple(_) => TypeKind::Tuple, + InferredType::Record(_) => TypeKind::Record, + InferredType::Flags(_) => TypeKind::Flag, + InferredType::Enum(_) => TypeKind::Enum, + InferredType::Option(_) => TypeKind::Option, + InferredType::Result { .. } => TypeKind::Result, + InferredType::Variant(_) => TypeKind::Variant, + InferredType::Resource { .. } => TypeKind::Resource, + InferredType::OneOf(possibilities) => internal::get_type_kind(possibilities), + InferredType::AllOf(possibilities) => internal::get_type_kind(possibilities), + InferredType::Unknown => TypeKind::Unknown, + InferredType::Sequence(_) => TypeKind::Unknown, + } + } +} + +mod internal { + use crate::type_inference::kind::{GetTypeKind, TypeKind}; + use crate::InferredType; + + pub(crate) fn get_type_kind(possibilities: &[InferredType]) -> TypeKind { + if let Some(first) = possibilities.first() { + let first = first.get_type_kind(); + if possibilities.iter().all(|p| p.get_type_kind() == first) { + first + } else { + TypeKind::Unknown + } + } else { + TypeKind::Unknown + } + } +} diff --git a/golem-rib/src/type_inference/mod.rs b/golem-rib/src/type_inference/mod.rs index 9f2b96182..36ad43937 100644 --- a/golem-rib/src/type_inference/mod.rs +++ b/golem-rib/src/type_inference/mod.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub use call_arguments_inference::*; pub use enum_resolution::*; pub use expr_visitor::*; -pub use function_type_inference::*; pub use global_input_inference::*; pub use identifier_inference::*; pub use inference_fix_point::*; @@ -29,8 +29,8 @@ pub use type_reset::*; pub use type_unification::*; pub use variant_resolution::*; +mod call_arguments_inference; mod expr_visitor; -mod function_type_inference; mod identifier_inference; mod name_binding; mod pattern_match_binding; @@ -45,6 +45,7 @@ mod variant_resolution; mod enum_resolution; mod global_input_inference; mod inference_fix_point; +pub(crate) mod kind; mod type_binding; #[cfg(test)] diff --git a/golem-rib/src/type_inference/rib_input_type.rs b/golem-rib/src/type_inference/rib_input_type.rs index 6b144e097..0b4a23ea9 100644 --- a/golem-rib/src/type_inference/rib_input_type.rs +++ b/golem-rib/src/type_inference/rib_input_type.rs @@ -20,6 +20,8 @@ use poem_openapi::Object; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, VecDeque}; +// RibInputTypeInfo refers to the required global inputs to a RibScript +// with its type information. Example: `request` variable which should be of the type `Record`. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Encode, Decode, Object)] pub struct RibInputTypeInfo { pub types: HashMap, diff --git a/golem-rib/src/type_inference/variant_resolution.rs b/golem-rib/src/type_inference/variant_resolution.rs index 298b214d3..28209d102 100644 --- a/golem-rib/src/type_inference/variant_resolution.rs +++ b/golem-rib/src/type_inference/variant_resolution.rs @@ -24,7 +24,8 @@ pub fn infer_variants(expr: &mut Expr, function_type_registry: &FunctionTypeRegi mod internal { use crate::call_type::CallType; - use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue}; + use crate::{Expr, FunctionTypeRegistry, InferredType, RegistryKey, RegistryValue}; + use golem_wasm_ast::analysis::AnalysedType; use std::collections::VecDeque; pub(crate) fn convert_function_calls_to_variant_calls( @@ -88,26 +89,25 @@ mod internal { while let Some(expr) = queue.pop_back() { match expr { Expr::Identifier(variable_id, inferred_type) => { - let key = RegistryKey::VariantName(variable_id.name().clone()); - if let Some(RegistryValue::Value(analysed_type)) = + let key = RegistryKey::FunctionName(variable_id.name().clone()); + if let Some(RegistryValue::Value(AnalysedType::Variant(type_variant))) = function_type_registry.types.get(&key) { no_arg_variants.push(variable_id.name()); - *inferred_type = inferred_type.merge(analysed_type.clone().into()); + *inferred_type = + inferred_type.merge(InferredType::from_variant_cases(type_variant)); } } Expr::Call(CallType::Function(parsed_function_name), exprs, inferred_type) => { - let key = RegistryKey::VariantName(parsed_function_name.to_string()); - if let Some(RegistryValue::Function { return_types, .. }) = + let key = RegistryKey::FunctionName(parsed_function_name.to_string()); + if let Some(RegistryValue::Variant { variant_type, .. }) = function_type_registry.types.get(&key) { - variant_with_args.push(parsed_function_name.to_string()); + let variant_inferred_type = InferredType::from_variant_cases(variant_type); + *inferred_type = inferred_type.merge(variant_inferred_type); - // TODO; return type is only 1 in reality for variants - we can make this typed - if let Some(variant_type) = return_types.first() { - *inferred_type = inferred_type.merge(variant_type.clone().into()); - } + variant_with_args.push(parsed_function_name.to_string()); } for expr in exprs { diff --git a/golem-rib/src/type_registry.rs b/golem-rib/src/type_registry.rs index dad0f6a04..feee91b9a 100644 --- a/golem-rib/src/type_registry.rs +++ b/golem-rib/src/type_registry.rs @@ -14,8 +14,8 @@ use crate::call_type::CallType; use crate::ParsedFunctionSite; -use golem_wasm_ast::analysis::AnalysedExport; use golem_wasm_ast::analysis::AnalysedType; +use golem_wasm_ast::analysis::{AnalysedExport, TypeVariant}; use std::collections::{HashMap, HashSet}; // A type-registry is a mapping from a function name (global or part of an interface in WIT) @@ -29,8 +29,6 @@ use std::collections::{HashMap, HashSet}; // then the RegistryValue is simply an AnalysedType representing the variant type itself. #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub enum RegistryKey { - VariantName(String), - EnumName(String), FunctionName(String), FunctionNameWithInterface { interface_name: String, @@ -51,9 +49,9 @@ impl RegistryKey { pub fn from_invocation_name(invocation_name: &CallType) -> RegistryKey { match invocation_name { CallType::VariantConstructor(variant_name) => { - RegistryKey::VariantName(variant_name.clone()) + RegistryKey::FunctionName(variant_name.clone()) } - CallType::EnumConstructor(enum_name) => RegistryKey::EnumName(enum_name.clone()), + CallType::EnumConstructor(enum_name) => RegistryKey::FunctionName(enum_name.clone()), CallType::Function(function_name) => match function_name.site.interface_name() { None => RegistryKey::FunctionName(function_name.function_name()), Some(interface_name) => RegistryKey::FunctionNameWithInterface { @@ -68,6 +66,10 @@ impl RegistryKey { #[derive(PartialEq, Clone, Debug)] pub enum RegistryValue { Value(AnalysedType), + Variant { + parameter_types: Vec, + variant_type: TypeVariant, + }, Function { parameter_types: Vec, return_types: Vec, @@ -117,16 +119,16 @@ impl FunctionTypeRegistry { }) .collect::>(); - let registry_value = RegistryValue::Function { - parameter_types, - return_types, - }; - let registry_key = RegistryKey::FunctionNameWithInterface { interface_name: interface_name.clone(), function_name: function_name.clone(), }; + let registry_value = RegistryValue::Function { + parameter_types, + return_types, + }; + map.insert(registry_key, registry_value); } } @@ -179,7 +181,7 @@ impl FunctionTypeRegistry { mod internal { use crate::{RegistryKey, RegistryValue}; - use golem_wasm_ast::analysis::AnalysedType; + use golem_wasm_ast::analysis::{AnalysedType, TypeResult}; use std::collections::HashMap; pub(crate) fn update_registry( @@ -188,13 +190,14 @@ mod internal { ) { match ty.clone() { AnalysedType::Variant(variant) => { - for name_type_pair in variant.cases { - registry.insert(RegistryKey::VariantName(name_type_pair.name.clone()), { - name_type_pair.typ.map_or( + let type_variant = variant.clone(); + for name_type_pair in &type_variant.cases { + registry.insert(RegistryKey::FunctionName(name_type_pair.name.clone()), { + name_type_pair.typ.clone().map_or( RegistryValue::Value(ty.clone()), - |variant_parameter_typ| RegistryValue::Function { + |variant_parameter_typ| RegistryValue::Variant { parameter_types: vec![variant_parameter_typ], - return_types: vec![ty.clone()], + variant_type: type_variant.clone(), }, ) }); @@ -204,7 +207,7 @@ mod internal { AnalysedType::Enum(type_enum) => { for name_type_pair in type_enum.cases { registry.insert( - RegistryKey::EnumName(name_type_pair.clone()), + RegistryKey::FunctionName(name_type_pair.clone()), RegistryValue::Value(ty.clone()), ); } @@ -226,7 +229,47 @@ mod internal { } } - _ => {} + AnalysedType::Result(TypeResult { + ok: Some(ok_type), + err: Some(err_type), + }) => { + update_registry(ok_type.as_ref(), registry); + update_registry(err_type.as_ref(), registry); + } + AnalysedType::Result(TypeResult { + ok: None, + err: Some(err_type), + }) => { + update_registry(err_type.as_ref(), registry); + } + AnalysedType::Result(TypeResult { + ok: Some(ok_type), + err: None, + }) => { + update_registry(ok_type.as_ref(), registry); + } + AnalysedType::Option(type_option) => { + update_registry(type_option.inner.as_ref(), registry); + } + AnalysedType::Result(TypeResult { + ok: None, + err: None, + }) => {} + AnalysedType::Flags(_) => {} + AnalysedType::Str(_) => {} + AnalysedType::Chr(_) => {} + AnalysedType::F64(_) => {} + AnalysedType::F32(_) => {} + AnalysedType::U64(_) => {} + AnalysedType::S64(_) => {} + AnalysedType::U32(_) => {} + AnalysedType::S32(_) => {} + AnalysedType::U16(_) => {} + AnalysedType::S16(_) => {} + AnalysedType::U8(_) => {} + AnalysedType::S8(_) => {} + AnalysedType::Bool(_) => {} + AnalysedType::Handle(_) => {} } } } diff --git a/golem-worker-service-base/src/api/custom_http_request_api.rs b/golem-worker-service-base/src/api/custom_http_request_api.rs index dc273e645..23ebec6ea 100644 --- a/golem-worker-service-base/src/api/custom_http_request_api.rs +++ b/golem-worker-service-base/src/api/custom_http_request_api.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::sync::Arc; use crate::api_definition::http::CompiledHttpApiDefinition; -use crate::worker_service_rib_interpreter::{DefaultEvaluator, WorkerServiceRibInterpreter}; +use crate::worker_service_rib_interpreter::{DefaultRibInterpreter, WorkerServiceRibInterpreter}; use futures_util::FutureExt; use hyper::header::HOST; use poem::http::StatusCode; @@ -19,7 +19,7 @@ use crate::worker_bridge_execution::WorkerRequestExecutor; // This is a common API projects can make use of, similar to healthcheck service #[derive(Clone)] pub struct CustomHttpRequestApi { - pub evaluator: Arc, + pub worker_service_rib_interpreter: Arc, pub api_definition_lookup_service: Arc + Sync + Send>, } @@ -31,12 +31,12 @@ impl CustomHttpRequestApi { dyn ApiDefinitionsLookup + Sync + Send, >, ) -> Self { - let evaluator = Arc::new(DefaultEvaluator::from_worker_request_executor( + let evaluator = Arc::new(DefaultRibInterpreter::from_worker_request_executor( worker_request_executor_service.clone(), )); Self { - evaluator, + worker_service_rib_interpreter: evaluator, api_definition_lookup_service, } } @@ -71,7 +71,7 @@ impl CustomHttpRequestApi { } }; - let api_request = InputHttpRequest { + let input_http_request = InputHttpRequest { input_path: ApiInputPath { base_path: uri.path().to_string(), query_path: uri.query().map(|x| x.to_string()), @@ -83,25 +83,28 @@ impl CustomHttpRequestApi { let possible_api_definitions = match self .api_definition_lookup_service - .get(api_request.clone()) + .get(input_http_request.clone()) .await { - Ok(api_definition) => api_definition, - Err(err) => { - error!("API request host: {} - error: {}", host, err); + Ok(api_defs) => api_defs, + Err(api_defs_lookup_error) => { + error!( + "API request host: {} - error: {}", + host, api_defs_lookup_error + ); return Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from_string("Internal error".to_string())); } }; - match api_request + match input_http_request .resolve_worker_binding(possible_api_definitions) .await { - Ok(resolved_worker_request) => { - resolved_worker_request - .interpret_response_mapping::(&self.evaluator) + Ok(resolved_worker_binding) => { + resolved_worker_binding + .interpret_response_mapping(&self.worker_service_rib_interpreter) .await } diff --git a/golem-worker-service-base/src/api/register_api_definition_api.rs b/golem-worker-service-base/src/api/register_api_definition_api.rs index d47f9aa55..aa9d6344f 100644 --- a/golem-worker-service-base/src/api/register_api_definition_api.rs +++ b/golem-worker-service-base/src/api/register_api_definition_api.rs @@ -158,7 +158,7 @@ impl From for GolemWorkerBindingWithTypeInfo { .response_rib_expr .to_string(), response_mapping_input: Some(worker_binding.response_compiled.rib_input), - worker_name_input: Some(worker_binding.worker_name_compiled.rib_input), + worker_name_input: Some(worker_binding.worker_name_compiled.rib_input_type_info), idempotency_key_input: value .idempotency_key_compiled .map(|idempotency_key_compiled| idempotency_key_compiled.rib_input), diff --git a/golem-worker-service-base/src/api_definition/http/http_api_definition.rs b/golem-worker-service-base/src/api_definition/http/http_api_definition.rs index 59dda81e5..b374f005c 100644 --- a/golem-worker-service-base/src/api_definition/http/http_api_definition.rs +++ b/golem-worker-service-base/src/api_definition/http/http_api_definition.rs @@ -80,6 +80,11 @@ impl From for HttpApiDefinition { } } +// The Rib Expressions that exists in various parts of HttpApiDefinition (mainly in Routes) +// are compiled to form CompiledHttpApiDefinition. +// The Compilation happens during API definition registration, +// and is persisted, so that custom http requests are served by looking up +// CompiledHttpApiDefinition #[derive(Debug, Clone, PartialEq)] pub struct CompiledHttpApiDefinition { pub id: ApiDefinitionId, diff --git a/golem-worker-service-base/src/http/http_request.rs b/golem-worker-service-base/src/http/http_request.rs index 2c046c2a5..32ad75cbf 100644 --- a/golem-worker-service-base/src/http/http_request.rs +++ b/golem-worker-service-base/src/http/http_request.rs @@ -133,7 +133,7 @@ mod tests { WorkerRequest, WorkerRequestExecutor, WorkerRequestExecutorError, WorkerResponse, }; use crate::worker_service_rib_interpreter::{ - DefaultEvaluator, EvaluationError, WorkerServiceRibInterpreter, + DefaultRibInterpreter, EvaluationError, WorkerServiceRibInterpreter, }; struct TestWorkerRequestExecutor {} @@ -249,9 +249,9 @@ mod tests { } fn get_test_evaluator() -> Arc { - Arc::new(DefaultEvaluator::from_worker_request_executor(Arc::new( - TestWorkerRequestExecutor {}, - ))) + Arc::new(DefaultRibInterpreter::from_worker_request_executor( + Arc::new(TestWorkerRequestExecutor {}), + )) } #[derive(Debug)] diff --git a/golem-worker-service-base/src/lib.rs b/golem-worker-service-base/src/lib.rs index 07fa51638..0e58fe7ec 100644 --- a/golem-worker-service-base/src/lib.rs +++ b/golem-worker-service-base/src/lib.rs @@ -13,6 +13,7 @@ pub mod repo; pub mod service; mod worker_binding; pub mod worker_bridge_execution; +mod worker_service_rib_compiler; pub mod worker_service_rib_interpreter; const VERSION: &str = golem_version!(); diff --git a/golem-worker-service-base/src/worker_binding/compiled_golem_worker_binding.rs b/golem-worker-service-base/src/worker_binding/compiled_golem_worker_binding.rs index 48f9d5eb5..85e6c111e 100644 --- a/golem-worker-service-base/src/worker_binding/compiled_golem_worker_binding.rs +++ b/golem-worker-service-base/src/worker_binding/compiled_golem_worker_binding.rs @@ -1,4 +1,5 @@ -use crate::worker_binding::{compile_rib, GolemWorkerBinding, ResponseMapping}; +use crate::worker_binding::{GolemWorkerBinding, ResponseMapping}; +use crate::worker_service_rib_compiler::{DefaultRibCompiler, WorkerServiceRibCompiler}; use bincode::{Decode, Encode}; use golem_service_base::model::VersionedComponentId; use golem_wasm_ast::analysis::AnalysedExport; @@ -15,7 +16,7 @@ pub struct CompiledGolemWorkerBinding { impl CompiledGolemWorkerBinding { pub fn from_golem_worker_binding( golem_worker_binding: &GolemWorkerBinding, - export_metadata: &Vec, + export_metadata: &[AnalysedExport], ) -> Result { let worker_name_compiled = WorkerNameCompiled::from_worker_name( &golem_worker_binding.worker_name, @@ -46,20 +47,20 @@ impl CompiledGolemWorkerBinding { pub struct WorkerNameCompiled { pub worker_name: Expr, pub compiled_worker_name: RibByteCode, - pub rib_input: RibInputTypeInfo, + pub rib_input_type_info: RibInputTypeInfo, } impl WorkerNameCompiled { pub fn from_worker_name( worker_name: &Expr, - exports: &Vec, + exports: &[AnalysedExport], ) -> Result { - let worker_name_compiled = compile_rib(worker_name, exports)?; + let worker_name_compiled = DefaultRibCompiler::compile(worker_name, exports)?; Ok(WorkerNameCompiled { worker_name: worker_name.clone(), compiled_worker_name: worker_name_compiled.byte_code, - rib_input: worker_name_compiled.global_input_type_info, + rib_input_type_info: worker_name_compiled.global_input_type_info, }) } } @@ -74,9 +75,9 @@ pub struct IdempotencyKeyCompiled { impl IdempotencyKeyCompiled { pub fn from_idempotency_key( idempotency_key: &Expr, - exports: &Vec, + exports: &[AnalysedExport], ) -> Result { - let idempotency_key_compiled = compile_rib(idempotency_key, exports)?; + let idempotency_key_compiled = DefaultRibCompiler::compile(idempotency_key, exports)?; Ok(IdempotencyKeyCompiled { idempotency_key: idempotency_key.clone(), @@ -96,9 +97,9 @@ pub struct ResponseMappingCompiled { impl ResponseMappingCompiled { pub fn from_response_mapping( response_mapping: &ResponseMapping, - exports: &Vec, + exports: &[AnalysedExport], ) -> Result { - let response_compiled = compile_rib(&response_mapping.0, exports)?; + let response_compiled = DefaultRibCompiler::compile(&response_mapping.0, exports)?; Ok(ResponseMappingCompiled { response_rib_expr: response_mapping.0.clone(), @@ -152,7 +153,7 @@ impl TryFrom .ok_or("Missing worker name".to_string()) .and_then(Expr::try_from)?, compiled_worker_name: worker_name_compiled, - rib_input: worker_name_input, + rib_input_type_info: worker_name_input, }; let idempotency_key_compiled = match (idempotency_key_compiled, idempotency_key_input) { @@ -196,7 +197,7 @@ impl TryFrom let worker_name = Some(value.worker_name_compiled.worker_name.into()); let compiled_worker_name_expr = Some(value.worker_name_compiled.compiled_worker_name.into()); - let worker_name_rib_input = Some(value.worker_name_compiled.rib_input.into()); + let worker_name_rib_input = Some(value.worker_name_compiled.rib_input_type_info.into()); let (idempotency_key, compiled_idempotency_key_expr, idempotency_key_rib_input) = match value.idempotency_key_compiled { Some(x) => ( diff --git a/golem-worker-service-base/src/worker_binding/mod.rs b/golem-worker-service-base/src/worker_binding/mod.rs index b283f7c1d..d9c88a9af 100644 --- a/golem-worker-service-base/src/worker_binding/mod.rs +++ b/golem-worker-service-base/src/worker_binding/mod.rs @@ -1,8 +1,6 @@ pub(crate) use compiled_golem_worker_binding::*; -use golem_wasm_ast::analysis::AnalysedExport; pub(crate) use golem_worker_binding::*; pub(crate) use request_details::*; -use rib::{CompilerOutput, Expr}; pub(crate) use rib_input_value_resolver::*; pub(crate) use worker_binding_resolver::*; @@ -11,14 +9,3 @@ mod golem_worker_binding; mod request_details; mod rib_input_value_resolver; mod worker_binding_resolver; - -pub fn compile_rib( - worker_name: &Expr, - export_metadata: &Vec, -) -> Result { - rib::compile_with_limited_globals( - worker_name, - export_metadata, - Some(vec!["request".to_string()]), - ) -} diff --git a/golem-worker-service-base/src/worker_binding/rib_input_value_resolver.rs b/golem-worker-service-base/src/worker_binding/rib_input_value_resolver.rs index c61dc9091..2cf02a47d 100644 --- a/golem-worker-service-base/src/worker_binding/rib_input_value_resolver.rs +++ b/golem-worker-service-base/src/worker_binding/rib_input_value_resolver.rs @@ -5,6 +5,10 @@ use rib::RibInputTypeInfo; use std::collections::HashMap; use std::fmt::Display; +// `RibInputValueResolver` is responsible +// for converting to RibInputValue which is in the right shape +// to act as input for Rib Script. Example: HttpRequestDetails +// can be converted to RibInputValue pub trait RibInputValueResolver { fn resolve_rib_input_value( &self, diff --git a/golem-worker-service-base/src/worker_binding/worker_binding_resolver.rs b/golem-worker-service-base/src/worker_binding/worker_binding_resolver.rs index 4d7e3babd..f9d85b7e8 100644 --- a/golem-worker-service-base/src/worker_binding/worker_binding_resolver.rs +++ b/golem-worker-service-base/src/worker_binding/worker_binding_resolver.rs @@ -2,8 +2,8 @@ use crate::api_definition::http::{CompiledHttpApiDefinition, VarInfo}; use crate::http::http_request::router; use crate::http::router::RouterPattern; use crate::http::InputHttpRequest; +use crate::worker_service_rib_interpreter::EvaluationError; use crate::worker_service_rib_interpreter::WorkerServiceRibInterpreter; -use crate::worker_service_rib_interpreter::{DefaultEvaluator, EvaluationError}; use async_trait::async_trait; use golem_common::model::IdempotencyKey; use golem_service_base::model::VersionedComponentId; @@ -17,12 +17,9 @@ use crate::worker_binding::rib_input_value_resolver::RibInputValueResolver; use crate::worker_binding::{RequestDetails, ResponseMappingCompiled, RibInputTypeMismatch}; use crate::worker_bridge_execution::to_response::ToResponse; -// Every request (http or others) can have an instance of this resolver -// to resolve to a single worker-binding with additional-info which is required for worker_service_rib_interpreter -// TODO; It will be better if worker binding resolver -// able to deal with only one API definition -// as the first stage resolution can take place (based on host, input request (route resolution) -// up the stage +// Every type of request (example: InputHttpRequest (which corresponds to a Route)) can have an instance of this resolver, +// to resolve a single worker-binding is then executed with the help of worker_service_rib_interpreter, which internally +// calls the worker function. #[async_trait] pub trait RequestToWorkerBindingResolver { async fn resolve_worker_binding( @@ -127,17 +124,15 @@ impl ResolvedWorkerBindingFromRequest { impl RequestToWorkerBindingResolver for InputHttpRequest { async fn resolve_worker_binding( &self, - api_definition: Vec, + compiled_api_definitions: Vec, ) -> Result { - let default_evaluator = DefaultEvaluator::noop(); - - let routes = api_definition + let compiled_routes = compiled_api_definitions .iter() .flat_map(|x| x.routes.clone()) .collect::>(); let api_request = self; - let router = router::build(routes); + let router = router::build(compiled_routes); let path: Vec<&str> = RouterPattern::split(&api_request.input_path.base_path).collect(); let request_query_variables = self.input_path.query_components().unwrap_or_default(); let request_body = &self.req_body; @@ -158,7 +153,7 @@ impl RequestToWorkerBindingResolver for InputHttpRequ .collect() }; - let request_details = RequestDetails::from( + let http_request_details = RequestDetails::from( &zipped_path_params, &request_query_variables, query_params, @@ -167,33 +162,36 @@ impl RequestToWorkerBindingResolver for InputHttpRequ ) .map_err(|err| format!("Failed to fetch input request details {}", err.join(", ")))?; - let resolve_rib_input = request_details - .resolve_rib_input_value(&binding.worker_name_compiled.rib_input) - .map_err(|err| format!("Failed to resolve rib input value {}", err))?; + let resolve_rib_input = http_request_details + .resolve_rib_input_value(&binding.worker_name_compiled.rib_input_type_info) + .map_err(|err| { + format!( + "Failed to resolve rib input value from http request details {}", + err + ) + })?; // To evaluate worker-name, most probably - let worker_name: String = default_evaluator - .evaluate_pure( - &binding.worker_name_compiled.compiled_worker_name, - &resolve_rib_input, - ) - .await - .map_err(|err| format!("Failed to evaluate worker name expression. {}", err))? - .get_literal() - .ok_or("Worker name is not a String".to_string())? - .as_string(); + let worker_name: String = rib::interpret_pure( + &binding.worker_name_compiled.compiled_worker_name, + &resolve_rib_input.value, + ) + .await + .map_err(|err| format!("Failed to evaluate worker name rib expression. {}", err))? + .get_literal() + .ok_or("Worker name is not a Rib expression that resolves to String".to_string())? + .as_string(); let component_id = &binding.component_id; let idempotency_key = if let Some(idempotency_key_compiled) = &binding.idempotency_key_compiled { - let idempotency_key_value = default_evaluator - .evaluate_pure( - &idempotency_key_compiled.compiled_idempotency_key, - &resolve_rib_input, - ) - .await - .map_err(|err| err.to_string())?; + let idempotency_key_value = rib::interpret_pure( + &idempotency_key_compiled.compiled_idempotency_key, + &resolve_rib_input.value, + ) + .await + .map_err(|err| err.to_string())?; let idempotency_key = idempotency_key_value .get_literal() @@ -216,7 +214,7 @@ impl RequestToWorkerBindingResolver for InputHttpRequ let resolved_binding = ResolvedWorkerBindingFromRequest { worker_detail, - request_details, + request_details: http_request_details, compiled_response_mapping: binding.response_compiled.clone(), }; diff --git a/golem-worker-service-base/src/worker_bridge_execution/worker_request_executor.rs b/golem-worker-service-base/src/worker_bridge_execution/worker_request_executor.rs index 52d5b243d..40488672e 100644 --- a/golem-worker-service-base/src/worker_bridge_execution/worker_request_executor.rs +++ b/golem-worker-service-base/src/worker_bridge_execution/worker_request_executor.rs @@ -38,17 +38,3 @@ impl> From for WorkerRequestExecutorError { WorkerRequestExecutorError(err.as_ref().to_string()) } } - -pub struct NoopWorkerRequestExecutor; - -#[async_trait] -impl WorkerRequestExecutor for NoopWorkerRequestExecutor { - async fn execute( - &self, - _worker_request_params: WorkerRequest, - ) -> Result { - Err(WorkerRequestExecutorError( - "NoopWorkerRequestExecutor".to_string(), - )) - } -} diff --git a/golem-worker-service-base/src/worker_service_rib_compiler/mod.rs b/golem-worker-service-base/src/worker_service_rib_compiler/mod.rs new file mode 100644 index 000000000..14ca2738b --- /dev/null +++ b/golem-worker-service-base/src/worker_service_rib_compiler/mod.rs @@ -0,0 +1,20 @@ +use golem_wasm_ast::analysis::AnalysedExport; +use rib::{CompilerOutput, Expr}; + +// A wrapper service over original Rib Compiler concerning +// the details of the worker bridge. +pub trait WorkerServiceRibCompiler { + fn compile(rib: &Expr, export_metadata: &[AnalysedExport]) -> Result; +} + +pub struct DefaultRibCompiler; + +impl WorkerServiceRibCompiler for DefaultRibCompiler { + fn compile(rib: &Expr, export_metadata: &[AnalysedExport]) -> Result { + rib::compile_with_limited_globals( + rib, + &export_metadata.to_vec(), + Some(vec!["request".to_string()]), + ) + } +} diff --git a/golem-worker-service-base/src/worker_service_rib_interpreter/mod.rs b/golem-worker-service-base/src/worker_service_rib_interpreter/mod.rs index f50c5ef3c..0a735c240 100644 --- a/golem-worker-service-base/src/worker_service_rib_interpreter/mod.rs +++ b/golem-worker-service-base/src/worker_service_rib_interpreter/mod.rs @@ -10,14 +10,14 @@ use golem_common::model::{ComponentId, IdempotencyKey}; use crate::worker_binding::RibInputValue; use rib::{RibByteCode, RibFunctionInvoke, RibInterpreterResult}; -use crate::worker_bridge_execution::{ - NoopWorkerRequestExecutor, WorkerRequest, WorkerRequestExecutor, -}; +use crate::worker_bridge_execution::{WorkerRequest, WorkerRequestExecutor}; // A wrapper service over original RibInterpreter concerning // the details of the worker service. #[async_trait] pub trait WorkerServiceRibInterpreter { + // Evaluate a Rib byte against a specific worker. + // RibByteCode may have actual function calls. async fn evaluate( &self, worker_name: &str, @@ -26,12 +26,6 @@ pub trait WorkerServiceRibInterpreter { rib_byte_code: &RibByteCode, rib_input: &RibInputValue, ) -> Result; - - async fn evaluate_pure( - &self, - expr: &RibByteCode, - rib_input: &RibInputValue, - ) -> Result; } #[derive(Debug, PartialEq)] @@ -49,28 +43,22 @@ impl From for EvaluationError { } } -pub struct DefaultEvaluator { +pub struct DefaultRibInterpreter { worker_request_executor: Arc, } -impl DefaultEvaluator { - pub fn noop() -> Self { - DefaultEvaluator { - worker_request_executor: Arc::new(NoopWorkerRequestExecutor), - } - } - +impl DefaultRibInterpreter { pub fn from_worker_request_executor( worker_request_executor: Arc, ) -> Self { - DefaultEvaluator { + DefaultRibInterpreter { worker_request_executor, } } } #[async_trait] -impl WorkerServiceRibInterpreter for DefaultEvaluator { +impl WorkerServiceRibInterpreter for DefaultRibInterpreter { async fn evaluate( &self, worker_name: &str, @@ -115,23 +103,4 @@ impl WorkerServiceRibInterpreter for DefaultEvaluator { .await .map_err(EvaluationError) } - - async fn evaluate_pure( - &self, - expr: &RibByteCode, - rib_input: &RibInputValue, - ) -> Result { - let worker_invoke_function: RibFunctionInvoke = Arc::new(|_, _| { - Box::pin( - async move { - Err("Worker invoke function is not allowed in pure evaluation".to_string()) - } - .boxed(), - ) - }); - - rib::interpret(expr, rib_input.value.clone(), worker_invoke_function) - .await - .map_err(EvaluationError) - } }