Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for free constants with refinement types #1110

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/lustre/lustreAstNormalizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -584,42 +584,42 @@ let mk_fresh_subrange_constraint source info pos node_id constrained_name expr_t
in
List.fold_left union (empty ()) gids

let rec mk_ref_type_expr: Ctx.tc_context -> A.expr -> source -> A.lustre_type -> (source * A.expr) list
= fun ctx id source ty ->
let rec mk_ref_type_expr: Ctx.tc_context -> A.expr -> A.lustre_type -> A.expr list
= fun ctx id ty ->
let ty = Ctx.expand_type_syn ctx ty in
match ty with
| A.RefinementType (_, (_, id2, _), expr) ->
(* For refinement type variable of the form x = { y: int | ... }, write the constraint
in terms of x instead of y *)
let expr = AH.substitute_naive id2 id expr in
[(source, expr)]
[expr]
| TupleType (pos, tys)
| GroupType (pos, tys) -> List.mapi (fun i ty ->
mk_ref_type_expr ctx (A.TupleProject(pos, id, i)) source ty
mk_ref_type_expr ctx (A.TupleProject(pos, id, i)) ty
) tys |> List.flatten
| RecordType (p, _, tis) ->
List.map (fun (_, id2, ty) ->
let expr = A.RecordProject(p, id, id2) in
mk_ref_type_expr ctx expr source ty
mk_ref_type_expr ctx expr ty
) tis |> List.flatten
| ArrayType (_, (ty, len)) ->
let pos = AH.pos_of_expr id in
let dummy_index = mk_fresh_dummy_index () in
let exprs_sources = mk_ref_type_expr ctx (A.ArrayIndex(pos, id, Ident(pos, dummy_index))) source ty in
List.map (fun (source, expr) ->
let exprs = mk_ref_type_expr ctx (A.ArrayIndex(pos, id, Ident(pos, dummy_index))) ty in
List.map (fun expr ->
let bound1 =
A.CompOp(pos, Lte, A.Const(pos, Num (HString.mk_hstring "0")), A.Ident(pos, dummy_index))
in
let bound2 = A.CompOp(pos, Lt, A.Ident(pos, dummy_index), len) in
let expr = A.BinaryOp(pos, Impl, A.BinaryOp(pos, And, bound1, bound2), expr) in
source, A.Quantifier(pos, Forall, [pos, dummy_index, A.Int pos], expr)
) exprs_sources
A.Quantifier(pos, Forall, [pos, dummy_index, A.Int pos], expr)
) exprs
| _ -> []


let mk_fresh_refinement_type_constraint source info pos id expr_type =
let ref_type_exprs = mk_ref_type_expr info.context id source expr_type in
let gids = List.map (fun (source, ref_type_expr) ->
let ref_type_exprs = mk_ref_type_expr info.context id expr_type in
let gids = List.map (fun ref_type_expr ->
i := !i + 1;
let output_expr = AH.rename_contract_vars ref_type_expr in
let prefix = HString.mk_hstring (string_of_int !i) in
Expand Down Expand Up @@ -1758,7 +1758,7 @@ and normalize_expr ?guard info node_id map =
let expr = A.Ident(dpos, id) in
let range_exprs =
List.map fst (mk_enum_range_expr info.context (Some node_id) ty expr) @
List.map snd (mk_ref_type_expr info.context expr Local ty)
(mk_ref_type_expr info.context expr ty)
in
range_exprs :: acc
)
Expand Down
5 changes: 5 additions & 0 deletions src/lustre/lustreAstNormalizer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ val mk_range_expr : TypeCheckerContext.tc_context ->
LustreAst.expr ->
(LustreAst.expr * bool) list

val mk_ref_type_expr : TypeCheckerContext.tc_context ->
LustreAst.expr ->
LustreAst.lustre_type ->
LustreAst.expr list

val mk_enum_range_expr : TypeCheckerContext.tc_context ->
HString.t option ->
LustreAst.lustre_type ->
Expand Down
1 change: 1 addition & 0 deletions src/lustre/lustreInput.ml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ let type_check declarations =
let inlined_global_ctx, const_inlined_nodes_and_contracts = LIP.instantiate_polymorphic_nodes inlined_global_ctx const_inlined_nodes_and_contracts in

(* Step 17. Flatten refinement types *)
let const_inlined_type_and_consts = LFR.flatten_ref_types inlined_global_ctx const_inlined_type_and_consts in
let const_inlined_nodes_and_contracts = LFR.flatten_ref_types inlined_global_ctx const_inlined_nodes_and_contracts in

(* Step 18. Normalize AST: guard pres, abstract to locals where appropriate *)
Expand Down
17 changes: 13 additions & 4 deletions src/lustre/lustreNodeGen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2365,15 +2365,24 @@ and compile_const_decl ?(ghost = false) cstate ctx map scope = function
else (
let global_constraints =
let ty = Ctx.expand_type_syn ctx ty in
if Ctx.type_contains_subrange ctx ty then (
let has_subrange = Ctx.type_contains_subrange ctx ty in
let has_ref_type = Ctx.type_contains_ref ctx ty in
if has_subrange || has_ref_type then (
let ctx = Ctx.add_ty ctx i ty in
let range_exprs =
let ctx = Ctx.add_ty ctx i ty in
AN.mk_range_expr ctx None ty (A.Ident (p, i)) |> List.map fst
if has_subrange then
AN.mk_range_expr ctx None ty (A.Ident (p, i)) |> List.map fst
else []
in
let ref_type_exprs =
if has_ref_type then
AN.mk_ref_type_expr ctx (A.Ident(p, i)) ty
else []
in
List.map (fun expr ->
let c_expr = compile_ast_expr cstate ctx [] map expr in
X.max_binding c_expr |> snd
) range_exprs @ cstate.global_constraints
) (range_exprs @ ref_type_exprs) @ cstate.global_constraints
)
else cstate.global_constraints
in
Expand Down
10 changes: 7 additions & 3 deletions src/lustre/lustreTypeChecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ let error_message kind = match kind with
^ Lib.string_of_t LA.pp_print_expr e
| IntervalMustHaveBound -> "Range should have at least one bound"
| ExpectedRecordType ty -> "Expected record type but found " ^ string_of_tc_type ty
| GlobalConstRefType id -> "Global constant '" ^ HString.string_of_hstring id ^ "' has refinement type (not yet supported)"
| GlobalConstRefType id -> "Definition of global constant '" ^ HString.string_of_hstring id ^ "' has refinement type (not yet supported)"
| QuantifiedAbstractType id -> "Variable '" ^ HString.string_of_hstring id ^ "' with type that contains an abstract type (or type variable) cannot be quantified"
| InvalidPolymorphicCall id -> "Call to node, contract, or user type '" ^ HString.string_of_hstring id ^ "' passes an incorrect number of type parameters"

Expand Down Expand Up @@ -1928,15 +1928,19 @@ and build_type_and_const_context: tc_context -> LA.t -> (tc_context * [> warning
| LA.TypeDecl (_, ty_decl) :: rest ->
let* ctx' = tc_ctx_of_ty_decl ctx ty_decl in
build_type_and_const_context ctx' rest
| LA.ConstDecl (_, (TypedConst (p, i, _, ty) as const_decl)) :: rest
| LA.ConstDecl (_, ((FreeConst (p, i, ty)) as const_decl)) :: rest ->
| LA.ConstDecl (_, (TypedConst (p, i, _, ty) as const_decl)) :: rest ->
let ty = expand_type_syn ctx ty in
if type_contains_ref ctx ty then type_error p (GlobalConstRefType i)
else (
let* ctx', warnings1 = tc_ctx_const_decl ctx Global None const_decl in
let* ctx', warnings2 = build_type_and_const_context ctx' rest in
R.ok (ctx', warnings1 @ warnings2)
)
| LA.ConstDecl (_, ((FreeConst _) as const_decl)) :: rest -> (
let* ctx', warnings1 = tc_ctx_const_decl ctx Global None const_decl in
let* ctx', warnings2 = build_type_and_const_context ctx' rest in
R.ok (ctx', warnings1 @ warnings2)
)
| LA.ConstDecl (_, UntypedConst _) :: _ -> assert false
| _ :: rest -> build_type_and_const_context ctx rest
(** Process top level type declarations and make a type context with
Expand Down
14 changes: 14 additions & 0 deletions tests/regression/success/ref_type_local_free_const.lus
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

type Nat = subrange [0,*] of int;

type R0 = struct { x: Nat; y: Nat };

type R1 = subtype { r: R0 | r.x < 10 };

type R2 = subtype { r: R1 | r.y < 20};

node N() returns ()
const C: R2;
let
check "P1" C.x < 10 and C.y < 20;
tel
Loading