From 67d73ff6f12c585a0aead415fd4d4bc3d9e8956f Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Tue, 12 Dec 2023 15:03:30 +0100 Subject: [PATCH 1/9] import ocaml code --- compiler/dcalc/to_coq.ml | 654 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 654 insertions(+) create mode 100644 compiler/dcalc/to_coq.ml diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml new file mode 100644 index 000000000..7c5301968 --- /dev/null +++ b/compiler/dcalc/to_coq.ml @@ -0,0 +1,654 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020 Inria, contributor: + Alain Delaƫt-Tixeuil + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Catala_utils +open Shared_ast +open Ast + +let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit = + match Mark.remove l with + | LBool b -> Print.lit fmt (LBool b) + | LInt i -> + Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) + | LUnit -> Print.lit fmt LUnit + | LRat i -> Format.fprintf fmt "decimal_of_string \"%a\"" Print.lit (LRat i) + | LMoney e -> + Format.fprintf fmt "money_of_cents_string@ \"%s\"" + (Runtime.integer_to_string (Runtime.money_to_cents e)) + | LDate d -> + Format.fprintf fmt "date_of_numbers (%d) (%d) (%d)" + (Runtime.integer_to_int (Runtime.year_of_date d)) + (Runtime.integer_to_int (Runtime.month_number_of_date d)) + (Runtime.integer_to_int (Runtime.day_of_month_of_date d)) + | LDuration d -> + let years, months, days = Runtime.duration_to_years_months_days d in + Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days + +let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) + : unit = + Format.fprintf fmt "@[[%a]@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt info -> + Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info)) + uids + +let format_string_list (fmt : Format.formatter) (uids : string list) : unit = + let sanitize_quotes = Re.compile (Re.char '"') in + Format.fprintf fmt "@[[%a]@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt info -> + Format.fprintf fmt "\"%s\"" + (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) + uids + +(* list taken from + http://caml.inria.fr/pub/docs/manual-ocaml/lex.html#sss:keywords *) +let ocaml_keywords = + [ + "and"; + "as"; + "assert"; + "asr"; + "begin"; + "class"; + "constraint"; + "do"; + "done"; + "downto"; + "else"; + "end"; + "exception"; + "external"; + "false"; + "for"; + "fun"; + "function"; + "functor"; + "if"; + "in"; + "include"; + "inherit"; + "initializer"; + "land"; + "lazy"; + "let"; + "lor"; + "lsl"; + "lsr"; + "lxor"; + "match"; + "method"; + "mod"; + "module"; + "mutable"; + "new"; + "nonrec"; + "object"; + "of"; + "open"; + "or"; + "private"; + "rec"; + "sig"; + "struct"; + "then"; + "to"; + "true"; + "try"; + "type"; + "val"; + "virtual"; + "when"; + "while"; + "with"; + "Stdlib"; + "Runtime"; + "Oper"; + ] + +let ocaml_keywords_set = String.Set.of_list ocaml_keywords + +let avoid_keywords (s : string) : string = + if String.Set.mem s ocaml_keywords_set then s ^ "_user" else s +(* Fixme: this could cause clashes if the user program contains both e.g. [new] + and [new_user] *) + +let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = + Format.asprintf "%a" StructName.format v + |> String.to_ascii + |> String.to_snake_case + |> avoid_keywords + |> Format.fprintf fmt "%s" + +let format_to_module_name + (fmt : Format.formatter) + (name : [< `Ename of EnumName.t | `Sname of StructName.t ]) = + (match name with + | `Ename v -> Format.asprintf "%a" EnumName.format v + | `Sname v -> Format.asprintf "%a" StructName.format v) + |> String.to_ascii + |> avoid_keywords + |> Format.pp_print_string fmt + +let format_struct_field_name + (fmt : Format.formatter) + ((sname_opt, v) : StructName.t option * StructField.t) : unit = + (match sname_opt with + | Some sname -> + Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname) + | None -> Format.fprintf fmt "%s") + (avoid_keywords + (String.to_ascii (Format.asprintf "%a" StructField.format v))) + +let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = + Format.fprintf fmt "%s" + (avoid_keywords + (String.to_snake_case + (String.to_ascii (Format.asprintf "%a" EnumName.format v)))) + +let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : + unit = + Format.fprintf fmt "%s" + (avoid_keywords + (String.to_ascii (Format.asprintf "%a" EnumConstructor.format v))) + +let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = + match Mark.remove ty with + | TLit TUnit -> Format.fprintf fmt "embed_unit" + | TLit TBool -> Format.fprintf fmt "embed_bool" + | TLit TInt -> Format.fprintf fmt "embed_integer" + | TLit TRat -> Format.fprintf fmt "embed_decimal" + | TLit TMoney -> Format.fprintf fmt "embed_money" + | TLit TDate -> Format.fprintf fmt "embed_date" + | TLit TDuration -> Format.fprintf fmt "embed_duration" + | TStruct s_name -> Format.fprintf fmt "embed_%a" format_struct_name s_name + | TEnum e_name -> Format.fprintf fmt "embed_%a" format_enum_name e_name + | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty + | _ -> Format.fprintf fmt "unembeddable" + +let typ_needs_parens (e : typ) : bool = + match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false + +let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = + let format_typ_with_parens (fmt : Format.formatter) (t : typ) = + if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t + else Format.fprintf fmt "%a" format_typ t + in + match Mark.remove typ with + | TLit l -> Format.fprintf fmt "%a" Print.tlit l + | TTuple ts -> + Format.fprintf fmt "@[(%a)@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") + format_typ_with_parens) + ts + | TStruct s -> Format.fprintf fmt "%a.t" format_to_module_name (`Sname s) + | TOption t -> + Format.fprintf fmt "@[(%a)@] %a.t" format_typ_with_parens t + format_to_module_name (`Ename Expr.option_enum) + | TDefault t -> format_typ fmt t + | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) + | TArrow (t1, t2) -> + Format.fprintf fmt "@[%a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ") + format_typ_with_parens) + (t1 @ [t2]) + | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 + | TAny -> Format.fprintf fmt "_" + | TClosureEnv -> failwith "unimplemented!" + +let format_var_str (fmt : Format.formatter) (v : string) : unit = + let lowercase_name = String.to_snake_case (String.to_ascii v) in + let lowercase_name = + Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") + ~subst:(fun _ -> "_dot_") + lowercase_name + in + let lowercase_name = String.to_ascii lowercase_name in + if + List.mem lowercase_name ["handle_default"; "handle_default_opt"] + (* O_O *) + || String.begins_with_uppercase v + then Format.pp_print_string fmt lowercase_name + else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name + else Format.fprintf fmt "%s_" lowercase_name + +let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = + format_var_str fmt (Bindlib.name_of v) + +let needs_parens (e : 'm expr) : bool = + match Mark.remove e with + | EApp { f = EAbs _, _; _ } + | ELit (LBool _ | LUnit) + | EVar _ | ETuple _ | EOp _ -> + false + | _ -> true + +let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : + unit = + let format_expr = format_expr ctx in + let format_with_parens (fmt : Format.formatter) (e : 'm expr) = + if needs_parens e then Format.fprintf fmt "(%a)" format_expr e + else Format.fprintf fmt "%a" format_expr e + in + match Mark.remove e with + | EVar v -> Format.fprintf fmt "%a" format_var v + | EExternal { name } -> ( + (* FIXME: this is wrong in general !! We assume the idents exposed by the + module depend only on the original name, while they actually get through + Bindlib and may have been renamed. A correct implem could use the runtime + registration used by the interpreter, but that would be distasteful and + incur a penalty ; or we would need to reproduce the same structure as in + the original module to ensure that bindlib performs the exact same + renamings ; or finally we could normalise the names at generation time + (either at toplevel or in a dedicated submodule ?) *) + let path = + match Mark.remove name with + | External_value name -> TopdefName.path name + | External_scope name -> ScopeName.path name + in + Uid.Path.format fmt path; + match Mark.remove name with + | External_value name -> + format_var_str fmt (Mark.remove (TopdefName.get_info name)) + | External_scope name -> + format_var_str fmt (Mark.remove (ScopeName.get_info name))) + | ETuple es -> + Format.fprintf fmt "@[(%a)@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) + es + | EStruct { name = s; fields = es } -> + if StructField.Map.is_empty es then Format.fprintf fmt "()" + else + Format.fprintf fmt "{@[%a@]}" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt (struct_field, e) -> + Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name + (Some s, struct_field) format_with_parens e)) + (StructField.Map.bindings es) + | EArray es -> + Format.fprintf fmt "@[[|%a|]@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) + es + | ETupleAccess { e; index; size } -> + Format.fprintf fmt "let@ %a@ = %a@ in@ x" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt i -> + Format.pp_print_string fmt (if i = index then "x" else "_"))) + (List.init size Fun.id) format_with_parens e + | EStructAccess { e; field; name } -> + Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name + (Some name, field) + | EInj { e; cons; name } -> + Format.fprintf fmt "@[%a.%a@ %a@]" format_to_module_name + (`Ename name) format_enum_cons_name cons format_with_parens e + | EMatch { e; cases; name } -> + Format.fprintf fmt "@[@[match@ %a@]@ with@\n| %a@]" + format_with_parens e + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ") + (fun fmt (c, e) -> + Format.fprintf fmt "@[%a.%a %a@]" format_to_module_name + (`Ename name) format_enum_cons_name c + (fun fmt e -> + match Mark.remove e with + | EAbs { binder; _ } -> + let xs, body = Bindlib.unmbind binder in + Format.fprintf fmt "%a ->@ %a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@,") + (fun fmt x -> Format.fprintf fmt "%a" format_var x)) + (Array.to_list xs) format_with_parens body + | _ -> assert false + (* should not happen *)) + e)) + (EnumConstructor.Map.bindings cases) + | ELit l -> Format.fprintf fmt "%a" format_lit (Mark.add (Expr.pos e) l) + | EApp { f = EAbs { binder; tys }, _; args } -> + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in + let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in + Format.fprintf fmt "(%a%a)" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") + (fun fmt (x, tau, arg) -> + Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ in@\n" + format_var x format_typ tau format_with_parens arg)) + xs_tau_arg format_with_parens body + | EAbs { binder; tys } -> + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in + Format.fprintf fmt "@[fun@ %a ->@ %a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + (fun fmt (x, tau) -> + Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ tau)) + xs_tau format_expr body + | EApp + { + f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _; + args = [arg]; + } + when Cli.globals.trace -> + Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info + format_with_parens f format_with_parens arg + | EApp + { f = EOp { op = Log (VarDef var_def_info, info); _ }, _; args = [arg1] } + when Cli.globals.trace -> + Format.fprintf fmt + "(log_variable_definition@ %a@ {io_input=%s;@ io_output=%b}@ (%a)@ %a)" + format_uid_list info + (match var_def_info.log_io_input with + | NoInput -> "NoInput" + | OnlyInput -> "OnlyInput" + | Reentrant -> "Reentrant") + var_def_info.log_io_output typ_embedding_name + (var_def_info.log_typ, Pos.no_pos) + format_with_parens arg1 + | EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] } + when Cli.globals.trace -> + let pos = Expr.mark_pos m in + Format.fprintf fmt + "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ \ + start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" + (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) + (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list + (Pos.get_law_info pos) format_with_parens arg1 + | EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] } + when Cli.globals.trace -> + Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info + format_with_parens arg1 + | EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } -> + Format.fprintf fmt "%a" format_with_parens arg1 + | EApp { f; args } -> + Format.fprintf fmt "@[%a@ %a@]" format_with_parens f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args + | EIfThenElse { cond; etrue; efalse } -> + Format.fprintf fmt + "@[ if@ @[%a@]@ then@ @[%a@]@ else@ @[%a@]@]" + format_with_parens cond format_with_parens etrue format_with_parens efalse + | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op) + | EAssert e' -> + Format.fprintf fmt + "@[if@ %a@ then@ ()@ else@ raise (AssertionFailed @[{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ end_line=%d; \ + end_column=%d;@ law_headings=%a}@])@]" + format_with_parens e' + (Pos.get_file (Expr.pos e')) + (Pos.get_start_line (Expr.pos e')) + (Pos.get_start_column (Expr.pos e')) + (Pos.get_end_line (Expr.pos e')) + (Pos.get_end_column (Expr.pos e')) + format_string_list + (Pos.get_law_info (Expr.pos e')) + | EEmptyError -> assert false + | EDefault _ -> assert false + | EPureDefault _ -> assert false + | EErrorOnEmpty _ -> assert false + | _ -> . + +let format_struct_embedding + (fmt : Format.formatter) + ((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) = + if StructField.Map.is_empty struct_fields then + Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" + format_struct_name struct_name format_to_module_name (`Sname struct_name) + else + Format.fprintf fmt + "@[let embed_%a (x: %a.t) : runtime_value =@ Struct([\"%a\"],@ \ + @[[%a]@])@]@\n\ + @\n" + format_struct_name struct_name format_to_module_name (`Sname struct_name) + StructName.format struct_name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") + (fun fmt (struct_field, struct_field_type) -> + Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format + struct_field typ_embedding_name struct_field_type + format_struct_field_name + (Some struct_name, struct_field))) + (StructField.Map.bindings struct_fields) + +let format_enum_embedding + (fmt : Format.formatter) + ((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) = + if EnumConstructor.Map.is_empty enum_cases then + Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" + format_to_module_name (`Ename enum_name) format_enum_name enum_name + else + Format.fprintf fmt + "@[@[let embed_%a@ @[(x:@ %a.t)@]@ : runtime_value \ + =@]@ Enum([\"%a\"],@ @[match x with@ %a@])@]@\n\ + @\n" + format_enum_name enum_name format_to_module_name (`Ename enum_name) + EnumName.format enum_name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (enum_cons, enum_cons_type) -> + Format.fprintf fmt "@[| %a x ->@ (\"%a\", %a x)@]" + format_enum_cons_name enum_cons EnumConstructor.format enum_cons + typ_embedding_name enum_cons_type)) + (EnumConstructor.Map.bindings enum_cases) + +let format_ctx + (type_ordering : Scopelang.Dependency.TVertex.t list) + (fmt : Format.formatter) + (ctx : decl_ctx) : unit = + let format_struct_decl fmt (struct_name, struct_fields) = + if StructField.Map.is_empty struct_fields then + Format.fprintf fmt + "@[module %a = struct@\n@[type t = unit@]@]@\nend@\n" + format_to_module_name (`Sname struct_name) + else + Format.fprintf fmt + "@[@[module %a = struct@ @[type t = {@,\ + %a@;\ + <0-2>}@]@]@ end@]@\n" + format_to_module_name (`Sname struct_name) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt (struct_field, struct_field_type) -> + Format.fprintf fmt "@[%a:@ %a@]" format_struct_field_name + (None, struct_field) format_typ struct_field_type)) + (StructField.Map.bindings struct_fields); + if Cli.globals.trace then + format_struct_embedding fmt (struct_name, struct_fields) + in + let format_enum_decl fmt (enum_name, enum_cons) = + Format.fprintf fmt + "module %a = struct@\n@[@ type t =@\n@[ %a@]@\nend@]@\n" + format_to_module_name (`Ename enum_name) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (enum_cons, enum_cons_type) -> + Format.fprintf fmt "@[| %a@ of@ %a@]" format_enum_cons_name + enum_cons format_typ enum_cons_type)) + (EnumConstructor.Map.bindings enum_cons); + if Cli.globals.trace then format_enum_embedding fmt (enum_name, enum_cons) + in + let is_in_type_ordering s = + List.exists + (fun struct_or_enum -> + match struct_or_enum with + | Scopelang.Dependency.TVertex.Enum _ -> false + | Scopelang.Dependency.TVertex.Struct s' -> s = s') + type_ordering + in + let scope_structs = + List.map + (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) + (StructName.Map.bindings + (StructName.Map.filter + (fun s _ -> not (is_in_type_ordering s)) + ctx.ctx_structs)) + in + List.iter + (fun struct_or_enum -> + match struct_or_enum with + | Scopelang.Dependency.TVertex.Struct s -> + let def = StructName.Map.find s ctx.ctx_structs in + if StructName.path s = [] then + Format.fprintf fmt "%a@\n" format_struct_decl (s, def) + | Scopelang.Dependency.TVertex.Enum e -> + let def = EnumName.Map.find e ctx.ctx_enums in + if EnumName.path e = [] then + Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) + (type_ordering @ scope_structs) + +let rename_vars e = + Expr.( + unbox + (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true + ~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) + +let format_expr ctx fmt e = format_expr ctx fmt (rename_vars e) + +let rec format_scope_body_expr + (ctx : decl_ctx) + (fmt : Format.formatter) + (scope_lets : 'm Ast.expr scope_body_expr) : unit = + match scope_lets with + | Result e -> format_expr ctx fmt e + | ScopeLet scope_let -> + let scope_let_var, scope_let_next = + Bindlib.unbind scope_let.scope_let_next + in + Format.fprintf fmt "@[let %a: %a = %a in@]@\n%a" format_var + scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) + scope_let.scope_let_expr + (format_scope_body_expr ctx) + scope_let_next + +let format_code_items + (ctx : decl_ctx) + (fmt : Format.formatter) + (code_items : 'm Ast.expr code_item_list) : 'm Ast.expr Var.t String.Map.t = + Scope.fold_left + ~f:(fun bnd item var -> + match item with + | Topdef (name, typ, e) -> + Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var var + format_typ typ (format_expr ctx) e; + String.Map.add (Format.asprintf "%a" TopdefName.format name) var bnd + | ScopeDef (name, body) -> + let scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" + format_var var format_var scope_input_var format_to_module_name + (`Sname body.scope_body_input_struct) format_to_module_name + (`Sname body.scope_body_output_struct) + (format_scope_body_expr ctx) + scope_body_expr; + String.Map.add (Format.asprintf "%a" ScopeName.format name) var bnd) + ~init:String.Map.empty code_items + +let format_scope_exec + (ctx : decl_ctx) + (fmt : Format.formatter) + (bnd : 'm Ast.expr Var.t String.Map.t) + scope_name + scope_body = + let scope_name_str = Format.asprintf "%a" ScopeName.format scope_name in + let scope_var = String.Map.find scope_name_str bnd in + let scope_input = + StructName.Map.find scope_body.scope_body_input_struct ctx.ctx_structs + in + if not (StructField.Map.is_empty scope_input) then + Message.raise_error + "The scope @{%s@} defines input variables.@ This is not supported \ + for a main scope at the moment." + scope_name_str; + Format.pp_open_vbox fmt 2; + Format.pp_print_string fmt "let () ="; + (* TODO: dump the output using yojson that should be already available from + the runtime *) + Format.pp_print_space fmt (); + format_var fmt scope_var; + Format.pp_print_space fmt (); + Format.pp_print_string fmt "()"; + Format.pp_close_box fmt () + +let format_module_registration + fmt + (bnd : 'm Ast.expr Var.t String.Map.t) + modname = + Format.pp_open_vbox fmt 2; + Format.pp_print_string fmt "let () ="; + Format.pp_print_space fmt (); + Format.pp_open_hvbox fmt 2; + Format.fprintf fmt "Runtime_ocaml.Runtime.register_module \"%a\"" + ModuleName.format modname; + Format.pp_print_space fmt (); + Format.pp_open_vbox fmt 2; + Format.pp_print_string fmt "[ "; + Format.pp_print_seq + ~pp_sep:(fun fmt () -> + Format.pp_print_char fmt ';'; + Format.pp_print_cut fmt ()) + (fun fmt (id, var) -> + Format.fprintf fmt "@[%S,@ Obj.repr %a@]" id format_var var) + fmt (String.Map.to_seq bnd); + Format.pp_close_box fmt (); + Format.pp_print_char fmt ' '; + Format.pp_print_string fmt "]"; + Format.pp_print_space fmt (); + Format.pp_print_string fmt "\"todo-module-hash\""; + Format.pp_close_box fmt (); + Format.pp_close_box fmt (); + Format.pp_print_newline fmt () + +let header = + {ocaml| + (** This file has been generated by the Catala compiler, do not edit! *) + + open Runtime_ocaml.Runtime + + [@@@ocaml.warning "-4-26-27-32-41-42"] + + |ocaml} + +let format_program + (fmt : Format.formatter) + ?exec_scope + (p : 'm Ast.program) + (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + Format.pp_print_string fmt header; + format_ctx type_ordering fmt p.decl_ctx; + let bnd = format_code_items p.decl_ctx fmt p.code_items in + Format.pp_print_newline fmt (); + match p.module_name, exec_scope with + | Some modname, None -> format_module_registration fmt bnd modname + | None, Some scope_name -> + let scope_body = Program.get_scope_body p scope_name in + format_scope_exec p.decl_ctx fmt bnd scope_name scope_body + | None, None -> () + | Some _, Some _ -> + Message.raise_error + "OCaml generation: both module registration and top-level scope \ + execution where required at the same time." From d9c2ec8d39dc575da95d59747cd50b04f28e622f Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Wed, 13 Dec 2023 18:38:57 +0100 Subject: [PATCH 2/9] adding the coq backend to the driver --- compiler/driver.ml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/compiler/driver.ml b/compiler/driver.ml index 31281888a..b3005f930 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -614,6 +614,34 @@ module Commands = struct $ Cli.Flags.ex_scope_opt $ Cli.Flags.check_invariants) + let coq options includes output optimize check_invariants ex_scope_opt = + let prg, type_ordering = + Passes.dcalc options ~includes ~optimize ~check_invariants + ~typed:Expr.typed + in + let _output_file, with_output = + get_output_format options ~ext:".v" output + in + with_output + @@ fun fmt -> + Message.emit_debug "Compiling program into Coq..."; + Message.emit_debug "Writing to %s..." (Option.value ~default:"stdout" None); + let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in + Dcalc.To_coq.format_program fmt prg ?exec_scope type_ordering + + let coq_cmd = + Cmd.v + (Cmd.info "coq" + ~doc:"Generates an OCaml translation of the Catala program.") + Term.( + const coq + $ Cli.Flags.Global.options + $ Cli.Flags.include_dirs + $ Cli.Flags.output + $ Cli.Flags.optimize + $ Cli.Flags.check_invariants + $ Cli.Flags.ex_scope_opt) + let proof options includes @@ -953,6 +981,7 @@ module Commands = struct interpret_lcalc_cmd; typecheck_cmd; proof_cmd; + coq_cmd; ocaml_cmd; python_cmd; r_cmd; From 1395c976a6be2b2c5654e2367e0b177251bdcfde Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Wed, 13 Dec 2023 18:39:39 +0100 Subject: [PATCH 3/9] modification to produce code that looks like catala's ast in the coq development --- compiler/dcalc/to_coq.ml | 306 ++++++++++++++++++--------------------- 1 file changed, 138 insertions(+), 168 deletions(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index 7c5301968..bfc2c7458 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -20,9 +20,8 @@ open Ast let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit = match Mark.remove l with - | LBool b -> Print.lit fmt (LBool b) - | LInt i -> - Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) + | LBool b -> Format.fprintf fmt "Bool@ %s" (Bool.to_string b) + | LInt i -> Format.fprintf fmt "Int@ %s" (Runtime.integer_to_string i) | LUnit -> Print.lit fmt LUnit | LRat i -> Format.fprintf fmt "decimal_of_string \"%a\"" Print.lit (LRat i) | LMoney e -> @@ -182,15 +181,40 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = | _ -> Format.fprintf fmt "unembeddable" let typ_needs_parens (e : typ) : bool = - match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false - -let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = - let format_typ_with_parens (fmt : Format.formatter) (t : typ) = - if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t - else Format.fprintf fmt "%a" format_typ t - in + match Mark.remove e with + | TDefault _ | TArrow _ | TArray _ -> true + | _ -> false + +let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = + Format.fprintf fmt + (match l with + | TUnit -> "TUnit" + | TBool -> "TBool" + | TInt -> "TInteger" + | TRat -> "TDecimal" + | TMoney -> "TMoney" + | TDuration -> "TDuration" + | TDate -> "TDate") + +let rec format_nested_arrows + (fmt : Format.formatter) + ((args, res) : typ list * typ) : unit = + match args with + | [] -> Format.fprintf fmt "%a" format_typ_with_parens res + | [arg] -> + Format.fprintf fmt "@[TFun %a %a@]" format_typ_with_parens arg + format_typ_with_parens res + | arg :: args -> + Format.fprintf fmt "@[TFun %a %a@]" format_typ_with_parens arg + format_nested_arrows (args, res) + +and format_typ_with_parens (fmt : Format.formatter) (t : typ) = + if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t + else Format.fprintf fmt "%a" format_typ t + +and format_typ (fmt : Format.formatter) (typ : typ) : unit = match Mark.remove typ with - | TLit l -> Format.fprintf fmt "%a" Print.tlit l + | TLit l -> Format.fprintf fmt "%a" format_tlit l | TTuple ts -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list @@ -204,11 +228,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = | TDefault t -> format_typ fmt t | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ") - format_typ_with_parens) - (t1 @ [t2]) + Format.fprintf fmt "@[%a@]" format_nested_arrows (t1, t2) | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "_" | TClosureEnv -> failwith "unimplemented!" @@ -234,57 +254,43 @@ let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = let needs_parens (e : 'm expr) : bool = match Mark.remove e with - | EApp { f = EAbs _, _; _ } - | ELit (LBool _ | LUnit) - | EVar _ | ETuple _ | EOp _ -> - false + | EApp { f = EAbs _, _; _ } | EVar _ | ETuple _ | EOp _ -> false | _ -> true -let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : - unit = +let find_index p = + let rec aux i = function + | [] -> None + | a :: l -> if p a then Some i else aux (i + 1) l + in + aux 0 + +let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = + let dctx, debrin = ctx in + let format_expr' = format_expr in let format_expr = format_expr ctx in let format_with_parens (fmt : Format.formatter) (e : 'm expr) = if needs_parens e then Format.fprintf fmt "(%a)" format_expr e else Format.fprintf fmt "%a" format_expr e in match Mark.remove e with - | EVar v -> Format.fprintf fmt "%a" format_var v - | EExternal { name } -> ( - (* FIXME: this is wrong in general !! We assume the idents exposed by the - module depend only on the original name, while they actually get through - Bindlib and may have been renamed. A correct implem could use the runtime - registration used by the interpreter, but that would be distasteful and - incur a penalty ; or we would need to reproduce the same structure as in - the original module to ensure that bindlib performs the exact same - renamings ; or finally we could normalise the names at generation time - (either at toplevel or in a dedicated submodule ?) *) - let path = - match Mark.remove name with - | External_value name -> TopdefName.path name - | External_scope name -> ScopeName.path name - in - Uid.Path.format fmt path; - match Mark.remove name with - | External_value name -> - format_var_str fmt (Mark.remove (TopdefName.get_info name)) - | External_scope name -> - format_var_str fmt (Mark.remove (ScopeName.get_info name))) - | ETuple es -> - Format.fprintf fmt "@[(%a)@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) - es - | EStruct { name = s; fields = es } -> - if StructField.Map.is_empty es then Format.fprintf fmt "()" - else - Format.fprintf fmt "{@[%a@]}" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt (struct_field, e) -> - Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name - (Some s, struct_field) format_with_parens e)) - (StructField.Map.bindings es) + | EVar v -> ( + match find_index (Bindlib.eq_vars v) debrin with + | Some i -> Format.fprintf fmt "(Var@ %i@ (* %a *))" i format_var v + | None -> Format.fprintf fmt "(Var@ ???@ (* %a *))" format_var v) + | EExternal _ -> assert false + | ETuple _es -> assert false + | EStruct { name = s; fields = es } -> begin + match StructField.Map.bindings es with + | [] -> Format.fprintf fmt "%a" format_lit (LUnit, Pos.no_pos) + | [(n, f)] -> + Format.fprintf fmt "(* { %a = *) %a (* } *)" format_struct_field_name + (Some s, n) format_with_parens f + | _ -> assert false + (* Format.fprintf fmt "{@[%a@]}" (Format.pp_print_list ~pp_sep:(fun + fmt () -> Format.fprintf fmt ";@ ") (fun fmt (struct_field, e) -> + Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name (Some + s, struct_field) format_with_parens e)) (StructField.Map.bindings es) *) + end | EArray es -> Format.fprintf fmt "@[[|%a|]@]" (Format.pp_print_list @@ -325,18 +331,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : (* should not happen *)) e)) (EnumConstructor.Map.bindings cases) - | ELit l -> Format.fprintf fmt "%a" format_lit (Mark.add (Expr.pos e) l) - | EApp { f = EAbs { binder; tys }, _; args } -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in - let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in - Format.fprintf fmt "(%a%a)" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "") - (fun fmt (x, tau, arg) -> - Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ in@\n" - format_var x format_typ tau format_with_parens arg)) - xs_tau_arg format_with_parens body + | ELit l -> + Format.fprintf fmt "@[Value (%a)@]" format_lit + (Mark.add (Expr.pos e) l) + (* | EApp { f = EAbs { binder; tys }, _; args } -> let xs, body = + Bindlib.unmbind binder in let xs_tau = List.map2 (fun x tau -> x, tau) + (Array.to_list xs) tys in let xs_tau_arg = List.map2 (fun (x, tau) arg -> + x, tau, arg) xs_tau args in Format.fprintf fmt "(%a%a)" + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "") (fun + fmt (x, tau, arg) -> Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ + in@\n" format_var x format_typ tau format_with_parens arg)) xs_tau_arg + format_with_parens body *) | EAbs { binder; tys } -> let xs, body = Bindlib.unmbind binder in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in @@ -344,44 +349,11 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") (fun fmt (x, tau) -> - Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ tau)) - xs_tau format_expr body - | EApp - { - f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _; - args = [arg]; - } - when Cli.globals.trace -> - Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info - format_with_parens f format_with_parens arg - | EApp - { f = EOp { op = Log (VarDef var_def_info, info); _ }, _; args = [arg1] } - when Cli.globals.trace -> - Format.fprintf fmt - "(log_variable_definition@ %a@ {io_input=%s;@ io_output=%b}@ (%a)@ %a)" - format_uid_list info - (match var_def_info.log_io_input with - | NoInput -> "NoInput" - | OnlyInput -> "OnlyInput" - | Reentrant -> "Reentrant") - var_def_info.log_io_output typ_embedding_name - (var_def_info.log_typ, Pos.no_pos) - format_with_parens arg1 - | EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] } - when Cli.globals.trace -> - let pos = Expr.mark_pos m in - Format.fprintf fmt - "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ \ - start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" - (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) - (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list - (Pos.get_law_info pos) format_with_parens arg1 - | EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] } - when Cli.globals.trace -> - Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info - format_with_parens arg1 - | EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } -> - Format.fprintf fmt "%a" format_with_parens arg1 + Format.fprintf fmt "@[((*%a:*)@ %a)@]" format_var x format_typ + tau)) + xs_tau + (format_expr' (dctx, List.append (Array.to_list xs) debrin)) + body | EApp { f; args } -> Format.fprintf fmt "@[%a@ %a@]" format_with_parens f (Format.pp_print_list @@ -393,23 +365,18 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : "@[ if@ @[%a@]@ then@ @[%a@]@ else@ @[%a@]@]" format_with_parens cond format_with_parens etrue format_with_parens efalse | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op) - | EAssert e' -> - Format.fprintf fmt - "@[if@ %a@ then@ ()@ else@ raise (AssertionFailed @[{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ end_line=%d; \ - end_column=%d;@ law_headings=%a}@])@]" - format_with_parens e' - (Pos.get_file (Expr.pos e')) - (Pos.get_start_line (Expr.pos e')) - (Pos.get_start_column (Expr.pos e')) - (Pos.get_end_line (Expr.pos e')) - (Pos.get_end_column (Expr.pos e')) - format_string_list - (Pos.get_law_info (Expr.pos e')) - | EEmptyError -> assert false - | EDefault _ -> assert false - | EPureDefault _ -> assert false - | EErrorOnEmpty _ -> assert false + | EAssert _ -> Format.fprintf fmt "@[Value@ VUnit (* assert *) @]" + | EEmptyError -> Format.fprintf fmt "Empty" + | EDefault { excepts; just; cons } -> + Format.fprintf fmt "@[Default@ @[[%a]@]@ %a@ %a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + format_with_parens) + excepts format_with_parens just format_with_parens cons + | EPureDefault e' -> + Format.fprintf fmt "@[PureDefault@ %a@]" format_with_parens e' + | EErrorOnEmpty e' -> + Format.fprintf fmt "@[ErrorOnEmpty@ %a@]" format_with_parens e' | _ -> . let format_struct_embedding @@ -460,11 +427,19 @@ let format_ctx (fmt : Format.formatter) (ctx : decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = - if StructField.Map.is_empty struct_fields then + match StructField.Map.bindings struct_fields with + | [] -> Format.fprintf fmt - "@[module %a = struct@\n@[type t = unit@]@]@\nend@\n" + "@[(* module %a = struct@\n@[type t = unit@]@]@\nend*)@\n" format_to_module_name (`Sname struct_name) - else + | [(_n, t)] -> + Format.fprintf fmt + "@[(* module %a = struct@\n@[type t = %a@]@]@\nend*)@\n" + format_to_module_name (`Sname struct_name) format_typ_with_parens t + | _ -> + Message.emit_warning + "Structure %a has multiple fields. This might not be supported by coq." + format_to_module_name (`Sname struct_name); Format.fprintf fmt "@[@[module %a = struct@ @[type t = {@,\ %a@;\ @@ -475,22 +450,9 @@ let format_ctx (fun fmt (struct_field, struct_field_type) -> Format.fprintf fmt "@[%a:@ %a@]" format_struct_field_name (None, struct_field) format_typ struct_field_type)) - (StructField.Map.bindings struct_fields); - if Cli.globals.trace then - format_struct_embedding fmt (struct_name, struct_fields) - in - let format_enum_decl fmt (enum_name, enum_cons) = - Format.fprintf fmt - "module %a = struct@\n@[@ type t =@\n@[ %a@]@\nend@]@\n" - format_to_module_name (`Ename enum_name) - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (enum_cons, enum_cons_type) -> - Format.fprintf fmt "@[| %a@ of@ %a@]" format_enum_cons_name - enum_cons format_typ enum_cons_type)) - (EnumConstructor.Map.bindings enum_cons); - if Cli.globals.trace then format_enum_embedding fmt (enum_name, enum_cons) + (StructField.Map.bindings struct_fields) in + let format_enum_decl _fmt (_enum_name, _enum_cons) = assert false in let is_in_type_ordering s = List.exists (fun struct_or_enum -> @@ -526,47 +488,54 @@ let rename_vars e = (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true ~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) -let format_expr ctx fmt e = format_expr ctx fmt (rename_vars e) +let format_expr ctx fmt e = format_expr ctx fmt e let rec format_scope_body_expr - (ctx : decl_ctx) + ctx (fmt : Format.formatter) (scope_lets : 'm Ast.expr scope_body_expr) : unit = + let dctx, debrin = ctx in match scope_lets with - | Result e -> format_expr ctx fmt e + | Result e -> format_expr (ctx, debrin) fmt e | ScopeLet scope_let -> let scope_let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in - Format.fprintf fmt "@[let %a: %a = %a in@]@\n%a" format_var + Format.fprintf fmt "@[let %a: %a = %a@ in@]@\n%a" format_var scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) scope_let.scope_let_expr - (format_scope_body_expr ctx) + (format_scope_body_expr (dctx, scope_let_var :: debrin)) scope_let_next let format_code_items (ctx : decl_ctx) (fmt : Format.formatter) (code_items : 'm Ast.expr code_item_list) : 'm Ast.expr Var.t String.Map.t = - Scope.fold_left - ~f:(fun bnd item var -> - match item with - | Topdef (name, typ, e) -> - Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var var - format_typ typ (format_expr ctx) e; - String.Map.add (Format.asprintf "%a" TopdefName.format name) var bnd - | ScopeDef (name, body) -> - let scope_input_var, scope_body_expr = - Bindlib.unbind body.scope_body_expr - in - Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" - format_var var format_var scope_input_var format_to_module_name - (`Sname body.scope_body_input_struct) format_to_module_name - (`Sname body.scope_body_output_struct) - (format_scope_body_expr ctx) - scope_body_expr; - String.Map.add (Format.asprintf "%a" ScopeName.format name) var bnd) - ~init:String.Map.empty code_items + let _, res = + Scope.fold_left + ~f:(fun (debrin, bnd) item var -> + match item with + | Topdef _ (* name, typ, e *) -> + assert false + (* Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var + var format_typ typ (format_expr (ctx, [])) e; String.Map.add + (Format.asprintf "%a" TopdefName.format name) var bnd *) + | ScopeDef (name, body) -> + let scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" + format_var var format_var scope_input_var format_to_module_name + (`Sname body.scope_body_input_struct) format_to_module_name + (`Sname body.scope_body_output_struct) + (format_scope_body_expr (ctx, scope_input_var :: debrin)) + scope_body_expr; + ( var :: debrin, + String.Map.add (Format.asprintf "%a" ScopeName.format name) var bnd + )) + ~init:([], String.Map.empty) code_items + in + res let format_scope_exec (ctx : decl_ctx) @@ -646,6 +615,7 @@ let format_program | Some modname, None -> format_module_registration fmt bnd modname | None, Some scope_name -> let scope_body = Program.get_scope_body p scope_name in + format_scope_exec p.decl_ctx fmt bnd scope_name scope_body | None, None -> () | Some _, Some _ -> From 079b36e777878756437d7af98fc06acc47911ea4 Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:47:25 +0100 Subject: [PATCH 4/9] coq printer: struct are now more transparent lambda are correctly printed Let In are present to enable better readability. removed dead code of interface with modules scopes are chained --- compiler/dcalc/to_coq.ml | 269 ++++++++++++--------------------------- 1 file changed, 80 insertions(+), 189 deletions(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index bfc2c7458..d636c9108 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -22,7 +22,7 @@ let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit = match Mark.remove l with | LBool b -> Format.fprintf fmt "Bool@ %s" (Bool.to_string b) | LInt i -> Format.fprintf fmt "Int@ %s" (Runtime.integer_to_string i) - | LUnit -> Print.lit fmt LUnit + | LUnit -> Format.fprintf fmt "VUnit" | LRat i -> Format.fprintf fmt "decimal_of_string \"%a\"" Print.lit (LRat i) | LMoney e -> Format.fprintf fmt "money_of_cents_string@ \"%s\"" @@ -166,20 +166,6 @@ let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : (avoid_keywords (String.to_ascii (Format.asprintf "%a" EnumConstructor.format v))) -let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = - match Mark.remove ty with - | TLit TUnit -> Format.fprintf fmt "embed_unit" - | TLit TBool -> Format.fprintf fmt "embed_bool" - | TLit TInt -> Format.fprintf fmt "embed_integer" - | TLit TRat -> Format.fprintf fmt "embed_decimal" - | TLit TMoney -> Format.fprintf fmt "embed_money" - | TLit TDate -> Format.fprintf fmt "embed_date" - | TLit TDuration -> Format.fprintf fmt "embed_duration" - | TStruct s_name -> Format.fprintf fmt "embed_%a" format_struct_name s_name - | TEnum e_name -> Format.fprintf fmt "embed_%a" format_enum_name e_name - | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty - | _ -> Format.fprintf fmt "unembeddable" - let typ_needs_parens (e : typ) : bool = match Mark.remove e with | TDefault _ | TArrow _ | TArray _ -> true @@ -225,7 +211,7 @@ and format_typ (fmt : Format.formatter) (typ : typ) : unit = | TOption t -> Format.fprintf fmt "@[(%a)@] %a.t" format_typ_with_parens t format_to_module_name (`Ename Expr.option_enum) - | TDefault t -> format_typ fmt t + | TDefault t -> Format.fprintf fmt "@[TDefault %a@]" format_typ t | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) | TArrow (t1, t2) -> Format.fprintf fmt "@[%a@]" format_nested_arrows (t1, t2) @@ -253,9 +239,7 @@ let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = format_var_str fmt (Bindlib.name_of v) let needs_parens (e : 'm expr) : bool = - match Mark.remove e with - | EApp { f = EAbs _, _; _ } | EVar _ | ETuple _ | EOp _ -> false - | _ -> true + match Mark.remove e with EVar _ | ETuple _ | EOp _ -> false | _ -> true let find_index p = let rec aux i = function @@ -272,6 +256,10 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = if needs_parens e then Format.fprintf fmt "(%a)" format_expr e else Format.fprintf fmt "%a" format_expr e in + let format_with_parens' ctx (fmt : Format.formatter) (e : 'm expr) = + if needs_parens e then Format.fprintf fmt "(%a)" (format_expr' ctx) e + else Format.fprintf fmt "%a" (format_expr' ctx) e + in match Mark.remove e with | EVar v -> ( match find_index (Bindlib.eq_vars v) debrin with @@ -279,58 +267,33 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = | None -> Format.fprintf fmt "(Var@ ???@ (* %a *))" format_var v) | EExternal _ -> assert false | ETuple _es -> assert false - | EStruct { name = s; fields = es } -> begin + | EStructAccess { e; _ } -> Format.fprintf fmt "%a" format_with_parens e + | EStruct { name = _s; fields = es } -> begin match StructField.Map.bindings es with | [] -> Format.fprintf fmt "%a" format_lit (LUnit, Pos.no_pos) - | [(n, f)] -> - Format.fprintf fmt "(* { %a = *) %a (* } *)" format_struct_field_name - (Some s, n) format_with_parens f + | [(_n, f)] -> + Format.fprintf fmt "%a" format_with_parens f + (* Format.fprintf fmt "(* { %a = *) %a (* } *)" format_struct_field_name + (Some s, n) format_with_parens f *) | _ -> assert false (* Format.fprintf fmt "{@[%a@]}" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") (fun fmt (struct_field, e) -> Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name (Some s, struct_field) format_with_parens e)) (StructField.Map.bindings es) *) end - | EArray es -> - Format.fprintf fmt "@[[|%a|]@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) - es - | ETupleAccess { e; index; size } -> - Format.fprintf fmt "let@ %a@ = %a@ in@ x" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt i -> - Format.pp_print_string fmt (if i = index then "x" else "_"))) - (List.init size Fun.id) format_with_parens e - | EStructAccess { e; field; name } -> - Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name - (Some name, field) - | EInj { e; cons; name } -> - Format.fprintf fmt "@[%a.%a@ %a@]" format_to_module_name - (`Ename name) format_enum_cons_name cons format_with_parens e - | EMatch { e; cases; name } -> - Format.fprintf fmt "@[@[match@ %a@]@ with@\n| %a@]" - format_with_parens e - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ") - (fun fmt (c, e) -> - Format.fprintf fmt "@[%a.%a %a@]" format_to_module_name - (`Ename name) format_enum_cons_name c - (fun fmt e -> - match Mark.remove e with - | EAbs { binder; _ } -> - let xs, body = Bindlib.unmbind binder in - Format.fprintf fmt "%a ->@ %a" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@,") - (fun fmt x -> Format.fprintf fmt "%a" format_var x)) - (Array.to_list xs) format_with_parens body - | _ -> assert false - (* should not happen *)) - e)) - (EnumConstructor.Map.bindings cases) + | EArray _es -> + assert false + (* Format.fprintf fmt "@[[|%a|]@]" (Format.pp_print_list ~pp_sep:(fun + fmt () -> Format.fprintf fmt ";@ ") (fun fmt e -> Format.fprintf fmt "%a" + format_with_parens e)) es *) + | ETupleAccess _ -> + assert false + (* Format.fprintf fmt "let@ %a@ = %a@ in@ x" (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt i -> + Format.pp_print_string fmt (if i = index then "x" else "_"))) (List.init + size Fun.id) format_with_parens e *) + | EInj _ -> assert false + | EMatch _ -> assert false | ELit l -> Format.fprintf fmt "@[Value (%a)@]" format_lit (Mark.add (Expr.pos e) l) @@ -345,17 +308,31 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = | EAbs { binder; tys } -> let xs, body = Bindlib.unmbind binder in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in - Format.fprintf fmt "@[fun@ %a ->@ %a@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - (fun fmt (x, tau) -> - Format.fprintf fmt "@[((*%a:*)@ %a)@]" format_var x format_typ - tau)) + + List.fold_right + (fun (x, tau) pp (fmt : Format.formatter) () -> + Format.fprintf fmt "@[Lam (* %a: %a -> *)@ %a@]" format_var x + format_typ tau pp ()) xs_tau - (format_expr' (dctx, List.append (Array.to_list xs) debrin)) - body + (fun fmt () -> + format_with_parens' + (dctx, List.append (Array.to_list xs) debrin) + fmt body) + fmt () + | EApp { f = EAbs { binder; tys = [ty] }, _; args = [e1] } + when Bindlib.mbinder_arity binder = 1 -> + let xs, e2 = Bindlib.unmbind binder in + let x = xs.(0) in + + Format.fprintf fmt "@[Let (* %a: %a = *)@ %a@]@ In@\n%a" format_var x + format_typ ty format_expr e1 + (format_expr' (dctx, x :: debrin)) + e2 + | EApp { f = EOp { op; _ }, _; args = [e1; e2] } -> + Format.fprintf fmt "@[Binop %s@ %a@ %a@]" (Operator.name op) + format_with_parens e1 format_with_parens e2 | EApp { f; args } -> - Format.fprintf fmt "@[%a@ %a@]" format_with_parens f + Format.fprintf fmt "@[App@ %a@ %a@]" format_with_parens f (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) @@ -374,54 +351,11 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = format_with_parens) excepts format_with_parens just format_with_parens cons | EPureDefault e' -> - Format.fprintf fmt "@[PureDefault@ %a@]" format_with_parens e' + Format.fprintf fmt "@[DefaultPure@ %a@]" format_with_parens e' | EErrorOnEmpty e' -> Format.fprintf fmt "@[ErrorOnEmpty@ %a@]" format_with_parens e' | _ -> . -let format_struct_embedding - (fmt : Format.formatter) - ((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) = - if StructField.Map.is_empty struct_fields then - Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" - format_struct_name struct_name format_to_module_name (`Sname struct_name) - else - Format.fprintf fmt - "@[let embed_%a (x: %a.t) : runtime_value =@ Struct([\"%a\"],@ \ - @[[%a]@])@]@\n\ - @\n" - format_struct_name struct_name format_to_module_name (`Sname struct_name) - StructName.format struct_name - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") - (fun fmt (struct_field, struct_field_type) -> - Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format - struct_field typ_embedding_name struct_field_type - format_struct_field_name - (Some struct_name, struct_field))) - (StructField.Map.bindings struct_fields) - -let format_enum_embedding - (fmt : Format.formatter) - ((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) = - if EnumConstructor.Map.is_empty enum_cases then - Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" - format_to_module_name (`Ename enum_name) format_enum_name enum_name - else - Format.fprintf fmt - "@[@[let embed_%a@ @[(x:@ %a.t)@]@ : runtime_value \ - =@]@ Enum([\"%a\"],@ @[match x with@ %a@])@]@\n\ - @\n" - format_enum_name enum_name format_to_module_name (`Ename enum_name) - EnumName.format enum_name - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (enum_cons, enum_cons_type) -> - Format.fprintf fmt "@[| %a x ->@ (\"%a\", %a x)@]" - format_enum_cons_name enum_cons EnumConstructor.format enum_cons - typ_embedding_name enum_cons_type)) - (EnumConstructor.Map.bindings enum_cases) - let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) @@ -501,7 +435,7 @@ let rec format_scope_body_expr let scope_let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in - Format.fprintf fmt "@[let %a: %a = %a@ in@]@\n%a" format_var + Format.fprintf fmt "@[Let (* %a: %a = *)@ %a@]@ In@\n%a" format_var scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) scope_let.scope_let_expr (format_scope_body_expr (dctx, scope_let_var :: debrin)) @@ -524,12 +458,33 @@ let format_code_items let scope_input_var, scope_body_expr = Bindlib.unbind body.scope_body_expr in - Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" - format_var var format_var scope_input_var format_to_module_name - (`Sname body.scope_body_input_struct) format_to_module_name - (`Sname body.scope_body_output_struct) + + (* "@[Lam (* %a: %a -> *)@ %a@]" *) + let _ = "@\n@\n@[Let (* %a: %a = *)@\n%a@]@ In@\n%a" in + let _ = "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" in + + let scope_type : typ = + ( TArrow + ( [TStruct body.scope_body_input_struct, Pos.no_pos], + (TStruct body.scope_body_output_struct, Pos.no_pos) ), + Pos.no_pos ) + in + Format.fprintf fmt + "@\n\ + @\n\ + @[Let (* %a: %a = *)@\n\ + @[Lam (* %a: %a -> *)@ %a@]@]@ In@\n" + format_var var format_typ scope_type format_var scope_input_var + format_typ + (TStruct body.scope_body_input_struct, Pos.no_pos) (format_scope_body_expr (ctx, scope_input_var :: debrin)) scope_body_expr; + (* Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t + =@\n%a@]@ In" format_var var format_var scope_input_var + format_to_module_name (`Sname body.scope_body_input_struct) + format_to_module_name (`Sname body.scope_body_output_struct) + (format_scope_body_expr (ctx, scope_input_var :: debrin)) + scope_body_expr; *) ( var :: debrin, String.Map.add (Format.asprintf "%a" ScopeName.format name) var bnd )) @@ -537,61 +492,6 @@ let format_code_items in res -let format_scope_exec - (ctx : decl_ctx) - (fmt : Format.formatter) - (bnd : 'm Ast.expr Var.t String.Map.t) - scope_name - scope_body = - let scope_name_str = Format.asprintf "%a" ScopeName.format scope_name in - let scope_var = String.Map.find scope_name_str bnd in - let scope_input = - StructName.Map.find scope_body.scope_body_input_struct ctx.ctx_structs - in - if not (StructField.Map.is_empty scope_input) then - Message.raise_error - "The scope @{%s@} defines input variables.@ This is not supported \ - for a main scope at the moment." - scope_name_str; - Format.pp_open_vbox fmt 2; - Format.pp_print_string fmt "let () ="; - (* TODO: dump the output using yojson that should be already available from - the runtime *) - Format.pp_print_space fmt (); - format_var fmt scope_var; - Format.pp_print_space fmt (); - Format.pp_print_string fmt "()"; - Format.pp_close_box fmt () - -let format_module_registration - fmt - (bnd : 'm Ast.expr Var.t String.Map.t) - modname = - Format.pp_open_vbox fmt 2; - Format.pp_print_string fmt "let () ="; - Format.pp_print_space fmt (); - Format.pp_open_hvbox fmt 2; - Format.fprintf fmt "Runtime_ocaml.Runtime.register_module \"%a\"" - ModuleName.format modname; - Format.pp_print_space fmt (); - Format.pp_open_vbox fmt 2; - Format.pp_print_string fmt "[ "; - Format.pp_print_seq - ~pp_sep:(fun fmt () -> - Format.pp_print_char fmt ';'; - Format.pp_print_cut fmt ()) - (fun fmt (id, var) -> - Format.fprintf fmt "@[%S,@ Obj.repr %a@]" id format_var var) - fmt (String.Map.to_seq bnd); - Format.pp_close_box fmt (); - Format.pp_print_char fmt ' '; - Format.pp_print_string fmt "]"; - Format.pp_print_space fmt (); - Format.pp_print_string fmt "\"todo-module-hash\""; - Format.pp_close_box fmt (); - Format.pp_close_box fmt (); - Format.pp_print_newline fmt () - let header = {ocaml| (** This file has been generated by the Catala compiler, do not edit! *) @@ -607,18 +507,9 @@ let format_program ?exec_scope (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + let _ = exec_scope in Format.pp_print_string fmt header; format_ctx type_ordering fmt p.decl_ctx; - let bnd = format_code_items p.decl_ctx fmt p.code_items in - Format.pp_print_newline fmt (); - match p.module_name, exec_scope with - | Some modname, None -> format_module_registration fmt bnd modname - | None, Some scope_name -> - let scope_body = Program.get_scope_body p scope_name in - - format_scope_exec p.decl_ctx fmt bnd scope_name scope_body - | None, None -> () - | Some _, Some _ -> - Message.raise_error - "OCaml generation: both module registration and top-level scope \ - execution where required at the same time." + let _ = format_code_items p.decl_ctx fmt p.code_items in + Format.fprintf fmt "(Var 0)"; + Format.pp_print_newline fmt () From ec2166513577d89bb22e6c5ae4fc879bf135d752 Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:47:33 +0100 Subject: [PATCH 5/9] typo correction --- compiler/shared_ast/program.mli | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index b527ba8fa..521ab7a0c 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -41,9 +41,8 @@ val get_scope_body : val untype : ('a any, _) gexpr program -> ('a, untyped) gexpr program val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed -(** Usage: [build_whole_program_expr program main_scope] builds an expression - corresponding to the main program and returning the main scope as a - function. *) +(** Usage: [to_expr program main_scope] builds an expression corresponding to + the main program and returning the main scope as a function. *) val equal : (('a any, _) gexpr as 'e) program -> (('a any, _) gexpr as 'e) program -> bool From f94131075de98d44dea66e754fb107ac1c15cb20 Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Fri, 15 Dec 2023 13:47:40 +0100 Subject: [PATCH 6/9] struct type are now unpacked in the coq backend --- compiler/dcalc/to_coq.ml | 98 ++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index d636c9108..49ec74e03 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -183,39 +183,56 @@ let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = | TDate -> "TDate") let rec format_nested_arrows + (ctx : decl_ctx) (fmt : Format.formatter) ((args, res) : typ list * typ) : unit = match args with - | [] -> Format.fprintf fmt "%a" format_typ_with_parens res + | [] -> Format.fprintf fmt "%a" (format_typ_with_parens ctx) res | [arg] -> - Format.fprintf fmt "@[TFun %a %a@]" format_typ_with_parens arg - format_typ_with_parens res + Format.fprintf fmt "@[TFun %a %a@]" + (format_typ_with_parens ctx) + arg + (format_typ_with_parens ctx) + res | arg :: args -> - Format.fprintf fmt "@[TFun %a %a@]" format_typ_with_parens arg - format_nested_arrows (args, res) + Format.fprintf fmt "@[TFun %a %a@]" + (format_typ_with_parens ctx) + arg (format_nested_arrows ctx) (args, res) -and format_typ_with_parens (fmt : Format.formatter) (t : typ) = - if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t - else Format.fprintf fmt "%a" format_typ t +and format_typ_with_parens (ctx : decl_ctx) (fmt : Format.formatter) (t : typ) = + if typ_needs_parens t then Format.fprintf fmt "(%a)" (format_typ ctx) t + else Format.fprintf fmt "%a" (format_typ ctx) t -and format_typ (fmt : Format.formatter) (typ : typ) : unit = +and format_typ (ctx : decl_ctx) (fmt : Format.formatter) (typ : typ) : unit = match Mark.remove typ with | TLit l -> Format.fprintf fmt "%a" format_tlit l | TTuple ts -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") - format_typ_with_parens) + (format_typ_with_parens ctx)) ts - | TStruct s -> Format.fprintf fmt "%a.t" format_to_module_name (`Sname s) + | TStruct s -> begin + match + Option.map StructField.Map.bindings + @@ StructName.Map.find_opt s ctx.ctx_structs + with + | Some [] -> format_typ ctx fmt (TLit TUnit, Pos.no_pos) + | Some [(_n, t)] -> format_typ ctx fmt t + | _ -> + assert + false (* Format.fprintf fmt "%a.t" format_to_module_name (`Sname s) *) + end | TOption t -> - Format.fprintf fmt "@[(%a)@] %a.t" format_typ_with_parens t - format_to_module_name (`Ename Expr.option_enum) - | TDefault t -> Format.fprintf fmt "@[TDefault %a@]" format_typ t - | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) + Format.fprintf fmt "@[(%a)@] %a.t" + (format_typ_with_parens ctx) + t format_to_module_name (`Ename Expr.option_enum) + | TDefault t -> Format.fprintf fmt "@[TDefault %a@]" (format_typ ctx) t + | TEnum _ -> assert false | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a@]" format_nested_arrows (t1, t2) - | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 + Format.fprintf fmt "@[%a@]" (format_nested_arrows ctx) (t1, t2) + | TArray t1 -> + Format.fprintf fmt "@[%a@ array@]" (format_typ_with_parens ctx) t1 | TAny -> Format.fprintf fmt "_" | TClosureEnv -> failwith "unimplemented!" @@ -303,8 +320,8 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = x, tau, arg) xs_tau args in Format.fprintf fmt "(%a%a)" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "") (fun fmt (x, tau, arg) -> Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ - in@\n" format_var x format_typ tau format_with_parens arg)) xs_tau_arg - format_with_parens body *) + in@\n" format_var x (format_typ dctx) tau format_with_parens arg)) + xs_tau_arg format_with_parens body *) | EAbs { binder; tys } -> let xs, body = Bindlib.unmbind binder in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in @@ -312,7 +329,7 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = List.fold_right (fun (x, tau) pp (fmt : Format.formatter) () -> Format.fprintf fmt "@[Lam (* %a: %a -> *)@ %a@]" format_var x - format_typ tau pp ()) + (format_typ dctx) tau pp ()) xs_tau (fun fmt () -> format_with_parens' @@ -325,7 +342,7 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = let x = xs.(0) in Format.fprintf fmt "@[Let (* %a: %a = *)@ %a@]@ In@\n%a" format_var x - format_typ ty format_expr e1 + (format_typ dctx) ty format_expr e1 (format_expr' (dctx, x :: debrin)) e2 | EApp { f = EOp { op; _ }, _; args = [e1; e2] } -> @@ -338,9 +355,8 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = format_with_parens) args | EIfThenElse { cond; etrue; efalse } -> - Format.fprintf fmt - "@[ if@ @[%a@]@ then@ @[%a@]@ else@ @[%a@]@]" - format_with_parens cond format_with_parens etrue format_with_parens efalse + Format.fprintf fmt "@[If@ %a@ %a@ %a@]" format_with_parens cond + format_with_parens etrue format_with_parens efalse | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op) | EAssert _ -> Format.fprintf fmt "@[Value@ VUnit (* assert *) @]" | EEmptyError -> Format.fprintf fmt "Empty" @@ -369,7 +385,9 @@ let format_ctx | [(_n, t)] -> Format.fprintf fmt "@[(* module %a = struct@\n@[type t = %a@]@]@\nend*)@\n" - format_to_module_name (`Sname struct_name) format_typ_with_parens t + format_to_module_name (`Sname struct_name) + (format_typ_with_parens ctx) + t | _ -> Message.emit_warning "Structure %a has multiple fields. This might not be supported by coq." @@ -383,7 +401,7 @@ let format_ctx ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") (fun fmt (struct_field, struct_field_type) -> Format.fprintf fmt "@[%a:@ %a@]" format_struct_field_name - (None, struct_field) format_typ struct_field_type)) + (None, struct_field) (format_typ ctx) struct_field_type)) (StructField.Map.bindings struct_fields) in let format_enum_decl _fmt (_enum_name, _enum_cons) = assert false in @@ -425,18 +443,18 @@ let rename_vars e = let format_expr ctx fmt e = format_expr ctx fmt e let rec format_scope_body_expr - ctx + (ctx : decl_ctx * _ base_gexpr Bindlib.var list) (fmt : Format.formatter) (scope_lets : 'm Ast.expr scope_body_expr) : unit = let dctx, debrin = ctx in match scope_lets with - | Result e -> format_expr (ctx, debrin) fmt e + | Result e -> format_expr (dctx, debrin) fmt e | ScopeLet scope_let -> let scope_let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in Format.fprintf fmt "@[Let (* %a: %a = *)@ %a@]@ In@\n%a" format_var - scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) + scope_let_var (format_typ dctx) scope_let.scope_let_typ (format_expr ctx) scope_let.scope_let_expr (format_scope_body_expr (dctx, scope_let_var :: debrin)) scope_let_next @@ -452,7 +470,7 @@ let format_code_items | Topdef _ (* name, typ, e *) -> assert false (* Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var - var format_typ typ (format_expr (ctx, [])) e; String.Map.add + var (format_typ dctx) typ (format_expr (ctx, [])) e; String.Map.add (Format.asprintf "%a" TopdefName.format name) var bnd *) | ScopeDef (name, body) -> let scope_input_var, scope_body_expr = @@ -473,9 +491,9 @@ let format_code_items "@\n\ @\n\ @[Let (* %a: %a = *)@\n\ - @[Lam (* %a: %a -> *)@ %a@]@]@ In@\n" - format_var var format_typ scope_type format_var scope_input_var - format_typ + @[Lam (* %a: %a -> *)@ (%a)@]@]@ In@\n" + format_var var (format_typ ctx) scope_type format_var + scope_input_var (format_typ ctx) (TStruct body.scope_body_input_struct, Pos.no_pos) (format_scope_body_expr (ctx, scope_input_var :: debrin)) scope_body_expr; @@ -493,23 +511,17 @@ let format_code_items res let header = - {ocaml| - (** This file has been generated by the Catala compiler, do not edit! *) - - open Runtime_ocaml.Runtime - - [@@@ocaml.warning "-4-26-27-32-41-42"] - - |ocaml} + "(** This expression has been generated by the Catala compiler, do not edit! \ + *)" let format_program (fmt : Format.formatter) ?exec_scope (p : 'm Ast.program) - (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + (_type_ordering : Scopelang.Dependency.TVertex.t list) : unit = let _ = exec_scope in Format.pp_print_string fmt header; - format_ctx type_ordering fmt p.decl_ctx; + (* format_ctx type_ordering fmt p.decl_ctx; *) let _ = format_code_items p.decl_ctx fmt p.code_items in Format.fprintf fmt "(Var 0)"; Format.pp_print_newline fmt () From 65474bd92a73e5ee577cf94d96fc5dc726c8ec5c Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Fri, 15 Dec 2023 15:38:09 +0100 Subject: [PATCH 7/9] parenthis in coq --- compiler/dcalc/to_coq.ml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index 49ec74e03..c0ee3b91b 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -167,9 +167,7 @@ let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : (String.to_ascii (Format.asprintf "%a" EnumConstructor.format v))) let typ_needs_parens (e : typ) : bool = - match Mark.remove e with - | TDefault _ | TArrow _ | TArray _ -> true - | _ -> false + match Mark.remove e with TDefault _ | TArray _ -> true | _ -> false let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = Format.fprintf fmt @@ -189,13 +187,13 @@ let rec format_nested_arrows match args with | [] -> Format.fprintf fmt "%a" (format_typ_with_parens ctx) res | [arg] -> - Format.fprintf fmt "@[TFun %a %a@]" + Format.fprintf fmt "@[(TFun %a %a)@]" (format_typ_with_parens ctx) arg (format_typ_with_parens ctx) res | arg :: args -> - Format.fprintf fmt "@[TFun %a %a@]" + Format.fprintf fmt "@[(TFun %a %a)@]" (format_typ_with_parens ctx) arg (format_nested_arrows ctx) (args, res) From da1a6a707a6aab01c8423463c9efbd8d7d75d76c Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:47:56 +0100 Subject: [PATCH 8/9] rebase --- compiler/dcalc/to_coq.ml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index c0ee3b91b..c149d34bf 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -254,7 +254,7 @@ let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = format_var_str fmt (Bindlib.name_of v) let needs_parens (e : 'm expr) : bool = - match Mark.remove e with EVar _ | ETuple _ | EOp _ -> false | _ -> true + match Mark.remove e with EVar _ | ETuple _ | _ -> true let find_index p = let rec aux i = function @@ -334,7 +334,7 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = (dctx, List.append (Array.to_list xs) debrin) fmt body) fmt () - | EApp { f = EAbs { binder; tys = [ty] }, _; args = [e1] } + | EApp { f = EAbs { binder; tys = [ty] }, _; args = [e1] ; _ } when Bindlib.mbinder_arity binder = 1 -> let xs, e2 = Bindlib.unmbind binder in let x = xs.(0) in @@ -343,10 +343,11 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = (format_typ dctx) ty format_expr e1 (format_expr' (dctx, x :: debrin)) e2 - | EApp { f = EOp { op; _ }, _; args = [e1; e2] } -> + | EAppOp { op; args = [e1; e2]; _ } -> Format.fprintf fmt "@[Binop %s@ %a@ %a@]" (Operator.name op) format_with_parens e1 format_with_parens e2 - | EApp { f; args } -> + | EAppOp _ -> assert false + | EApp { f; args; _ } -> Format.fprintf fmt "@[App@ %a@ %a@]" format_with_parens f (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") @@ -355,7 +356,7 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = | EIfThenElse { cond; etrue; efalse } -> Format.fprintf fmt "@[If@ %a@ %a@ %a@]" format_with_parens cond format_with_parens etrue format_with_parens efalse - | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op) + (* | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op) *) | EAssert _ -> Format.fprintf fmt "@[Value@ VUnit (* assert *) @]" | EEmptyError -> Format.fprintf fmt "Empty" | EDefault { excepts; just; cons } -> From 83d6f3b75a398d9934cf47ef6adfbd3f1cbfaba4 Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:41:34 +0100 Subject: [PATCH 9/9] formatting --- compiler/dcalc/to_coq.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/dcalc/to_coq.ml b/compiler/dcalc/to_coq.ml index c149d34bf..db8a6c1c8 100644 --- a/compiler/dcalc/to_coq.ml +++ b/compiler/dcalc/to_coq.ml @@ -334,7 +334,7 @@ let rec format_expr ctx (fmt : Format.formatter) (e : 'm expr) : unit = (dctx, List.append (Array.to_list xs) debrin) fmt body) fmt () - | EApp { f = EAbs { binder; tys = [ty] }, _; args = [e1] ; _ } + | EApp { f = EAbs { binder; tys = [ty] }, _; args = [e1]; _ } when Bindlib.mbinder_arity binder = 1 -> let xs, e2 = Bindlib.unmbind binder in let x = xs.(0) in