Skip to content

Commit

Permalink
Fix problem with required proto2 messages.
Browse files Browse the repository at this point in the history
  - Remove Required/Optional in deserialization
  - re-introduce a basic_opt type to be more explicit about presense of default values
  • Loading branch information
andersfugmann committed Jan 29, 2024
1 parent e1c2e1a commit e11bb5b
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 469 deletions.
82 changes: 33 additions & 49 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ module S = Spec.Deserialize
module C = S.C
open S

type required = Required | Optional

type 'a reader = 'a -> Reader.t -> Field.field_type -> 'a
type 'a getter = 'a -> 'a
type ('a, 'b) getter = 'a -> 'b
type 'a field_spec = (int * 'a reader)
type 'a value = ('a field_spec list * required * 'a * 'a getter)
type _ value = Value: ('b field_spec list * 'b * ('b, 'a) getter) -> 'a value
type extensions = (int * Field.t) list

type (_, _) value_list =
| VNil : ('a, 'a) value_list
| VNil_ext : (extensions -> 'a, 'a) value_list
| VCons : ('a value) * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list
| VCons : 'a value * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list

type sentinel_field_spec = int * (Reader.t -> Field.field_type -> unit)
type 'a sentinel_getter = unit -> 'a
Expand Down Expand Up @@ -88,7 +86,7 @@ let read_of_spec: type a. a spec -> Field.field_type * (Reader.t -> a) = functio
| Message (from_proto, _merge) -> Length_delimited, fun reader ->
let Field.{ offset; length; data } = Reader.read_length_delimited reader in
from_proto (Reader.create ~offset ~length data)

(*
let default_value: type a. a spec -> a = function
| Double -> 0.0
| Float -> 0.0
Expand Down Expand Up @@ -117,7 +115,7 @@ let default_value: type a. a spec -> a = function
| SFixed64_int -> 0
| Enum of_int -> of_int 0
| Bool -> false

*)
let id x = x
let keep_last _ v = v

Expand All @@ -129,34 +127,29 @@ let read_field ~read:(expect, read_f) ~map v reader field_type =
error_wrong_field "Deserialize" field

let value: type a. a compound -> a value = function
| Basic_req (index, spec) ->
let map _ v2 = Some v2 in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> error_required_field_missing () in
Value ([(index, read)], None, getter)
| Basic (index, spec, default) ->
let map = match spec with
| Message (_, merge) -> merge
| _ -> keep_last
let map = keep_last
in
let read = read_field ~read:(read_of_spec spec) ~map in
let required = match default with
| Some _ -> Optional
| None -> Required
in
let default = match default with
| None -> default_value spec
| Some default -> default
in
([(index, read)], required, default, id)
Value ([(index, read)], default, id)
| Basic_opt (index, spec) ->
let map = match spec with
| Message (_, merge) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some prev -> Some (merge prev v2)
| Some v1 -> Some (merge v1 v2)
in
map
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
([(index, read)], Optional, None, id)
Value ([(index, read)], None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
let rec read_packed_values read_f acc reader =
Expand All @@ -175,16 +168,16 @@ let value: type a. a compound -> a value = function
let field = Reader.read_field_content ft reader in
error_wrong_field "Deserialize" field
in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Repeated (index, spec, Not_packed) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun vs v -> v :: vs) in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Oneof oneofs ->
let make_reader: a oneof -> a field_spec = fun (Oneof_elem (index, spec, constr)) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ -> constr) in
(index, read)
in
(List.map ~f:make_reader oneofs, Optional, `not_set, id)
Value (List.map ~f:make_reader oneofs, `not_set, id)

module IntMap = Map.Make(struct type t = int let compare = Int.compare end)

Expand All @@ -197,15 +190,12 @@ let deserialize_full: type constr a. extension_ranges -> (constr, a) value_list
| VNil -> NNil
| VNil_ext -> NNil_ext
(* Consider optimizing when optional is true *)
| VCons ((fields, required, default, getter), rest) ->
let v = ref (default, required) in
let get () = match !v with
| _, Required -> error_required_field_missing ();
| v, Optional-> getter v
in
| VCons (Value (fields, default, getter), rest) ->
let v = ref default in
let get () = getter !v in
let fields =
List.map ~f:(fun (index, read) ->
let read reader field_type = let v' = fst !v in v := (read v' reader field_type, Optional) in
let read reader field_type = (v := read !v reader field_type) in
(index, read)
) fields
in
Expand Down Expand Up @@ -277,11 +267,11 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
in

let rec read_values: type constr a. extension_ranges -> Field.field_type -> int -> Reader.t -> constr -> extensions -> (constr, a) value_list -> a = fun extension_ranges tpe idx reader constr extensions ->
let rec read_repeated tpe index read_f default get reader =
let rec read_repeated tpe index read_f default reader =
let default = read_f default reader tpe in
let (tpe, idx) = next_field reader in
match idx = index with
| true -> read_repeated tpe index read_f default get reader
| true -> read_repeated tpe index read_f default reader
| false -> default, tpe, idx
in
function
Expand All @@ -290,34 +280,27 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
| VNil_ext when idx = Int.max_int ->
constr (List.rev extensions)
(* All fields read successfully. Apply extensions and return result. *)
| VCons (([index, read_f], _required, default, get), vs) when index = idx ->
| VCons (Value ([index, read_f], default, get), vs) when index = idx ->
(* Read all values, and apply constructor once all fields have been read.
This pattern is the most likely to be matched for all values, and is added
as an optimization to avoid reconstructing the value list for each recursion.
*)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
let default, tpe, idx = read_repeated tpe index read_f default reader in
let constr = (constr (get default)) in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
| VCons (Value ((index, read_f) :: fields, default, get), vs) when index = idx ->
(* Read all values for the given field *)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, Optional, default, get), vs))
let default, tpe, idx = read_repeated tpe index read_f default reader in
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| vs when in_extension_ranges extension_ranges idx ->
(* Extensions may be sent inline. Store all valid extensions, before starting to apply constructors *)
let extensions = (idx, Reader.read_field_content tpe reader) :: extensions in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (([], Required, _default, _get), _vs) ->
(* If there are no more fields to be read we will never find the value.
If all values are read, then raise, else revert to full deserialization *)
begin match (idx = Int.max_int) with
| true -> error_required_field_missing ()
| false -> raise Restart_full
end
| VCons ((_ :: fields, optional, default, get), vs) ->
| VCons (Value (_ :: fields, default, get), vs) ->
(* Drop the field, as we dont expect to find it. *)
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, optional, default, get), vs))
| VCons (([], Optional, default, get), vs) ->
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| VCons (Value ([], default, get), vs) ->
(* Apply destructor. This case is only relevant for oneof fields *)
read_values extension_ranges tpe idx reader (constr (get default)) extensions vs
| VNil | VNil_ext ->
Expand All @@ -335,6 +318,7 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
let (tpe, idx) = next_field reader in
try
read_values extension_ranges tpe idx reader constr [] values
with Restart_full ->
with (Restart_full | Result.Error `Required_field_missing) ->
(* Revert to full deserialization *)
Reader.reset reader offset;
deserialize_full extension_ranges values constr reader
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/extensions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let compare _ _ = 0
let index_of_spec: type a. a Spec.Serialize.compound -> int = function
| Basic (index, _, _) -> index
| Basic_opt (index, _) -> index
| Basic_req (index, _) -> index
| Repeated (index, _, _) -> index
| Oneof _ -> failwith "Oneof fields not allowed in extensions"

Expand Down
12 changes: 10 additions & 2 deletions src/ocaml_protoc_plugin/merge.ml
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
(** Merge a two values. Need to match on the spec to merge messages recursivly *)
let merge: type t. t Spec.Deserialize.compound -> t -> t -> t = fun spec t t' -> match spec with
| Spec.Deserialize.Basic (_field, Message (_, merge), _) -> merge t t'
| Spec.Deserialize.Basic (_field, _spec, Some default) when t' = default -> t
| Spec.Deserialize.Basic (_field, Message (_, _), _) -> failwith "Messages with defaults cannot happen"
| Spec.Deserialize.Basic (_field, _spec, default) when t' = default -> t
| Spec.Deserialize.Basic (_field, _spec, _) -> t'

(* The spec states that proto2 required fields must be transmitted exactly once.
So merging these fields is not possible. The essentially means that you cannot merge
proto2 messages containing required fields.
In this implementation, we choose to ignore this, and adopt 'keep last'
*)
| Spec.Deserialize.Basic_req (_field, Message (_, merge)) -> merge t t'
| Spec.Deserialize.Basic_req (_field, _spec) -> t'
| Spec.Deserialize.Basic_opt (_field, Message (_, merge)) ->
begin
match t, t' with
Expand Down
19 changes: 9 additions & 10 deletions src/ocaml_protoc_plugin/serialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,15 @@ let rec write: type a. a compound -> Writer.t -> a -> unit = function
*)
| Basic (index, spec, default) -> begin
let write = write_field spec index in
match default with
| Some default ->
fun writer v -> begin
match v with
| v when v = default -> ()
| v -> write v writer
end
| None ->
fun writer v -> write v writer
let writer writer = function
| v when v = default -> ()
| v -> write v writer
in
writer
end
| Basic_req (index, spec) ->
let write = write_field spec index in
fun writer v -> write v writer
| Basic_opt (index, spec) -> begin
let write = write_field spec index in
fun writer v ->
Expand All @@ -145,7 +144,7 @@ let rec write: type a. a compound -> Writer.t -> a -> unit = function
(* Wonder if we could get the specs before calling v. Wonder what f is? *)
(* We could prob. return a list of all possible values + f v -> v. *)
let Oneof_elem (index, spec, v) = f v in
write (Basic (index, spec, None)) writer v
write (Basic_req (index, spec)) writer v
end

let in_extension_ranges extension_ranges index =
Expand Down
14 changes: 13 additions & 1 deletion src/ocaml_protoc_plugin/spec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,25 @@ module Make(T : T) = struct
| Oneof_elem : int * 'b spec * ('a, ('b -> 'a), 'b) T.dir -> 'a oneof

type _ compound =
| Basic : int * 'a spec * 'a option -> 'a compound
(* A field, where the default value is know (and set). This cannot be used for message types *)
| Basic : int * 'a spec * 'a -> 'a compound

(* Proto2/proto3 optional fields. *)
| Basic_opt : int * 'a spec -> 'a option compound

(* Proto2 required fields (and oneof fields) *)
| Basic_req : int * 'a spec -> 'a compound

(* Repeated fields *)
| Repeated : int * 'a spec * packed -> 'a list compound
| Oneof : ('a, 'a oneof list, 'a -> unit oneof) T.dir -> ([> `not_set ] as 'a) compound

type (_, _) compound_list =
| Nil : ('a, 'a) compound_list

(* Nil_ext denotes that the message contains extensions *)
| Nil_ext: extension_ranges -> (extensions -> 'a, 'a) compound_list

| Cons : ('a compound) * ('b, 'c) compound_list -> ('a -> 'b, 'c) compound_list

module C = struct
Expand Down Expand Up @@ -97,6 +108,7 @@ module Make(T : T) = struct

let repeated (i, s, p) = Repeated (i, s, p)
let basic (i, s, d) = Basic (i, s, d)
let basic_req (i, s) = Basic_req (i, s)
let basic_opt (i, s) = Basic_opt (i, s)
let oneof s = Oneof s
let oneof_elem (a, b, c) = Oneof_elem (a, b, c)
Expand Down
Loading

0 comments on commit e11bb5b

Please sign in to comment.