diff --git a/src/lustre/generatedIdentifiers.ml b/src/lustre/generatedIdentifiers.ml index 96a66c487..91a0b7b74 100644 --- a/src/lustre/generatedIdentifiers.ml +++ b/src/lustre/generatedIdentifiers.ml @@ -39,6 +39,7 @@ type t = { locals : (LustreAst.lustre_type) StringMap.t; + asserts : (Lib.position * LustreAst.expr) list; contract_calls : (Lib.position * (Lib.position * HString.t) list (* contract scope *) @@ -52,7 +53,8 @@ type t = { * LustreAst.expr (* restart expression *) * HString.t (* node name *) * (LustreAst.expr list) (* node arguments *) - * (LustreAst.expr list option)) (* node argument defaults *) + * (LustreAst.expr list option) (* node argument defaults *) + * bool) (* Was call inlined? *) list; subrange_constraints : (source * (Lib.position * HString.t) list (* contract scope *) @@ -104,6 +106,7 @@ let union_keys key id1 id2 = match key, id1, id2 with let union ids1 ids2 = { locals = StringMap.merge union_keys ids1.locals ids2.locals; + asserts = ids1.asserts @ ids2.asserts; array_constructors = StringMap.merge union_keys ids1.array_constructors ids2.array_constructors; node_args = ids1.node_args @ ids2.node_args; @@ -130,6 +133,7 @@ let union_keys2 key id1 id2 = match key, id1, id2 with let empty () = { locals = StringMap.empty; + asserts = []; array_constructors = StringMap.empty; node_args = []; oracles = []; diff --git a/src/lustre/generatedIdentifiers.mli b/src/lustre/generatedIdentifiers.mli index 3f9e34e33..4e950a497 100644 --- a/src/lustre/generatedIdentifiers.mli +++ b/src/lustre/generatedIdentifiers.mli @@ -39,6 +39,7 @@ type t = { locals : (LustreAst.lustre_type) StringMap.t; + asserts : (Lib.position * LustreAst.expr) list; contract_calls : (Lib.position * (Lib.position * HString.t) list (* contract scope *) @@ -52,7 +53,8 @@ type t = { * LustreAst.expr (* restart expression *) * HString.t (* node name *) * (LustreAst.expr list) (* node arguments *) - * (LustreAst.expr list option)) (* node argument defaults *) + * (LustreAst.expr list option) (* node argument defaults *) + * bool) (* Was call inlined? *) list; subrange_constraints : (source * (Lib.position * HString.t) list (* contract scope *) diff --git a/src/lustre/lustreAstHelpers.mli b/src/lustre/lustreAstHelpers.mli index ce9c0eac4..949736fc5 100644 --- a/src/lustre/lustreAstHelpers.mli +++ b/src/lustre/lustreAstHelpers.mli @@ -55,6 +55,11 @@ val substitute_naive : HString.t -> expr -> expr -> expr (** Substitute second param for first param in third param. AnyOp and Quantifier are not supported due to introduction of bound variables. *) +val apply_subst_in_expr : (HString.t * expr) list -> expr -> expr +(** [apply_subst_in_expr s e] applies the substitution defined by association list [s] + to the expression [e] + AnyOp and Quantifier are not supported due to introduction of bound variables. *) + val apply_subst_in_type : (HString.t * expr) list -> lustre_type -> lustre_type (** [apply_subst_in_type s t] applies the substitution defined by association list [s] to the expressions of (possibly dependent) type [t] @@ -187,4 +192,4 @@ val rename_contract_vars : expr -> expr val name_of_prop : Lib.position -> HString.t option -> LustreAst.prop_kind -> HString.t (** Get the name associated with a property *) -val get_const_num_value : expr -> int option \ No newline at end of file +val get_const_num_value : expr -> int option diff --git a/src/lustre/lustreAstNormalizer.ml b/src/lustre/lustreAstNormalizer.ml index 116267009..75ce699f0 100644 --- a/src/lustre/lustreAstNormalizer.ml +++ b/src/lustre/lustreAstNormalizer.ml @@ -203,7 +203,8 @@ type info = { contract_scope : (Lib.position * HString.t) list; contract_ref : HString.t; interpretation : HString.t StringMap.t; - local_group_projection : int + local_group_projection : int; + inlinable_funcs : LustreAst.node_decl StringMap.t; } let split3 triples = @@ -241,16 +242,17 @@ let pp_print_generated_identifiers ppf gids = LustreAst.pp_print_lustre_type ty LustreAst.pp_print_expr e in - let pp_print_call = (fun ppf (pos, output, cond, restart, ident, args, defaults) -> + let pp_print_call = (fun ppf (pos, output, cond, restart, ident, args, defaults, inlined) -> Format.fprintf ppf - "%a: %a = call(%a,(restart %a every %a)(%a),%a)" + "%a: %a = call(%a,(restart %a every %a)(%a),%a)%s" pp_print_position pos HString.pp_print_hstring output A.pp_print_expr cond HString.pp_print_hstring ident A.pp_print_expr restart (pp_print_list A.pp_print_expr ",@ ") args - (pp_print_option (pp_print_list A.pp_print_expr ",@")) defaults) + (pp_print_option (pp_print_list A.pp_print_expr ",@")) defaults + (if inlined then " %inlined" else "")) in let pp_print_source ppf source = Format.fprintf ppf (match source with | Local -> "local" @@ -354,21 +356,39 @@ let generalize_to_array_expr name ind_vars expr nexpr = in eq_lhs, nexpr -let mk_fresh_local force info pos ind_vars expr_type expr = - match (LocalCache.find_opt local_cache expr, force) with - | Some nexpr, false -> nexpr, empty () - | _ -> - i := !i + 1; - let prefix = HString.mk_hstring (string_of_int !i) in - let name = HString.concat2 prefix (HString.mk_hstring "_glocal") in - let nexpr = A.Ident (pos, name) in - let (eq_lhs, nexpr) = generalize_to_array_expr name ind_vars expr nexpr in - let gids = { (empty ()) with - locals = StringMap.singleton name expr_type; - equations = [(info.quantified_variables, info.contract_scope, eq_lhs, expr)]; } +let get_inline_func_expr inlinable_funcs name args = + let (_, _, _, _, inputs, _, _, items, _) : A.node_decl = + match StringMap.find_opt name inlinable_funcs with + | Some nd -> nd + | None -> assert false in - LocalCache.add local_cache expr nexpr; - nexpr, gids + let var_map = + items |> List.fold_left (fun acc item -> + match item with + | A.Body (Equation (_, StructDef (_, [lhs]), rhs)) -> ( + match lhs with + | A.SingleIdent (_, v) -> + (v, AH.apply_subst_in_expr acc rhs) :: acc + | ArrayDef _ -> + assert false (* rejected earlier in pipeline *) + | TupleStructItem _ | TupleSelection _ | FieldSelection _ + | ArraySliceStructItem _ -> assert false (* unreachable *) + ) + | IfBlock _ | FrameBlock _ -> + assert false (* desugared earlier in pipeline *) + | Body (Assert _) | AnnotMain _ | AnnotProperty _ -> + assert false (* rejected earlier in pipeline *) + | A.Body (Equation (_, StructDef (_, _), _)) -> + assert false (* rejected earlier in pipeline, should we support it? *) + ) + [] + in + let input_map = + List.map2 (fun (_, id, _, _, _) e -> (id, e)) inputs args + in + match var_map with + | (_, e) :: _ -> AH.apply_subst_in_expr input_map e + | _ -> assert false let mk_fresh_array_ctor info pos ind_vars expr_type expr size_expr = i := !i + 1; @@ -616,6 +636,94 @@ let rec mk_ref_type_expr: Ctx.tc_context -> A.expr -> A.lustre_type -> A.expr li ) exprs | _ -> [] +let mk_enum_subrange_reftype_constraints node_id info vars = + let enum_subrange_reftype_vars = + vars |> List.filter (fun (_, _, ty) -> + let ty' = Ctx.expand_type_syn info.context ty in + Ctx.type_contains_enum_subrange_reftype info.context ty' + ) + in + let constraints = + List.fold_left + (fun acc (_, id, ty) -> + let expr = A.Ident(dpos, id) in + let range_exprs = + List.map fst (mk_enum_range_expr info.context node_id ty expr) @ + (mk_ref_type_expr info.context expr ty) + in + range_exprs :: acc + ) + [] + enum_subrange_reftype_vars + |> List.flatten + in + match constraints with + | [] -> None + | c :: cs -> + let conj = + List.fold_left + (fun acc c' -> A.BinaryOp (dpos, A.And, c', acc)) c cs + in + Some conj + +let mk_fresh_oracle expr_type expr = + i := !i + 1; + let prefix = HString.mk_hstring (string_of_int !i) in + let name = HString.concat2 prefix (HString.mk_hstring "_oracle") in + let nexpr = A.Ident (Lib.dummy_pos, name) in + let gids = { (empty ()) with + oracles = [name, expr_type, expr]; } + in nexpr, name, gids + +let mk_fresh_local force info pos ind_vars expr_type expr = + match (LocalCache.find_opt local_cache expr, force) with + | Some nexpr, false -> nexpr, empty () + | _ -> + i := !i + 1; + let prefix = HString.mk_hstring (string_of_int !i) in + let name = HString.concat2 prefix (HString.mk_hstring "_glocal") in + let nexpr = A.Ident (pos, name) in + let (eq_lhs, nexpr) = generalize_to_array_expr name ind_vars expr nexpr in + let gids = { (empty ()) with + locals = StringMap.singleton name expr_type; + equations = [(info.quantified_variables, info.contract_scope, eq_lhs, expr)]; } + in + LocalCache.add local_cache expr nexpr; + nexpr, gids + +let mk_fresh_frozen_local node_id info pos ind_vars expr_type = + i := !i + 1; + let prefix = HString.mk_hstring (string_of_int !i) in + let name = HString.concat2 prefix (HString.mk_hstring "_flocal") in + let nexpr = A.Ident (pos, name) in + let init, oracle_id, gids1 = mk_fresh_oracle expr_type nexpr in + let expr = A.Arrow (pos, init, Pre (pos, nexpr)) in + let (eq_lhs, nexpr) = generalize_to_array_expr name ind_vars expr nexpr in + let constraints = + let typed_var = (pos, oracle_id, expr_type) in + let info = { info with + context = Ctx.add_ty info.context oracle_id expr_type + } in + (* Assume constraints are constant expressions, and thus, + no normalization is required *) + mk_enum_subrange_reftype_constraints (Some node_id) info [typed_var] + in + let asserts, gids3 = + match constraints with + | Some c -> ( + let c_expr, gids3 = + mk_fresh_local false info pos ind_vars (A.Bool (dummy_pos)) c + in + [(pos, c_expr)], gids3 + ) + | None -> [], empty () + in + let gids2 = { (empty ()) with + locals = StringMap.singleton name expr_type; + asserts; + equations = [(info.quantified_variables, info.contract_scope, eq_lhs, expr)]; } + in + nexpr, name, union (union gids1 gids2) gids3 let mk_fresh_refinement_type_constraint source info pos id expr_type = let ref_type_exprs = mk_ref_type_expr info.context id expr_type in @@ -635,16 +743,7 @@ let mk_fresh_refinement_type_constraint source info pos id expr_type = in List.fold_left union (empty ()) gids -let mk_fresh_oracle expr_type expr = - i := !i + 1; - let prefix = HString.mk_hstring (string_of_int !i) in - let name = HString.concat2 prefix (HString.mk_hstring "_oracle") in - let nexpr = A.Ident (Lib.dummy_pos, name) in - let gids = { (empty ()) with - oracles = [name, expr_type, expr]; } - in nexpr, gids - -let mk_fresh_call info id map pos cond restart ty_args args defaults = +let mk_fresh_call ?(inlined=false) info id map pos cond restart ty_args args defaults = let called_node = StringMap.find id map in let has_oracles = List.length called_node.oracles > 0 in let has_ty_args = List.length ty_args > 0 in @@ -664,7 +763,7 @@ let mk_fresh_call info id map pos cond restart ty_args args defaults = (HString.mk_hstring "proj_") in let nexpr = A.Ident (pos, HString.concat2 proj name) in - let call = (pos, name, cond, restart, id, args, defaults) in + let call = (pos, name, cond, restart, id, args, defaults, inlined) in let gids = { (empty ()) with calls = [call] } in if not has_ty_args then CallCache.add call_cache (id, cond, restart, args, defaults) nexpr; nexpr, gids @@ -970,7 +1069,22 @@ let desugar_history_in_expr ctx ctr_id prefix expr = r StringMap.empty expr -let rec normalize ctx ai_ctx (decls:LustreAst.t) gids = +let get_inlinable_func_decls inlinable_funcs decls = + List.fold_left + (fun acc decl -> + match decl with + | A.FuncDecl (_, nd) -> + let (id, _, _, _, _, _, _, _, _) = nd in + if A.SI.mem id inlinable_funcs then + StringMap.add id nd acc + else + acc + | _ -> acc + ) + StringMap.empty + decls + +let rec normalize ctx ai_ctx inlinable_funcs (decls:LustreAst.t) gids = let info = { context = ctx; abstract_interp_context = ai_ctx; inductive_variables = StringMap.empty; @@ -980,7 +1094,8 @@ let rec normalize ctx ai_ctx (decls:LustreAst.t) gids = contract_ref = HString.mk_hstring ""; contract_scope = []; interpretation = StringMap.empty; - local_group_projection = -1 } + local_group_projection = -1; + inlinable_funcs = get_inlinable_func_decls inlinable_funcs decls } in let over_declarations (nitems, accum, warnings_accum) item = clear_cache (); @@ -1747,50 +1862,70 @@ and normalize_expr ?guard info node_id map = in let iexpr, gids2 = mk_fresh_node_arg_local info pos is_const ty nexpr in iexpr, union gids1 gids2, warnings - in let mk_enum_subrange_reftype_constraints info vars = - let enum_subrange_reftype_vars = - vars |> List.filter (fun (_, _, ty) -> - let ty' = Ctx.expand_type_syn info.context ty in - Ctx.type_contains_enum_subrange_reftype info.context ty' - ) - in - let constraints = - List.fold_left - (fun acc (_, id, ty) -> - let expr = A.Ident(dpos, id) in - let range_exprs = - List.map fst (mk_enum_range_expr info.context (Some node_id) ty expr) @ - (mk_ref_type_expr info.context expr ty) - in - range_exprs :: acc - ) - [] - enum_subrange_reftype_vars - |> List.flatten - in - match constraints with - | [] -> None - | c :: cs -> - let conj = - List.fold_left - (fun acc c' -> A.BinaryOp (dpos, A.And, c', acc)) c cs - in - Some conj in function (* ************************************************************************ *) (* Node calls *) (* ************************************************************************ *) | Call (pos, ty_args, id, args) -> - let flags = StringMap.find id info.node_is_input_const in - let cond = A.Const (Lib.dummy_pos, A.True) in - let restart = A.Const (Lib.dummy_pos, A.False) in - let nargs, gids1, warnings = normalize_list - (fun (arg, is_const) -> abstract_node_arg ?guard:None false is_const info map arg) - (combine_args_with_const info args flags) + let is_inlinable = StringMap.mem id info.inlinable_funcs in + let info, vmap, gids0 = + if is_inlinable then (* Only generate variables if inlinable *) + let args_vars = + List.fold_left + (fun acc e -> A.SI.union acc (AH.vars_without_node_call_ids e)) + A.SI.empty + args + in + let ivars = info.inductive_variables in + List.fold_left + (fun (info, vmap, gids) (pos_v, v, ty) -> + if A.SI.mem v args_vars then + let nexpr, id, gids' = + mk_fresh_frozen_local node_id info pos_v ivars ty + in + let info = + let ctx = Ctx.add_ty info.context id ty in + { info with context = ctx } + in + (info, (v, nexpr) :: vmap, union gids gids') + else + (info, vmap, gids) + ) + (info, [], (empty ())) + info.quantified_variables + else + (info, [], empty()) in - let nexpr, gids2 = mk_fresh_call info id map pos cond restart ty_args nargs None in - nexpr, union gids1 gids2, warnings + let handle_call inlined args = + let flags = StringMap.find id info.node_is_input_const in + let cond = A.Const (Lib.dummy_pos, A.True) in + let restart = A.Const (Lib.dummy_pos, A.False) in + let nargs, gids1, warnings = normalize_list + (fun (arg, is_const) -> abstract_node_arg ?guard:None false is_const info map arg) + (combine_args_with_const info args flags) + in + let nexpr, gids2 = + mk_fresh_call ~inlined info id map pos cond restart ty_args nargs None + in + nexpr, union gids1 gids2, warnings + in + if (is_inlinable && vmap <> []) + then ( + let nargs, gids1, warnings1 = normalize_list + (fun arg -> normalize_expr ?guard info node_id map arg) + args + in + let nexpr = get_inline_func_expr info.inlinable_funcs id nargs in + let args = + List.map (fun a -> AH.apply_subst_in_expr vmap a) args + in + let _, gids2, warnings2 = handle_call true args in + nexpr, union_list [gids0; gids1; gids2], warnings1 @ warnings2 + ) + else ( + handle_call false args + ) | Condact (pos, cond, restart, id, args, defaults) -> let flags = StringMap.find id info.node_is_input_const in let ncond, gids1, warnings1 = if AH.expr_is_true cond then cond, empty (), [] @@ -1877,7 +2012,7 @@ and normalize_expr ?guard info node_id map = let guard, gids2, warnings2, previously_guarded = match guard with | Some guard -> guard, empty (), [], true | None -> - let guard, gids = mk_fresh_oracle ty nexpr in + let guard, _, gids = mk_fresh_oracle ty nexpr in let warnings = [mk_warning pos (UnguardedPreWarning (Pre (pos, expr)))] in guard, gids, warnings, false in @@ -1989,7 +2124,9 @@ and normalize_expr ?guard info node_id map = let nexpr, gids, warnings = normalize_expr ?guard info node_id map expr in let nexpr = let constraints = - mk_enum_subrange_reftype_constraints info vars + (* Assume constraints are constant expressions, and thus, + no normalization is required *) + mk_enum_subrange_reftype_constraints (Some node_id) info vars in match constraints, kind with | None, _ -> nexpr diff --git a/src/lustre/lustreAstNormalizer.mli b/src/lustre/lustreAstNormalizer.mli index 4f40c6957..5d8c5fe15 100644 --- a/src/lustre/lustreAstNormalizer.mli +++ b/src/lustre/lustreAstNormalizer.mli @@ -101,6 +101,7 @@ val mk_enum_range_expr : TypeCheckerContext.tc_context -> val normalize : TypeCheckerContext.tc_context -> LustreAbstractInterpretation.context -> + LustreAst.SI.t -> LustreAst.t -> GeneratedIdentifiers.t GeneratedIdentifiers.StringMap.t -> (LustreAst.declaration list * GeneratedIdentifiers.t GeneratedIdentifiers.StringMap.t * diff --git a/src/lustre/lustreInput.ml b/src/lustre/lustreInput.ml index e911fc225..290370788 100644 --- a/src/lustre/lustreInput.ml +++ b/src/lustre/lustreInput.ml @@ -48,6 +48,7 @@ module LDN = LustreDesugarAnyOps module LFR = LustreFlattenRefinementTypes module LGI = LustreGenRefTypeImpNodes module LIP = LustreInstantiatePolyNodes +module LUF = LustreUserFunctions type error = [ | `LustreArrayDependencies of Lib.position * LustreArrayDependencies.error_kind @@ -208,16 +209,24 @@ let type_check declarations = let const_inlined_type_and_consts = LFR.flatten_ref_types inlined_global_ctx const_inlined_type_and_consts in let const_inlined_nodes_and_contracts = LFR.flatten_ref_types inlined_global_ctx const_inlined_nodes_and_contracts in - (* Step 18. Normalize AST: guard pres, abstract to locals where appropriate *) - let* (normalized_nodes_and_contracts, gids, warnings5) = - LAN.normalize inlined_global_ctx abstract_interp_ctx const_inlined_nodes_and_contracts gids + (* Step 18. Check no quantified variable in argument of non-inlinable function *) + let inlinable_funcs = + LUF.inlinable_functions inlined_global_ctx const_inlined_nodes_and_contracts + in + let* warnings5 = + LS.no_quant_vars_in_calls_to_non_inlinable_funcs inlinable_funcs declarations + in + + (* Step 19. Normalize AST: guard pres, abstract to locals where appropriate *) + let* (normalized_nodes_and_contracts, gids, warnings6) = + LAN.normalize inlined_global_ctx abstract_interp_ctx inlinable_funcs const_inlined_nodes_and_contracts gids in Res.ok (inlined_global_ctx, gids, const_inlined_type_and_consts @ normalized_nodes_and_contracts, toplevel_nodes, - warnings1 @ warnings2 @ warnings3 @ warnings4 @ warnings5) + warnings1 @ warnings2 @ warnings3 @ warnings4 @ warnings5 @ warnings6) ) in match tc_res with diff --git a/src/lustre/lustreNode.ml b/src/lustre/lustreNode.ml index 0ce506bcc..9ff8f844a 100644 --- a/src/lustre/lustreNode.ml +++ b/src/lustre/lustreNode.ml @@ -124,6 +124,8 @@ type node_call = { [merge] operator.*) call_defaults : E.t D.t option; + (* Whether this call was inlined or not *) + call_inlined : bool; } diff --git a/src/lustre/lustreNode.mli b/src/lustre/lustreNode.mli index d0ea0d0d0..79c9059f7 100644 --- a/src/lustre/lustreNode.mli +++ b/src/lustre/lustreNode.mli @@ -102,6 +102,8 @@ type node_call = { If the option value is not [None], the keys of the index match those in the {!t.outputs} field of the called node. *) + call_inlined : bool; + (** Whether this call was inlined or not *) } diff --git a/src/lustre/lustreNodeGen.ml b/src/lustre/lustreNodeGen.ml index 4ddd204d3..c676887f6 100644 --- a/src/lustre/lustreNodeGen.ml +++ b/src/lustre/lustreNodeGen.ml @@ -1176,7 +1176,7 @@ and compile_ast_expr | A.When _ -> assert false | A.Activate _ -> assert false -and compile_node node_scope pos ctx cstate map outputs cond restart ident args defaults = +and compile_node node_scope pos ctx cstate map outputs cond restart ident args defaults inlined = let called_node = N.node_of_name ident cstate.nodes in let po_ct = !map.poracle_count in map := {!map with poracle_count = po_ct + (List.length called_node.oracles) }; @@ -1260,7 +1260,8 @@ and compile_node node_scope pos ctx cstate map outputs cond restart ident args d N.call_inputs = input_state_vars; N.call_oracles = oracles; N.call_outputs = outputs; - N.call_defaults = defaults + N.call_defaults = defaults; + N.call_inlined = inlined; } in node_call @@ -1716,7 +1717,7 @@ and compile_node_decl gids_map is_function opac cstate ctx i ext params inputs o (* ****************************************************************** *) in let () = - let over_calls = fun () ((_, var, _, _, ident, _, _)) -> + let over_calls = fun () ((_, var, _, _, ident, _, _, _)) -> let node_id = mk_ident ident in let called_node = N.node_of_name node_id cstate.nodes in let _outputs = @@ -1815,7 +1816,9 @@ and compile_node_decl gids_map is_function opac cstate ctx i ext params inputs o in let (calls, glocals) = let seen_calls = ref SVS.empty in - let over_calls = fun (calls, glocals) (pos, var, cond, restart, ident, args, defaults) -> + let over_calls = + fun (calls, glocals) (pos, var, cond, restart, ident, args, defaults, inlined) + -> let node_id = mk_ident ident in let called_node = N.node_of_name node_id cstate.nodes in (* let output_ast_types = (match Ctx.lookup_node_ty ctx ident with @@ -1856,7 +1859,7 @@ and compile_node_decl gids_map is_function opac cstate ctx i ext params inputs o X.fold over_vars called_node.outputs X.empty in let node_call = compile_node - node_scope pos ctx cstate map outputs cond restart node_id args defaults + node_scope pos ctx cstate map outputs cond restart node_id args defaults inlined in let glocals' = H.fold (fun _ v a -> (X.singleton X.empty_index v) :: a) local_map [] in node_call :: calls, glocals' @ glocals @@ -1928,6 +1931,17 @@ and compile_node_decl gids_map is_function opac cstate ctx i ext params inputs o N.add_state_var_def sv (N.Assertion pos); (pos, sv) in List.map op node_asserts + + (* ****************************************************************** *) + (* Generated assertions *) + (* ****************************************************************** *) + in let asserts = + let op (pos, expr) = + let id = extract_normalized expr in + let sv = H.find !map.state_var id in + (* N.add_state_var_def sv (N.Assertion pos); *) + (pos, sv) + in asserts @ List.map op gids.GI.asserts (* ****************************************************************** *) (* Helpers for generated and user equations *) (* ****************************************************************** *) diff --git a/src/lustre/lustreSimplify.ml b/src/lustre/lustreSimplify.ml index e1de119be..3bbba93ff 100644 --- a/src/lustre/lustreSimplify.ml +++ b/src/lustre/lustreSimplify.ml @@ -1911,7 +1911,8 @@ and eval_node_call N.call_inputs = input_state_vars; N.call_oracles = oracle_state_vars; N.call_outputs = output_state_vars; - N.call_defaults = defaults } + N.call_defaults = defaults; + N.call_inlined = false } in (* Add node call to context *) let ctx = C.add_node_call ctx pos node_call in diff --git a/src/lustre/lustreSlicing.ml b/src/lustre/lustreSlicing.ml index ff33caa77..bf159ce95 100644 --- a/src/lustre/lustreSlicing.ml +++ b/src/lustre/lustreSlicing.ml @@ -774,6 +774,17 @@ let roots_of_contract_ass = function let with_sofar_var = assumes <> [] in Contract.svars_of ~with_sofar_var contract +let roots_of_inlined_calls calls = + List.fold_left + (fun acc c -> + if c.N.call_inlined then + SVS.union acc (D.values c.call_outputs |> SVS.of_list) + else + acc + ) + SVS.empty + calls + (* Add state variables in assertion *) let add_roots_of_asserts asserts roots = List.fold_left @@ -1096,6 +1107,7 @@ let root_and_leaves_of_impl ({ N.outputs; N.contract; N.props; + N.calls; N.asserts } as node) = (* Slice everything from node *) @@ -1133,6 +1145,8 @@ let root_and_leaves_of_impl |> SVS.union ( if is_top then SVS.empty else D.values outputs |> SVS.of_list ) + |> SVS.union (roots_of_inlined_calls calls) + |> SVS.elements in @@ -1147,7 +1161,8 @@ let root_and_leaves_of_impl let root_and_leaves_of_contracts is_top roots - ({ N.outputs; + ({ N.outputs; + N.calls; N.contract } as node) = (* Slice everything from node *) @@ -1164,6 +1179,7 @@ let root_and_leaves_of_contracts match roots node false with | None -> roots_of_contract ~with_sofar_var:(not is_top) contract + |> SVS.union (roots_of_inlined_calls calls) |> SVS.elements | Some r -> SVS.elements r diff --git a/src/lustre/lustreSyntaxChecks.ml b/src/lustre/lustreSyntaxChecks.ml index e2021267f..e0d939e35 100644 --- a/src/lustre/lustreSyntaxChecks.ml +++ b/src/lustre/lustreSyntaxChecks.ml @@ -95,10 +95,10 @@ let error_message kind = match kind with | QuantifiedVariableInPre var -> "Quantified variable '" ^ HString.string_of_hstring var ^ "' is not allowed in an argument to pre operator" | QuantifiedVariableInNodeArgument (var, node) -> "Quantified variable '" - ^ HString.string_of_hstring var ^ "' is not allowed in an argument to the node call '" + ^ HString.string_of_hstring var ^ "' is not allowed in an argument of a call to node or non-inlinable function '" ^ HString.string_of_hstring node ^ "'" | SymbolicArrayIndexInNodeArgument (idx, node) -> "Symbolic array index '" - ^ HString.string_of_hstring idx ^ "' is not allowed in an argument to the node call '" + ^ HString.string_of_hstring idx ^ "' is not allowed in an argument of a call to node or non-inlinable function '" ^ HString.string_of_hstring node ^ "'" | AnyOpInFunction -> "Illegal any operator in function" | NodeCallInFunction node -> "Illegal call to node '" @@ -511,7 +511,7 @@ let no_node_calls_in_constant i e = else Ok () let no_quant_var_or_symbolic_index_in_node_call ctx = function - | LA.Call (pos, _, i, args) -> + (*| LA.Call (pos, _, i, args) -> let vars = List.fold_left (fun acc e -> LA.SI.union acc (LAH.vars_without_node_call_ids e)) @@ -527,7 +527,7 @@ let no_quant_var_or_symbolic_index_in_node_call ctx = function | false, false -> Ok ()) in let check = List.map over_vars (LA.SI.elements vars) in - List.fold_left (>>) (Ok ()) check + List.fold_left (>>) (Ok ()) check*) | LA.Pre (_, ArrayIndex (_, _, _)) -> Ok () | LA.Pre (pos, e) -> let vars = LAH.vars_without_node_call_ids e in @@ -1031,3 +1031,71 @@ let no_mismatched_clock is_bool e = | _ -> Ok ([]) in check_expr ctx (fun _ -> check_merge) e + + +let ovq_check_expr inlinable_funcs ctx = function +| LA.Call (pos, _, i, args) -> + let vars = + List.fold_left + (fun acc e -> LA.SI.union acc (LAH.vars_without_node_call_ids e)) + LA.SI.empty + args + in + let over_vars j = + let found_quant_in_non_inlinable = + StringMap.mem j ctx.quant_vars && not (LA.SI.mem i inlinable_funcs) + in + let found_symbolic_index_in_non_inlinable = + StringMap.mem j ctx.symbolic_array_indices && + not (LA.SI.mem i inlinable_funcs) + in + (match found_quant_in_non_inlinable, found_symbolic_index_in_non_inlinable with + | true, _ -> syntax_error pos (QuantifiedVariableInNodeArgument (j, i)) + | _, true -> syntax_error pos (SymbolicArrayIndexInNodeArgument (j, i)) + | false, false -> Ok []) + in + let check = List.map over_vars (LA.SI.elements vars) in + List.fold_left (>>) (Ok []) check +| _ -> Ok [] + +let oqv_check_node_decl inlinable_funcs ctx (_, _, _, _, inputs, outputs, locals, items, contract) = + let* warnings1 = + match contract with + | Some c -> + let ctx = + (* Locals are not visible in contracts *) + build_local_ctx ctx [] inputs outputs + in + check_contract false ctx (ovq_check_expr inlinable_funcs) c + | None -> Ok ([]) + in + let ctx = build_local_ctx ctx locals inputs outputs in + let* warnings2 = + check_items + (build_local_ctx ctx locals [] []) (* Add locals to ctx *) + (ovq_check_expr inlinable_funcs) + items + in + Ok (warnings1 @ warnings2) + +let oqv_check_contract_node_decl inlinable_funcs ctx (_, _, inputs, outputs, contract) = + let ctx = build_local_ctx ctx [] inputs outputs in + let* warnings = + check_contract true ctx (ovq_check_expr inlinable_funcs) contract + in + Ok warnings + +let oqv_check_decl: LA.SI.t -> context -> LA.declaration -> ([> warning] list, [> error]) result += fun inlinable_funcs ctx -> function + | NodeDecl (_, decl) -> + oqv_check_node_decl inlinable_funcs ctx decl + | FuncDecl (_, decl) -> + oqv_check_node_decl inlinable_funcs ctx decl + | ContractNodeDecl (_, decl) -> + oqv_check_contract_node_decl inlinable_funcs ctx decl + | _ -> Ok [] + +let no_quant_vars_in_calls_to_non_inlinable_funcs inlinable_funcs ast = + let ctx = build_global_ctx ast in + let* warnings = Res.seq (List.map (oqv_check_decl inlinable_funcs ctx) ast) in + Ok (List.flatten warnings) diff --git a/src/lustre/lustreSyntaxChecks.mli b/src/lustre/lustreSyntaxChecks.mli index b100404a1..09e31f729 100644 --- a/src/lustre/lustreSyntaxChecks.mli +++ b/src/lustre/lustreSyntaxChecks.mli @@ -76,3 +76,6 @@ val no_mismatched_clock : bool -> LA.expr -> ([> warning ] list, [> error]) resu Note: type information is needed for this check, causing this check to be called in the lustreTypeChecker *) + +val no_quant_vars_in_calls_to_non_inlinable_funcs : + LA.SI.t -> LA.t -> ([> warning ] list, [> error]) result \ No newline at end of file diff --git a/src/lustre/lustreUserFunctions.ml b/src/lustre/lustreUserFunctions.ml new file mode 100644 index 000000000..fffdb62a0 --- /dev/null +++ b/src/lustre/lustreUserFunctions.ml @@ -0,0 +1,84 @@ +(* This file is part of the Kind 2 model checker. + + Copyright (c) 2024 by the Board of Trustees of the University of Iowa + + Licensed under the Apache License, Version 2.0 (the "License"); you + may not use this file except in compliance with the License. You + may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied. See the License for the specific language governing + permissions and limitations under the License. + +*) + +module A = LustreAst +module AH = LustreAstHelpers +module Ctx = TypeCheckerContext + +module IdSet = A.SI + +let valid_outputs ctx = function + | [(_, _, ty, _)] -> ( (* single output variable *) + not (Ctx.type_contains_array ctx ty) + ) + | _ -> false + +let valid_locals ctx locals = + locals |> List.for_all (function + | A.NodeConstDecl (_, TypedConst (_,_,_,ty)) -> + not (Ctx.type_contains_array ctx ty) + | A.NodeConstDecl _ -> true + | NodeVarDecl (_, (_, _, ty, _)) -> + not (Ctx.type_contains_array ctx ty) + ) + +let valid_items set items = + items |> List.for_all (function + | A.Body (Equation (_, StructDef (_, [A.SingleIdent _]), rhs)) -> + IdSet.subset (AH.calls_of_expr rhs) set + | AnnotProperty _ -> true + | A.Body (Equation (_, _, _)) + | Body (Assert _) + | AnnotMain _ -> false + | FrameBlock _ + | IfBlock _ -> assert false (* desugared earlier in pipeline *) + ) + +let is_output_defined outputs items = + let output_id = + match outputs with + | [(_, id, _, _)] -> id + | _ -> assert false + in + items |> List.exists (function + | A.Body (Equation (_, StructDef (_, [A.SingleIdent (_, id)]), _)) -> + HString.equal id output_id + | _ -> false + ) + +let is_inlinable set ctx opac contract outputs locals items = + (opac = A.Transparent || contract = None) && + valid_outputs ctx outputs && + valid_locals ctx locals && + valid_items set items && + is_output_defined outputs items + +let inlinable_functions ctx decls = + List.fold_left (fun set dcl -> + match dcl with + (* A non-imported function *) + | A.FuncDecl (_, (id, false, opac, [], _, outputs, locals, items, contract)) -> ( + if is_inlinable set ctx opac contract outputs locals items then + IdSet.add id set + else + set + ) + | _ -> set + ) + IdSet.empty + decls diff --git a/src/lustre/lustreUserFunctions.mli b/src/lustre/lustreUserFunctions.mli new file mode 100644 index 000000000..143bce4ac --- /dev/null +++ b/src/lustre/lustreUserFunctions.mli @@ -0,0 +1,22 @@ +(* This file is part of the Kind 2 model checker. + + Copyright (c) 2024 by the Board of Trustees of the University of Iowa + + Licensed under the Apache License, Version 2.0 (the "License"); you + may not use this file except in compliance with the License. You + may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied. See the License for the specific language governing + permissions and limitations under the License. + +*) + +val inlinable_functions : + TypeCheckerContext.tc_context -> + LustreAst.declaration list -> + LustreAst.SI.t diff --git a/src/lustre/typeCheckerContext.ml b/src/lustre/typeCheckerContext.ml index ed9d47666..359c67a0a 100644 --- a/src/lustre/typeCheckerContext.ml +++ b/src/lustre/typeCheckerContext.ml @@ -802,6 +802,29 @@ let rec type_contains_abstract ctx = function | Int8 _ |Int16 _ |Int32 _ | Int64 _ | AbstractType _ -> false +let rec type_contains_array ctx = function + | LA.ArrayType (_, (_, _)) -> true + | RefinementType (_, (_, _, ty), _) -> type_contains_array ctx ty + | TupleType (_, tys) | GroupType (_, tys) -> + List.fold_left (fun acc ty -> acc || type_contains_array ctx ty) false tys + | RecordType (_, _, tys) -> + List.fold_left (fun acc (_, _, ty) -> acc || type_contains_array ctx ty) + false tys + | TArr (_, ty1, ty2) -> type_contains_array ctx ty1 || type_contains_array ctx ty2 + | History (_, id) -> + (match lookup_ty ctx id with + | Some ty -> type_contains_array ctx ty + | _ -> assert false) + | UserType (_, ty_args, id) -> ( + match lookup_ty_syn ctx id ty_args with + | Some ty -> type_contains_array ctx ty + | None -> assert false + ) + | Bool _ | Int _ | Real _ | EnumType _ | IntRange _ + | UInt8 _| UInt16 _| UInt32 _| UInt64 _ + | Int8 _ |Int16 _ |Int32 _ | Int64 _ + | AbstractType _ -> false + let rec ty_vars_of_expr ctx node_name expr = let call = ty_vars_of_expr ctx node_name in match expr with (* Node calls *) diff --git a/src/lustre/typeCheckerContext.mli b/src/lustre/typeCheckerContext.mli index acfc89dd3..338654827 100644 --- a/src/lustre/typeCheckerContext.mli +++ b/src/lustre/typeCheckerContext.mli @@ -289,6 +289,9 @@ val type_contains_enum_subrange_reftype : tc_context -> LA.lustre_type -> bool val type_contains_abstract : tc_context -> tc_type -> bool (** Returns true if the lustre type expression contains an abstract type (including polymorphic type variable) or if it is an abstract type *) +val type_contains_array: tc_context -> tc_type -> bool +(** Returns true if the lustre type expression contains an array *) + val ty_vars_of_expr: tc_context -> LA.index -> LA.expr -> SI.t (** [ty_vars_of_type ctx node_name e] returns all type variable identifiers that appear in the expression [e] *)