Skip to content

Commit

Permalink
first crack at a printing algorithm for with-kinds
Browse files Browse the repository at this point in the history
  • Loading branch information
glittershark committed Dec 27, 2024
1 parent 4c2935f commit ca413c7
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 105 deletions.
1 change: 1 addition & 0 deletions typing/btype.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ module TypeHash = struct
include TransientTypeHash
let mem hash = wrap_repr (mem hash)
let add hash = wrap_repr (add hash)
let replace hash = wrap_repr (replace hash)
let remove hash = wrap_repr (remove hash)
let find hash = wrap_repr (find hash)
let find_opt hash = wrap_repr (find_opt hash)
Expand Down
1 change: 1 addition & 0 deletions typing/btype.mli
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ module TypeHash : sig
include Hashtbl.S with type key = transient_expr
val mem: 'a t -> type_expr -> bool
val add: 'a t -> type_expr -> 'a -> unit
val replace: 'a t -> type_expr -> 'a -> unit
val remove: 'a t -> type_expr -> unit
val find: 'a t -> type_expr -> 'a
val find_opt: 'a t -> type_expr -> 'a option
Expand Down
22 changes: 22 additions & 0 deletions typing/jkind_axis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,28 @@ module Axis_collection (T : Axed) = struct
let[@inline] f f bounds = Monadic_identity.f f bounds
end

module Iter = struct
type ('type_expr, 'd) f =
{ f : 'axis. axis:'axis Axis.t -> ('type_expr, 'd, 'axis) T.t -> unit }

let[@inline] f { f }
{ locality;
linearity;
uniqueness;
portability;
contention;
externality;
nullability
} =
f ~axis:Axis.(Modal (Comonadic Areality)) locality;
f ~axis:Axis.(Modal (Monadic Uniqueness)) uniqueness;
f ~axis:Axis.(Modal (Comonadic Linearity)) linearity;
f ~axis:Axis.(Modal (Monadic Contention)) contention;
f ~axis:Axis.(Modal (Comonadic Portability)) portability;
f ~axis:Axis.(Nonmodal Externality) externality;
f ~axis:Axis.(Nonmodal Nullability) nullability
end

module Map2 = struct
module Monadic (M : Misc.Stdlib.Monad.S) = struct
type ('type_expr, 'd1, 'd2, 'd3) f =
Expand Down
7 changes: 7 additions & 0 deletions typing/jkind_axis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ module Axis_collection (T : Axed) : sig
('type_expr, 'd1, 'd2) f -> ('type_expr, 'd1) t -> ('type_expr, 'd2) t
end

module Iter : sig
type ('type_expr, 'd) f =
{ f : 'axis. axis:'axis Axis.t -> ('type_expr, 'd, 'axis) T.t -> unit }

val f : ('type_expr, 'd) f -> ('type_expr, 'd) t -> unit
end

(** Map an operation over two sets of bounds *)
module Map2 : sig
module Monadic (M : Misc.Stdlib.Monad.S) : sig
Expand Down
33 changes: 15 additions & 18 deletions typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1911,10 +1911,10 @@ module Value_with (Areality : Areality) = struct
| Monadic ax -> Monadic.max_axis ax

let is_max : type m a d. (m, a, d) axis -> a -> bool =
fun ax m -> le_axis ax (max_axis ax) m
fun ax m -> le_axis ax (max_axis ax) m

let is_min : type m a d. (m, a, d) axis -> a -> bool =
fun ax m -> le_axis ax m (min_axis ax)
fun ax m -> le_axis ax m (min_axis ax)

let split = split

Expand Down Expand Up @@ -2178,14 +2178,13 @@ module Const = struct
type 'd packed_value_axis =
| P : ('m, 'a, 'd) Value.axis -> 'd packed_value_axis

let alloc_as_value
: type m a d. (m, a, d) Alloc.axis -> d packed_value_axis
= function
| Comonadic Areality -> P (Comonadic Areality)
| Comonadic Linearity -> P (Comonadic Linearity)
| Comonadic Portability -> P (Comonadic Portability)
| Monadic Uniqueness -> P (Monadic Uniqueness)
| Monadic Contention -> P (Monadic Contention)
let alloc_as_value : type m a d. (m, a, d) Alloc.axis -> d packed_value_axis
= function
| Comonadic Areality -> P (Comonadic Areality)
| Comonadic Linearity -> P (Comonadic Linearity)
| Comonadic Portability -> P (Comonadic Portability)
| Monadic Uniqueness -> P (Monadic Uniqueness)
| Monadic Contention -> P (Monadic Contention)
end

let locality_as_regionality = C.locality_as_regionality
Expand Down Expand Up @@ -2610,11 +2609,9 @@ module Modality = struct
let to_list { monadic; comonadic } =
Comonadic.to_list comonadic @ Monadic.to_list monadic

let proj_monadic ax { monadic; _ } =
Monadic.proj ax monadic
let proj_monadic ax { monadic; _ } = Monadic.proj ax monadic

let proj_comonadic ax { comonadic; _ } =
Comonadic.proj ax comonadic
let proj_comonadic ax { comonadic; _ } = Comonadic.proj ax comonadic

let proj (type m a d) (ax : (m, a, d) Value.axis) t =
match ax with
Expand All @@ -2629,11 +2626,11 @@ module Modality = struct
then false
else
Misc.fatal_error
"Don't yet know how to interpret non-constant, \
non-identity modalities"
end
"Don't yet know how to interpret non-constant, non-identity \
modalities"
end

type t = (Monadic.t, Comonadic.t) monadic_comonadic
type t = (Monadic.t, Comonadic.t) monadic_comonadic

let id : t = { monadic = Monadic.id; comonadic = Comonadic.id }

Expand Down
13 changes: 11 additions & 2 deletions typing/mode_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,19 @@ module type S = sig
val print : Format.formatter -> ('p, 'r) t -> unit

val eq : ('p, 'r0) t -> ('p, 'r1) t -> ('r0, 'r1) Misc.eq option

end

module type Mode := sig
module Areality : Common

module Monadic : sig
module Const : Lattice with type t = monadic
module Const : sig
include Lattice with type t = monadic

val max_axis : (t, 'a) Axis.t -> 'a

val min_axis : (t, 'a) Axis.t -> 'a
end

include Common with module Const := Const

Expand All @@ -309,6 +314,10 @@ module type S = sig
val eq : t -> t -> bool

val print_axis : (t, 'a) Axis.t -> Format.formatter -> 'a -> unit

val max_axis : (t, 'a) Axis.t -> 'a

val min_axis : (t, 'a) Axis.t -> 'a
end

type error = Error : (Const.t, 'a) Axis.t * 'a Solver.error -> error
Expand Down
94 changes: 50 additions & 44 deletions typing/oprint.ml
Original file line number Diff line number Diff line change
Expand Up @@ -342,50 +342,6 @@ let pr_var = Pprintast.tyvar
let ty_var ~non_gen ppf s =
pr_var ppf (if non_gen then "_" ^ s else s)

let print_out_jkind_const ppf ojkind =
let rec pp_element ~nested ppf (ojkind : Outcometree.out_jkind_const) =
match ojkind with
| Ojkind_const_default -> fprintf ppf "_"
| Ojkind_const_abbreviation abbrev -> fprintf ppf "%s" abbrev
| Ojkind_const_mod (base, modes) ->
Misc.pp_parens_if nested (fun ppf (base, modes) ->
fprintf ppf "%a mod @[%a@]" (pp_element ~nested:true) base
(pp_print_list
~pp_sep:(fun ppf () -> fprintf ppf "@ ")
(fun ppf -> fprintf ppf "%s"))
modes
) ppf (base, modes)
| Ojkind_const_product ts ->
let pp_sep ppf () = Format.fprintf ppf "@ & " in
Misc.pp_nested_list ~nested ~pp_element ~pp_sep ppf ts
| Ojkind_const_with _ | Ojkind_const_kind_of _ ->
failwith "XXX unimplemented jkind syntax"
in
pp_element ~nested:false ppf ojkind

let print_out_jkind ppf ojkind =
let rec pp_element ~nested ppf ojkind =
match ojkind with
| Ojkind_var v -> fprintf ppf "%s" v
| Ojkind_const jkind -> print_out_jkind_const ppf jkind
| Ojkind_product ts ->
let pp_sep ppf () = Format.fprintf ppf "@ & " in
Misc.pp_nested_list ~nested ~pp_element ~pp_sep ppf ts
in
pp_element ~nested:false ppf ojkind

let print_out_jkind_annot ppf = function
| None -> ()
| Some lay -> fprintf ppf "@ : %a" print_out_jkind lay

let pr_var_jkind ppf (v, l) = match l with
| None -> pr_var ppf v
| Some lay -> fprintf ppf "(%a : %a)"
pr_var v
print_out_jkind lay

let pr_var_jkinds =
print_list pr_var_jkind (fun ppf -> fprintf ppf "@ ")

(* NON-LEGACY MODES
Here, we are printing mode annotations even if the mode extension is
Expand Down Expand Up @@ -659,6 +615,56 @@ and print_out_label ppf (name, mut, arg, gbl) =
print_out_type arg
print_out_modalities_new m_new

and print_out_jkind_const ppf ojkind =
let rec pp_element ~nested ppf (ojkind : Outcometree.out_jkind_const) =
match ojkind with
| Ojkind_const_default -> fprintf ppf "_"
| Ojkind_const_abbreviation abbrev -> fprintf ppf "%s" abbrev
| Ojkind_const_mod (base, modes) ->
Misc.pp_parens_if nested (fun ppf (base, modes) ->
fprintf ppf "%a mod @[%a@]" (pp_element ~nested:true) base
(pp_print_list
~pp_sep:(fun ppf () -> fprintf ppf "@ ")
(fun ppf -> fprintf ppf "%s"))
modes
) ppf (base, modes)
| Ojkind_const_product ts ->
let pp_sep ppf () = Format.fprintf ppf "@ & " in
Misc.pp_nested_list ~nested ~pp_element ~pp_sep ppf ts
| Ojkind_const_with (base, ty, modalities) ->
fprintf ppf "%a with @[%a@]%a"
(pp_element ~nested:true) base
print_out_type ty
print_out_modalities_new modalities
| Ojkind_const_kind_of _ ->
failwith "XXX unimplemented jkind syntax"
in
pp_element ~nested:false ppf ojkind

and print_out_jkind ppf ojkind =
let rec pp_element ~nested ppf ojkind =
match ojkind with
| Ojkind_var v -> fprintf ppf "%s" v
| Ojkind_const jkind -> print_out_jkind_const ppf jkind
| Ojkind_product ts ->
let pp_sep ppf () = Format.fprintf ppf "@ & " in
Misc.pp_nested_list ~nested ~pp_element ~pp_sep ppf ts
in
pp_element ~nested:false ppf ojkind

and print_out_jkind_annot ppf = function
| None -> ()
| Some lay -> fprintf ppf "@ : %a" print_out_jkind lay

and pr_var_jkind ppf (v, l) = match l with
| None -> pr_var ppf v
| Some lay -> fprintf ppf "(%a : %a)"
pr_var v
print_out_jkind lay

and pr_var_jkinds jks =
print_list pr_var_jkind (fun ppf -> fprintf ppf "@ ") jks

let out_label = ref print_out_label

let out_modality = ref print_out_modality
Expand Down
2 changes: 1 addition & 1 deletion typing/outcometree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type out_jkind_const =
| Ojkind_const_default
| Ojkind_const_abbreviation of string
| Ojkind_const_mod of out_jkind_const * string list
| Ojkind_const_with of out_jkind_const * out_type
| Ojkind_const_with of out_jkind_const * out_type * out_modality_new list
| Ojkind_const_kind_of of out_type
| Ojkind_const_product of out_jkind_const list

Expand Down
Loading

0 comments on commit ca413c7

Please sign in to comment.