Skip to content

Commit

Permalink
Streamline json decoding and support arbitrary dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and ScrapCodes committed Nov 16, 2021
1 parent b3b7ec8 commit 160af67
Showing 1 changed file with 137 additions and 75 deletions.
212 changes: 137 additions & 75 deletions proxy/marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
Expand All @@ -28,7 +29,7 @@ const (
BYTES = "BYTES"
)

type Tensor struct {
type OutputTensor struct {
Name string `json:"name,omitempty"`
Datatype string `json:"datatype,omitempty"`
Shape []int64 `json:"shape,omitempty"`
Expand All @@ -41,13 +42,13 @@ type RESTResponse struct {
ModelVersion string `json:"model_version,omitempty"`
Id string `json:"id,omitempty"`
Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"`
Outputs []Tensor `json:"outputs,omitempty"`
Outputs []OutputTensor `json:"outputs,omitempty"`
}

type RESTRequest struct {
Id string `json:"id,omitempty"`
Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"`
Inputs []Tensor `json:"inputs,omitempty"`
Inputs []InputTensor `json:"inputs,omitempty"`
Outputs []*gw.ModelInferRequest_InferRequestedOutputTensor `json:"outputs,omitempty"`
}

Expand Down Expand Up @@ -86,7 +87,7 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) {
resp.ModelVersion = r.ModelVersion
resp.Id = r.Id
resp.Parameters = r.Parameters
resp.Outputs = make([]Tensor, len(r.Outputs))
resp.Outputs = make([]OutputTensor, len(r.Outputs))

for index, output := range r.Outputs {
tensor := &resp.Outputs[index]
Expand Down Expand Up @@ -145,6 +146,44 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) {
return c.JSONPb.Marshal(v)
}

type InputTensor gw.ModelInferRequest_InferInputTensor

type InputTensorMeta struct {
Name string `json:"name"`
Datatype string `json:"datatype"`
Shape []int64 `json:"shape"`
}

type InputTensorData struct {
Data tensorDataUnmarshaller `json:"data"`
Parameters map[string]*gw.InferParameter `json:"parameters"`
}

func (t *InputTensor) UnmarshalJSON(data []byte) error {
meta := InputTensorMeta{}
if err := json.Unmarshal(data, &meta); err != nil {
return err
}
contents := &gw.InferTensorContents{}
target, err := targetArray(meta.Datatype, meta.Name, contents)
if err != nil {
return err
}
itd := &InputTensorData{Data: tensorDataUnmarshaller{target: target, shape: meta.Shape}}
if err := json.Unmarshal(data, itd); err != nil {
return err
}
*t = InputTensor{
Name: meta.Name,
Datatype: meta.Datatype,
Shape: meta.Shape,
Parameters: itd.Parameters,
Contents: contents,
}

return nil
}

// This function adjusts the user input before a gRPC message is sent to the server.
func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder {
return runtime.DecoderFunc(func(v interface{}) error {
Expand All @@ -159,77 +198,9 @@ func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder {
req.Id = restReq.Id
req.Parameters = restReq.Parameters
req.Outputs = restReq.Outputs
req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, 0, len(restReq.Inputs))

// TODO: Figure out better/cleaner way to do type coercion?
// TODO: Flatten N-Dimensional data arrays.

for index, input := range restReq.Inputs {
tensor := &gw.ModelInferRequest_InferInputTensor{
Name: input.Name,
Datatype: input.Datatype,
Shape: input.Shape,
Parameters: input.Parameters,
}
d := input.Data.([]interface{})
switch tensor.Datatype {
case BOOL:
data := make([]bool, len(d))
for i := range d {
data[i] = d[i].(bool)
}
tensor.Contents = &gw.InferTensorContents{BoolContents: data}
case UINT8, UINT16, UINT32:
data := make([]uint32, len(d))
for i := range d {
data[i] = uint32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{UintContents: data}
case UINT64:
data := make([]uint64, len(d))
for i := range d {
data[i] = uint64(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Uint64Contents: data}
case INT8, INT16, INT32:
data := make([]int32, len(d))
for i := range d {
data[i] = int32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{IntContents: data}
case INT64:
data := make([]int64, len(d))
for i := range d {
data[i] = int64(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Int64Contents: data}
case FP16:
return fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensor.Name) //TODO
case FP32:
data := make([]float32, len(d))
for i := range d {
data[i] = float32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Fp32Contents: data}
case FP64:
data := make([]float64, len(d))
for i := range d {
data[i] = d[i].(float64)
}
tensor.Contents = &gw.InferTensorContents{Fp64Contents: data}
case BYTES:
// TODO: BytesContents is multi-dimensional. Figure out how to
// correctly represent the data from a 2D slice.
data := make([][]byte, 1)
data[0] = make([]byte, len(d))
for i := range d {
data[index][i] = byte(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{BytesContents: data}
default:
return fmt.Errorf("unsupported datatype: %s", tensor.Datatype)
}
req.Inputs = append(req.Inputs, tensor)
req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, len(restReq.Inputs))
for i := range restReq.Inputs {
req.Inputs[i] = (*gw.ModelInferRequest_InferInputTensor)(&restReq.Inputs[i])
}
return nil
}
Expand Down Expand Up @@ -257,3 +228,94 @@ func readBytes(dataBytes []byte, elementType tensorType, index int, numElements
func sliceType(v interface{}) reflect.Type {
return reflect.SliceOf(reflect.TypeOf(v))
}

type tensorDataUnmarshaller struct {
target interface{}
shape []int64
}

func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error {
if len(t.shape) <= 1 {
return json.Unmarshal(data, t.target) // single-dimension fast-path
}
start := -1
for i, b := range data {
if b == '[' {
if start != -1 {
data = data[start:]
break
}
start = i
} else if !isSpace(b) {
if start == -1 {
return errors.New("invalid tensor data: not a json array")
}
// fast-path: flat array
return json.Unmarshal(data, t.target)
}
}
// here we have nested arrays

//TODO handle strings / BYTES case

// strip all the square brackets (update data slice in-place)
var o, c int
j := 1
for _, b := range data {
if b == '[' {
o++
} else if b == ']' {
c++
} else {
data[j] = b
j++
}
}
if o != c || o != expectedBracketCount(t.shape) {
return errors.New("invalid tensor data: invalid nested json arrays")
}
data[j] = ']'
return json.Unmarshal(data[:j+1], t.target)
}

func expectedBracketCount(shape []int64) int {
n := len(shape) - 1
if n < 1 {
return 1
}
p, s := 1, 1
for i := 0; i < n; i++ {
p *= int(shape[i])
s += p
}
return s
}

func targetArray(dataType, tensorName string, contents *gw.InferTensorContents) (interface{}, error) {
switch dataType {
case BOOL:
return &contents.BoolContents, nil
case UINT8, UINT16, UINT32:
return &contents.UintContents, nil
case UINT64:
return &contents.Uint64Contents, nil
case INT8, INT16, INT32:
return &contents.IntContents, nil
case INT64:
return &contents.Int64Contents, nil
case FP16:
return nil, fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensorName) //TODO
case FP32:
return &contents.Fp32Contents, nil
case FP64:
return &contents.Fp64Contents, nil
case BYTES:
return &contents.BytesContents, nil //TODO still need to figure this one out
default:
return nil, fmt.Errorf("unsupported datatype: %s", dataType)
}
}

func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}

0 comments on commit 160af67

Please sign in to comment.