diff --git a/src/lib/frontend/translate.ml b/src/lib/frontend/translate.ml index 7455aacdd..0b3f5c803 100644 --- a/src/lib/frontend/translate.ml +++ b/src/lib/frontend/translate.ml @@ -523,8 +523,8 @@ and handle_ty_app ?(update = false) ty_c l = variable. *) let rec apply_ty_substs tysubsts ty = match ty with - | Ty.Tvar { v; _ } -> - Ty.M.find v tysubsts + | Ty.Tvar v -> + Ty.TvMap.find v tysubsts | Text (tyl, hs) -> Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs) @@ -561,9 +561,9 @@ and handle_ty_app ?(update = false) ty_c l = List.fold_left2 ( fun acc tv ty -> match tv with - | Ty.Tvar { v; _ } -> Ty.M.add v ty acc + | Ty.Tvar v -> Ty.TvMap.add v ty acc | _ -> assert false - ) Ty.M.empty args tyl + ) Ty.TvMap.empty args tyl in apply_ty_substs tysubsts ty @@ -1749,7 +1749,7 @@ let make_form name_base f loc ~decl_kind = in assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind @@ -2005,7 +2005,7 @@ let make dloc_file acc stmt = assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in let e = - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind @@ -2026,7 +2026,7 @@ let make dloc_file acc stmt = assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in let e = - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 2b05a3d49..3c5313c7d 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -42,7 +42,7 @@ and term_view = { bind : bind_kind; tag: int; vars : (Ty.t * int) Var.Map.t; (* vars to types and nb of occurences *) - vty : Ty.Svty.t; + vty : Ty.TvSet.t; depth: int; nb_nodes : int; pure : bool; @@ -244,11 +244,11 @@ module Msbt : Map.S with type key = expr Var.Map.t = let compare a b = Var.Map.compare compare a b end) -module Msbty : Map.S with type key = Ty.t Ty.M.t = +module Msbty : Map.S with type key = Ty.subst = Map.Make (struct - type t = Ty.t Ty.M.t - let compare a b = Ty.M.compare Ty.compare a b + type t = Ty.subst + let compare = Ty.compare_subst end) module TSet : Set.S with type elt = expr = @@ -333,10 +333,6 @@ module SmtPrinter = struct | `Forall -> Fmt.pf ppf "forall" | `Exists -> Fmt.pf ppf "exists" - (* This printer follows the convention used to print - type variables in the module [Ty]. *) - let pp_tyvar ppf v = Fmt.pf ppf "A%d" v - let rec pp_main bind ppf { user_trs; main; binders; _ } = if not @@ Var.Map.is_empty binders then Fmt.pf ppf "@[<2>(%a (%a)@, %a@, %a)@]" @@ -348,9 +344,9 @@ module SmtPrinter = struct pp_boxed ppf main and pp_quantified bind ppf q = - if q.toplevel && not @@ Ty.Svty.is_empty q.main.vty then + if q.toplevel && not @@ Ty.TvSet.is_empty q.main.vty then Fmt.pf ppf "@[<2>(par (%a)@, %a)@]" - Fmt.(box @@ iter ~sep:sp Ty.Svty.iter pp_tyvar) q.main.vty + Fmt.(box @@ iter ~sep:sp Ty.TvSet.iter DE.Ty.Var.print) q.main.vty (pp_main bind) q else pp_main bind ppf q @@ -802,7 +798,7 @@ let free_type_vars t = t.vty let is_ground t = Var.Map.is_empty (free_vars t Var.Map.empty) && - Ty.Svty.is_empty (free_type_vars t) + Ty.TvSet.is_empty (free_type_vars t) let size t = t.nb_nodes @@ -876,7 +872,7 @@ let free_vars_non_form s l ty = | _, e::r -> List.fold_left (fun s t -> merge_vars s t.vars) e.vars r let free_type_vars_non_form l ty = - List.fold_left (fun acc t -> Ty.Svty.union acc t.vty) (Ty.vty_of ty) l + List.fold_left (fun acc t -> Ty.TvSet.union acc t.vty) (Ty.vty_of ty) l let is_ite s = match s with | Sy.Op Sy.Tite -> true @@ -960,7 +956,7 @@ let vrai = let res = let nb_nodes = 0 in let vars = Var.Map.empty in - let vty = Ty.Svty.empty in + let vty = Ty.TvSet.empty in let faux = HC.make {f = Sy.False; xs = []; ty = Ty.Tbool; depth = -2; (*smallest depth*) @@ -1040,7 +1036,7 @@ let mk_or f1 f2 is_impl = let d = (max f1.depth f2.depth) in (* the +1 causes regression *) let nb_nodes = f1.nb_nodes + f2.nb_nodes + 1 in let vars = merge_vars f1.vars f2.vars in - let vty = Ty.Svty.union f1.vty f2.vty in + let vty = Ty.TvSet.union f1.vty f2.vty in let pos = HC.make {f=Sy.Form (Sy.F_Clause is_impl); xs=[f1; f2]; ty=Ty.Tbool; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1070,7 +1066,7 @@ let mk_iff f1 f2 = let d = (max f1.depth f2.depth) in (* the +1 causes regression *) let nb_nodes = f1.nb_nodes + f2.nb_nodes + 1 in let vars = merge_vars f1.vars f2.vars in - let vty = Ty.Svty.union f1.vty f2.vty in + let vty = Ty.TvSet.union f1.vty f2.vty in let pos = HC.make {f=Sy.Form Sy.F_Iff; xs=[f1; f2]; ty=Ty.Tbool; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1156,7 +1152,7 @@ let mk_forall_ter = lemma. Otherwise (if not toplevel), the free vtys of the lemma are those of lem.main *) let vty = - if new_q.toplevel then Ty.Svty.empty + if new_q.toplevel then Ty.TvSet.empty else free_type_vars new_q.main in let vars = @@ -1191,7 +1187,7 @@ let no_occur_check v e = not (Var.Map.mem v e.vars) let no_vtys l = - List.for_all (fun e -> Ty.Svty.is_empty e.vty) l + List.for_all (fun e -> Ty.TvSet.is_empty e.vty) l (** smart constructors for literals *) @@ -1356,12 +1352,12 @@ let no_capture_issue s_t binders = end let rec apply_subst_aux (s_t, s_ty) t = - if is_ground t || (Var.Map.is_empty s_t && Ty.M.is_empty s_ty) then t + if is_ground t || (Var.Map.is_empty s_t && Ty.TvMap.is_empty s_ty) then t else let { f; xs; ty; vars; vty; bind; _ } = t in let s_t = Var.Map.filter (fun v _ -> Var.Map.mem v vars) s_t in - let s_ty = Ty.M.filter (fun tvar _ -> Ty.Svty.mem tvar vty) s_ty in - if Var.Map.is_empty s_t && Ty.M.is_empty s_ty then t + let s_ty = Ty.TvMap.filter (fun tvar _ -> Ty.TvSet.mem tvar vty) s_ty in + if Var.Map.is_empty s_t && Ty.TvMap.is_empty s_ty then t else let s = s_t, s_ty in let xs', same = My_list.apply (apply_subst_aux s) xs in @@ -1501,7 +1497,7 @@ and mk_let_aux ({ let_v; let_e; in_e; _ } as x) = let nb_nodes = let_e.nb_nodes + in_e.nb_nodes + 1 (* approx *) in (* do not include free vars in let_sko that have been simplified *) let vars = merge_vars let_e.vars (Var.Map.remove let_v in_e.vars) in - let vty = Ty.Svty.union let_e.vty in_e.vty in + let vty = Ty.TvSet.union let_e.vty in_e.vty in let pos = HC.make {f=Sy.Let; xs=[]; ty; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1524,7 +1520,7 @@ and mk_forall_bis (q : quantified) = let binders = (* ignore binders that are not used in f *) Var.Map.filter (fun v _ -> Var.Map.mem v q.main.vars) q.binders in - if Var.Map.is_empty binders && Ty.Svty.is_empty q.main.vty then q.main + if Var.Map.is_empty binders && Ty.TvSet.is_empty q.main.vty then q.main else let q = {q with binders} in (* Attempt to reduce the number of quantifiers. We try to find a @@ -1582,7 +1578,7 @@ and find_particular_subst = in fun binders trs f -> (* TODO: move the test for `trs` outside. *) - if not (Ty.Svty.is_empty f.vty) || has_hypotheses trs || + if not (Ty.TvSet.is_empty f.vty) || has_hypotheses trs || has_semantic_triggers trs then None @@ -1699,7 +1695,7 @@ let resolution_of_literal a binders free_vty acc = match lit_view a with | Pred(t, _) -> let cond = - Ty.Svty.subset free_vty (free_type_vars t) && + Ty.TvSet.subset free_vty (free_type_vars t) && let vars = free_vars t Var.Map.empty in Var.Map.for_all (fun v _ -> Var.Map.mem v vars) binders in @@ -1781,8 +1777,8 @@ let resolution_triggers ~is_back { kind; main = f; binders; _ } = )cand [] let free_type_vars_as_types e = - Ty.Svty.fold - (fun i z -> Ty.Set.add (Ty.Tvar {Ty.v=i; value = None}) z) + Ty.TvSet.fold + (fun tv z -> Ty.Set.add (Ty.Tvar tv) z) (free_type_vars e) Ty.Set.empty @@ -1806,7 +1802,7 @@ let mk_let let_v let_e in_e = let skolemize { main = f; binders; sko_v; sko_vty; _ } = let print fmt ty = - assert (Ty.Svty.is_empty (Ty.vty_of ty)); + assert (Ty.TvSet.is_empty (Ty.vty_of ty)); Format.fprintf fmt "<%a>" Ty.print ty in let pp_sep_nospace fmt () = Format.fprintf fmt "" in @@ -1928,13 +1924,11 @@ let elim_iff f1 f2 ~with_conj = module Triggers = struct - module Svty = Ty.Svty - (* Set of patterns with their sets of free term and type variables. *) module STRS = Set.Make( struct - type t = expr * Var.Set.t * Svty.t + type t = expr * Var.Set.t * Ty.TvSet.t let compare (t1,_,_) (t2,_,_) = compare t1 t2 end) @@ -2105,7 +2099,7 @@ module Triggers = struct fun l -> unique (List.stable_sort cmp_trig_term_list l) [] - let vty_of_term acc t = Svty.union acc t.vty + let vty_of_term acc t = Ty.TvSet.union acc t.vty let not_pure t = not t.pure @@ -2126,11 +2120,11 @@ module Triggers = struct variables. *) not (List.exists not_pure l) && let s1 = List.fold_left (vars_of_term bv) Var.Set.empty l in - let s2 = List.fold_left vty_of_term Svty.empty l in + let s2 = List.fold_left vty_of_term Ty.TvSet.empty l in (* TODO: we can replace `Var.Set.subset bv s1` by `Var.Seq.equal bv s1`. By construction `s1` is a subset of `bv`. *) - Var.Set.subset bv s1 && Svty.subset vty s2 ) + Var.Set.subset bv s1 && Ty.TvSet.subset vty s2 ) trs (* unused @@ -2142,8 +2136,8 @@ module Triggers = struct if List.exists not_pure l then failwith "If-Then-Else are not allowed in (theory triggers)"; let s1 = List.fold_left (vars_of_term bv) SSet.empty l in - let s2 = List.fold_left vty_of_term Svty.empty l in - if not (Svty.subset vty s2) || not (SSet.subset bv s1) then + let s2 = List.fold_left vty_of_term Ty.TvSet.empty l in + if not (Ty.TvSet.subset vty s2) || not (SSet.subset bv s1) then failwith "Triggers of a theory should contain every quantified \ types and variables.") trs; @@ -2171,7 +2165,7 @@ module Triggers = struct module SLLT = Set.Make( struct - type t = expr list * Var.Set.t * Svty.t + type t = expr list * Var.Set.t * Ty.TvSet.t let compare (a, y1, _) (b, y2, _) = let c = try compare_lists a b compare; 0 with Util.Cmp c -> c in if c <> 0 then c else Var.Set.compare y1 y2 @@ -2222,14 +2216,14 @@ module Triggers = struct let llt, llt_ok = SLLT.fold (fun (l, bv2, vty2) (llt, llt_ok) -> - if Var.Set.subset bv1 bv2 && Svty.subset vty1 vty2 then + if Var.Set.subset bv1 bv2 && Ty.TvSet.subset vty1 vty2 then (* t doesn't bring new vars *) llt, llt_ok else let bv3 = Var.Set.union bv2 bv1 in - let vty3 = Svty.union vty2 vty1 in + let vty3 = Ty.TvSet.union vty2 vty1 in let e = t::l, bv3, vty3 in - if Var.Set.subset bv bv3 && Svty.subset vty vty3 then + if Var.Set.subset bv bv3 && Ty.TvSet.subset vty vty3 then (* The multi-trigger [e] cover all the free variables [bv] and [vty]. *) llt, SLLT.add e llt_ok @@ -2258,17 +2252,17 @@ module Triggers = struct List.exists (fun (_, bv',vty') -> (Var.Set.subset bv bv' && not(Var.Set.equal bv bv') - && Svty.subset vty vty') - || (Svty.subset vty vty' && not(Svty.equal vty vty') + && Ty.TvSet.subset vty vty') + || (Ty.TvSet.subset vty vty' && not(Ty.TvSet.equal vty vty') && Var.Set.subset bv bv') ) l in fun bv_a vty_a l -> let rec simpl_rec acc = function | [] -> acc | ((_, bv, vty) as e)::l -> if strict_subset bv vty l || strict_subset bv vty acc || - (Var.Set.subset bv_a bv && Svty.subset vty_a vty) || + (Var.Set.subset bv_a bv && Ty.TvSet.subset vty_a vty) || (Var.Set.equal (Var.Set.inter bv_a bv) Var.Set.empty && - Svty.equal (Svty.inter vty_a vty) Svty.empty) + Ty.TvSet.equal (Ty.TvSet.inter vty_a vty) Ty.TvSet.empty) then simpl_rec acc l else simpl_rec (e::acc) l in @@ -2294,7 +2288,7 @@ module Triggers = struct and [vtype]. *) let mono = List.filter (fun (_, bv_t, vty_t) -> - Var.Set.subset vterm bv_t && Svty.subset vtype vty_t) trs + Var.Set.subset vterm bv_t && Ty.TvSet.subset vtype vty_t) trs in let trs_v, trs_nv = List.partition (fun (t, _, _) -> is_var t) mono in let base = if menv.Util.triggers_var then trs_nv @ trs_v else trs_nv in @@ -2383,7 +2377,7 @@ module Triggers = struct Var.Map.exists (fun e _ -> Var.Set.mem e bv) bv_lf in let has_tyvar vty vty_lf = - Svty.exists (fun e -> Svty.mem e vty) vty_lf + Ty.TvSet.exists (fun e -> Ty.TvSet.mem e vty) vty_lf in let args_of e lets = match e.bind with @@ -2479,9 +2473,9 @@ module Triggers = struct )terms terms let check_user_triggers f toplevel binders trs0 ~decl_kind = - if Var.Map.is_empty binders && Ty.Svty.is_empty f.vty then trs0 + if Var.Map.is_empty binders && Ty.TvSet.is_empty f.vty then trs0 else - let vtype = if toplevel then f.vty else Ty.Svty.empty in + let vtype = if toplevel then f.vty else Ty.TvSet.empty in let vterm = Var.Map.fold (fun v _ s -> Var.Set.add v s) binders Var.Set.empty in @@ -2498,7 +2492,7 @@ module Triggers = struct filter_good_triggers (vterm, vtype) trs0 let make f binders decl_kind mconf = - if Var.Map.is_empty binders && Ty.Svty.is_empty f.vty then [] + if Var.Map.is_empty binders && Ty.TvSet.is_empty f.vty then [] else let vtype = f.vty in let vterm = @@ -2604,7 +2598,7 @@ let mk_forall name loc binders trs f ~toplevel ~decl_kind = user_trs = trs; main = f; sko_v; sko_vty; kind = decl_kind} let mk_exists name loc binders trs f ~toplevel ~decl_kind = - if not toplevel || Ty.Svty.is_empty f.vty then + if not toplevel || Ty.TvSet.is_empty f.vty then neg (mk_forall name loc binders trs (neg f) ~toplevel ~decl_kind) else (* If there are type variables in a toplevel exists: 1 - we add diff --git a/src/lib/structures/expr.mli b/src/lib/structures/expr.mli index 6dec2b1ab..bb0460449 100644 --- a/src/lib/structures/expr.mli +++ b/src/lib/structures/expr.mli @@ -49,7 +49,7 @@ type term_view = private { (** Map of free term variables in the term to their type and number of occurrences. *) - vty : Ty.Svty.t; + vty : Ty.TvSet.t; (** Map of free type variables in the term. *) depth: int; @@ -222,7 +222,7 @@ val compare_let : letin -> letin -> int (** Some auxiliary functions *) val free_vars : t -> (Ty.t * int) Var.Map.t -> (Ty.t * int) Var.Map.t -val free_type_vars : t -> Ty.Svty.t +val free_type_vars : t -> Ty.TvSet.t val is_ground : t -> bool val size : t -> int val depth : t -> int diff --git a/src/lib/structures/ty.ml b/src/lib/structures/ty.ml index 31716e50a..1d1fd4730 100644 --- a/src/lib/structures/ty.ml +++ b/src/lib/structures/ty.ml @@ -27,6 +27,11 @@ module DE = Dolmen.Std.Expr +module TvSet = Set.Make (DE.Ty.Var) +module TvMap = Map.Make (DE.Ty.Var) + +type tvar = DE.ty_var + type t = | Tint | Treal @@ -38,8 +43,6 @@ type t = | Tadt of DE.ty_cst * t list | Trecord of trecord -and tvar = { v : int ; mutable value : t option } - and trecord = { mutable args : t list; name : DE.ty_cst; @@ -62,14 +65,12 @@ module Smtlib = struct | Text (args, name) | Trecord { args; name; _ } | Tadt (name, args) -> Fmt.(pf ppf "(@[%a %a@])" DE.Ty.Const.print name (list ~sep:sp pp) args) - | Tvar { v; value = None; _ } -> Fmt.pf ppf "A%d" v - | Tvar { value = Some t; _ } -> pp ppf t + | Tvar tv -> DE.Ty.Var.print ppf tv end let pp_smtlib = Smtlib.pp exception TypeClash of t*t -exception Shorten of t type adt_constr = { constr : DE.term_cst ; @@ -96,7 +97,6 @@ let assoc_destrs hs cases = (*** pretty print ***) let print_generic body_of = - let h = Hashtbl.create 17 in let rec print = let open Format in fun body_of fmt -> function @@ -104,17 +104,7 @@ let print_generic body_of = | Treal -> fprintf fmt "real" | Tbool -> fprintf fmt "bool" | Tbitv n -> fprintf fmt "bitv[%d]" n - | Tvar{v=v ; value = None} -> fprintf fmt "'a_%d" v - | Tvar{v=v ; value = Some (Trecord { args = l; name = n; _ } as t) } -> - if Hashtbl.mem h v then - fprintf fmt "%a %a" print_list l DE.Ty.Const.print n - else - (Hashtbl.add h v (); - (*fprintf fmt "('a_%d->%a)" v print t *) - print body_of fmt t) - | Tvar{ value = Some t; _ } -> - (*fprintf fmt "('a_%d->%a)" v print t *) - print body_of fmt t + | Tvar tv -> fprintf fmt "'a_%a" DE.Ty.Var.print tv | Text(l, s) when l == [] -> fprintf fmt "%a" DE.Ty.Const.print s | Text(l,s) -> @@ -187,51 +177,11 @@ let print_generic body_of = let print_list = snd (print_generic None) let print = fst (print_generic None) None - -let fresh_var = - let cpt = ref (-1) in - fun () -> incr cpt; {v= !cpt ; value = None } - -let fresh_tvar () = Tvar (fresh_var ()) - -let rec shorten ty = - match ty with - | Tvar { value = None; _ } -> ty - | Tvar { value = Some (Tvar{ value = None; _ } as t'); _ } -> t' - | Tvar ({ value = Some (Tvar t2); _ } as t1) -> - t1.value <- t2.value; shorten ty - | Tvar { value = Some t'; _ } -> shorten t' - - | Text (l,s) -> - let l, same = My_list.apply shorten l in - if same then ty else Text(l,s) - - | Tfarray (t1,t2) -> - let t1' = shorten t1 in - let t2' = shorten t2 in - if t1 == t1' && t2 == t2' then ty - else Tfarray(t1', t2') - - | Trecord r -> - r.args <- List.map shorten r.args; - r.lbs <- List.map (fun (lb, ty) -> lb, shorten ty) r.lbs; - ty - - | Tadt (n, args) -> - let args' = List.map shorten args in - shorten_body n args; - (* should not rebuild the type if no changes are made *) - Tadt (n, args') - - | Tint | Treal | Tbool | Tbitv _ -> ty - -and shorten_body _ _ = - () - [@ocaml.ppwarning "TODO: should be implemented ?"] +let fresh_tvar () = Tvar (DE.Ty.Var.mk "A") let rec compare t1 t2 = - match shorten t1 , shorten t2 with - | Tvar{ v = v1; _ } , Tvar{ v = v2; _ } -> Int.compare v1 v2 + match t1, t2 with + | Tvar v1, Tvar v2 -> DE.Ty.Var.compare v1 v2 | Tvar _, _ -> -1 | _ , Tvar _ -> 1 | Text(l1, s1) , Text(l2, s2) -> let c = DE.Ty.Const.compare s1 s2 in @@ -282,8 +232,8 @@ and compare_list l1 l2 = match l1, l2 with let rec equal t1 t2 = t1 == t2 || - match shorten t1 , shorten t2 with - | Tvar{ v = v1; _ }, Tvar{ v = v2; _ } -> v1 = v2 + match t1, t2 with + | Tvar v1, Tvar v2 -> DE.Ty.Var.equal v1 v2 | Text(l1, s1), Text(l2, s2) -> (try DE.Ty.Const.equal s1 s2 && List.for_all2 equal l1 l2 with Invalid_argument _ -> false) @@ -312,17 +262,18 @@ let rec equal t1 t2 = | _ -> false (*** matching with a substitution mechanism ***) -module M = Util.MI -type subst = t M.t +type subst = t TvMap.t -let esubst = M.empty +let esubst = TvMap.empty let rec matching s pat t = match pat , t with - | Tvar {v=n;value=None} , _ -> - (try if not (equal (M.find n s) t) then raise (TypeClash(pat,t)); s - with Not_found -> M.add n t s) - | Tvar { value = _; _ }, _ -> raise (Shorten pat) + | Tvar v , _ -> + (try + if not (equal (TvMap.find v s) t) then + raise (TypeClash (pat,t)); + s + with Not_found -> TvMap.add v t s) | Text (l1,s1) , Text (l2,s2) when DE.Ty.Const.equal s1 s2 -> List.fold_left2 matching s l1 l2 | Tfarray (ta1,ta2), Tfarray (tb1,tb2) -> @@ -341,8 +292,8 @@ let rec matching s pat t = let apply_subst = let rec apply_subst s ty = match ty with - | Tvar { v= n; _ } -> - (try M.find n s with Not_found -> ty) + | Tvar v -> + (try TvMap.find v s with Not_found -> ty) | Text (l,e) -> let l, same = My_list.apply (apply_subst s) l in @@ -370,22 +321,16 @@ let apply_subst = | Tint | Treal | Tbool | Tbitv _ -> ty in - fun s ty -> if M.is_empty s then ty else apply_subst s ty + fun s ty -> if TvMap.is_empty s then ty else apply_subst s ty -(* Assume that [shorten] have been applied on [ty]. *) let rec fresh ty subst = match ty with - | Tvar { value = Some _; _ } -> - (* This case is eliminated by the normalization performed - in [shorten]. *) - assert false - - | Tvar { v= x; _ } -> + | Tvar v -> begin - try M.find x subst, subst + try TvMap.find v subst, subst with Not_found -> - let nv = Tvar (fresh_var()) in - nv, M.add x nv subst + let nv = fresh_tvar () in + nv, TvMap.add v nv subst end | Text (args, n) -> let args, subst = fresh_list args subst in @@ -408,17 +353,12 @@ let rec fresh ty subst = Tadt (s, args), subst | t -> t, subst -(* Assume that [shorten] have been applied on [lty]. *) and fresh_list lty subst = List.fold_right (fun ty (lty, subst) -> let ty, subst = fresh ty subst in ty::lty, subst) lty ([], subst) -let fresh ty subst = fresh (shorten ty) subst - -let fresh_list lty subst = fresh_list (List.map shorten lty) subst - module Decls = struct module MH = Map.Make (DE.Ty.Const) @@ -480,16 +420,14 @@ module Decls = struct try List.fold_left2 (fun sbt vty ty -> - let vty = shorten vty in match vty with - | Tvar { value = Some _ ; _ } -> assert false - | Tvar {v ; value = None} -> - if equal vty ty then sbt else M.add v ty sbt + | Tvar v -> + if equal vty ty then sbt else TvMap.add v ty sbt | _ -> Printer.print_err "vty = %a and ty = %a" print vty print ty; assert false - )M.empty params args + ) TvMap.empty params args with Invalid_argument _ -> assert false in let cases = @@ -564,7 +502,7 @@ let trecord ~record_constr lv name lbs = let rec hash t = match t with - | Tvar{ v; _ } -> v + | Tvar tv -> DE.Ty.Var.hash tv | Text(l,s) -> abs (List.fold_left (fun acc x-> acc*19 + hash x) (DE.Ty.Const.hash s) l) | Tfarray (t1,t2) -> 19 * (hash t1) + 23 * (hash t2) @@ -584,9 +522,9 @@ let rec hash t = | _ -> Hashtbl.hash t -let compare_subst = M.compare compare +let compare_subst = TvMap.compare compare -let equal_subst = M.equal equal +let equal_subst = TvMap.equal equal module Svty = Util.SI @@ -599,10 +537,9 @@ module Set = let vty_of t = let rec vty_of_rec acc t = - let t = shorten t in match t with - | Tvar { v = i ; value = None } -> Svty.add i acc - | Text(l,_) -> List.fold_left vty_of_rec acc l + | Tvar tv -> TvSet.add tv acc + | Text (l,_) -> List.fold_left vty_of_rec acc l | Tfarray (t1,t2) -> vty_of_rec (vty_of_rec acc t1) t2 | Trecord { args; lbs; _ } -> let acc = List.fold_left vty_of_rec acc args in @@ -610,16 +547,15 @@ let vty_of t = | Tadt(_, args) -> List.fold_left vty_of_rec acc args - | Tvar { value = Some _ ; _ } | Tint | Treal | Tbool | Tbitv _ -> acc in - vty_of_rec Svty.empty t + vty_of_rec TvSet.empty t let print_subst = let sep ppf () = Fmt.pf ppf " -> " in Fmt.(box @@ braces - @@ iter_bindings ~sep:comma M.iter (pair ~sep int print)) + @@ iter_bindings ~sep:comma TvMap.iter (pair ~sep DE.Ty.Var.print print)) let print_full = fst (print_generic (Some type_body)) (Some type_body) diff --git a/src/lib/structures/ty.mli b/src/lib/structures/ty.mli index 4bb69218b..168c95b69 100644 --- a/src/lib/structures/ty.mli +++ b/src/lib/structures/ty.mli @@ -31,6 +31,12 @@ (** {2 Definition} *) +type tvar = Dolmen.Std.Expr.ty_var +(** Type of type variable. *) + +module TvSet : Set.S with type elt = tvar +module TvMap : Map.S with type key = tvar + type t = | Tint (** Integer numbers *) @@ -63,17 +69,6 @@ type t = | Trecord of trecord (** Record type. *) -and tvar = { - v : int; - (** Unique identifier *) - mutable value : t option; - (** Pointer to the current value of the type variable. *) -} -(** Type variables. - The [value] field is mutated during unification, - hence distinct types should have disjoints sets of - type variables (see function {!val:fresh}). *) - and trecord = { mutable args : t list; (** Arguments passed to the record constructor *) @@ -102,13 +97,9 @@ type adt_constr = for recursive ADTs *) type type_body = adt_constr list -module Svty : Set.S with type elt = int -(** Sets of type variables, indexed by their identifier. *) - module Set : Set.S with type elt = t (** Sets of types *) - val assoc_destrs : Dolmen.Std.Expr.term_cst -> adt_constr list -> @@ -144,7 +135,7 @@ val print_list : Format.formatter -> t list -> unit val print_full : Format.formatter -> t -> unit (** Print function including the record fields. *) -val vty_of : t -> Svty.t +val vty_of : t -> TvSet.t (** Returns the set of type variables that occur in a given type. *) @@ -153,10 +144,6 @@ val vty_of : t -> Svty.t val tunit : t (** The unit type. *) -val fresh_var : unit -> tvar -(** Generate a fresh type variable, guaranteed to be distinct - from any other previously generated by this function. *) - val fresh_tvar : unit -> t (** Wrap the {!val:fresh_var} function to return a type. *) @@ -193,10 +180,7 @@ val trecord : (** {2 Substitutions} *) -module M : Map.S with type key = int -(** Maps from type variables identifiers. *) - -type subst = t M.t +type subst = t TvMap.t (** The type of substitution, i.e. maps from type variables identifiers to types.*)