diff --git a/src/lib/rpc_genfake.ml b/src/lib/rpc_genfake.ml index 27b7fd3..499051a 100644 --- a/src/lib/rpc_genfake.ml +++ b/src/lib/rpc_genfake.ml @@ -6,8 +6,21 @@ type err = [ `Msg of string ] let badstuff msg = failwith (Printf.sprintf "Failed to construct the record: %s" msg) -let rec gentest : type a. a typ -> a list = - fun t -> +module SeenType = struct + type t = T : _ typ -> t + let compare a b = if a == b then 0 else Stdlib.compare a b +end + +module Seen = Set.Make(SeenType) + +(* don't use this on recursive types! *) + +let rec gentest : type a. Seen.t -> a typ -> a list = + fun seen t -> + let seen_t = SeenType.T t in + if Seen.mem seen_t seen then [] + else + let gentest t = gentest (Seen.add seen_t seen) t in match t with | Basic Int -> [ 0; 1; max_int; -1; 1000000 ] | Basic Int32 -> [ 0l; 1l; Int32.max_int; -1l; 999999l ] @@ -95,10 +108,18 @@ let rec gentest : type a. a typ -> a list = | Abstract { test_data; _ } -> test_data -let thin d result = if d < 0 then [ List.hd result ] else result +let thin d result = + if d < 0 then match result with + | [] -> [] + | hd :: _ -> [hd] + else result -let rec genall : type a. int -> string -> a typ -> a list = - fun depth strhint t -> +let rec genall: type a. Seen.t -> int -> string -> a typ -> a list = + fun seen depth strhint t -> + let seen_t = SeenType.T t in + if Seen.mem seen_t seen then [] + else + let genall depth strhint t = genall (Seen.add seen_t seen) depth strhint t in match t with | Basic Int -> [ 0 ] | Basic Int32 -> [ 0l ] @@ -192,6 +213,8 @@ let rec genall : type a. int -> string -> a typ -> a list = | Abstract { test_data; _ } -> test_data +(* don't use this on recursive types! *) + let rec gen_nice : type a. a typ -> string -> a = fun ty hint -> let narg n = Printf.sprintf "%s_%d" hint n in @@ -235,3 +258,6 @@ let rec gen_nice : type a. a typ -> string -> a = let content = gen_nice v.tcontents v.tname in v.treview content) | Abstract { test_data; _ } -> List.hd test_data + +let gentest t = gentest Seen.empty t +let genall t = genall Seen.empty t diff --git a/tests/ppx/test_deriving_rpcty.ml b/tests/ppx/test_deriving_rpcty.ml index bd27c73..a92af53 100644 --- a/tests/ppx/test_deriving_rpcty.ml +++ b/tests/ppx/test_deriving_rpcty.ml @@ -312,6 +312,11 @@ type nested = } [@@deriving rpcty] +type recursive = + | A of recursive * string + | B of int +[@@deriving rpcty] + let fakegen () = let fake ty = let fake = Rpc_genfake.genall 10 "string" ty in @@ -335,7 +340,8 @@ let fakegen () = in fake typ_of_test_record_opt; fake typ_of_test_variant_name; - fake typ_of_nested + fake typ_of_nested; + fake typ_of_recursive type test_defaults = { test_with_default : int [@default 5] } [@@deriving rpcty]