diff --git a/proxy/marshaler.go b/proxy/marshaler.go index 6acd0e0..442f828 100644 --- a/proxy/marshaler.go +++ b/proxy/marshaler.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "reflect" @@ -28,7 +29,7 @@ const ( BYTES = "BYTES" ) -type Tensor struct { +type OutputTensor struct { Name string `json:"name,omitempty"` Datatype string `json:"datatype,omitempty"` Shape []int64 `json:"shape,omitempty"` @@ -41,13 +42,13 @@ type RESTResponse struct { ModelVersion string `json:"model_version,omitempty"` Id string `json:"id,omitempty"` Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"` - Outputs []Tensor `json:"outputs,omitempty"` + Outputs []OutputTensor `json:"outputs,omitempty"` } type RESTRequest struct { Id string `json:"id,omitempty"` Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"` - Inputs []Tensor `json:"inputs,omitempty"` + Inputs []InputTensor `json:"inputs,omitempty"` Outputs []*gw.ModelInferRequest_InferRequestedOutputTensor `json:"outputs,omitempty"` } @@ -86,7 +87,7 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) { resp.ModelVersion = r.ModelVersion resp.Id = r.Id resp.Parameters = r.Parameters - resp.Outputs = make([]Tensor, len(r.Outputs)) + resp.Outputs = make([]OutputTensor, len(r.Outputs)) for index, output := range r.Outputs { tensor := &resp.Outputs[index] @@ -145,6 +146,44 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) { return c.JSONPb.Marshal(v) } +type InputTensor gw.ModelInferRequest_InferInputTensor + +type InputTensorMeta struct { + Name string `json:"name"` + Datatype string `json:"datatype"` + Shape []int64 `json:"shape"` +} + +type InputTensorData struct { + Data tensorDataUnmarshaller `json:"data"` + Parameters map[string]*gw.InferParameter `json:"parameters"` +} + +func (t *InputTensor) UnmarshalJSON(data []byte) error { + meta := InputTensorMeta{} + if err := json.Unmarshal(data, &meta); err != nil { + return err + } + contents := &gw.InferTensorContents{} + target, err := targetArray(meta.Datatype, meta.Name, contents) + if err != nil { + return err + } + itd := &InputTensorData{Data: tensorDataUnmarshaller{target: target, shape: meta.Shape}} + if err := json.Unmarshal(data, itd); err != nil { + return err + } + *t = InputTensor{ + Name: meta.Name, + Datatype: meta.Datatype, + Shape: meta.Shape, + Parameters: itd.Parameters, + Contents: contents, + } + + return nil +} + // This function adjusts the user input before a gRPC message is sent to the server. func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder { return runtime.DecoderFunc(func(v interface{}) error { @@ -159,77 +198,9 @@ func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder { req.Id = restReq.Id req.Parameters = restReq.Parameters req.Outputs = restReq.Outputs - req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, 0, len(restReq.Inputs)) - - // TODO: Figure out better/cleaner way to do type coercion? - // TODO: Flatten N-Dimensional data arrays. - - for index, input := range restReq.Inputs { - tensor := &gw.ModelInferRequest_InferInputTensor{ - Name: input.Name, - Datatype: input.Datatype, - Shape: input.Shape, - Parameters: input.Parameters, - } - d := input.Data.([]interface{}) - switch tensor.Datatype { - case BOOL: - data := make([]bool, len(d)) - for i := range d { - data[i] = d[i].(bool) - } - tensor.Contents = &gw.InferTensorContents{BoolContents: data} - case UINT8, UINT16, UINT32: - data := make([]uint32, len(d)) - for i := range d { - data[i] = uint32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{UintContents: data} - case UINT64: - data := make([]uint64, len(d)) - for i := range d { - data[i] = uint64(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Uint64Contents: data} - case INT8, INT16, INT32: - data := make([]int32, len(d)) - for i := range d { - data[i] = int32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{IntContents: data} - case INT64: - data := make([]int64, len(d)) - for i := range d { - data[i] = int64(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Int64Contents: data} - case FP16: - return fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensor.Name) //TODO - case FP32: - data := make([]float32, len(d)) - for i := range d { - data[i] = float32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Fp32Contents: data} - case FP64: - data := make([]float64, len(d)) - for i := range d { - data[i] = d[i].(float64) - } - tensor.Contents = &gw.InferTensorContents{Fp64Contents: data} - case BYTES: - // TODO: BytesContents is multi-dimensional. Figure out how to - // correctly represent the data from a 2D slice. - data := make([][]byte, 1) - data[0] = make([]byte, len(d)) - for i := range d { - data[index][i] = byte(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{BytesContents: data} - default: - return fmt.Errorf("unsupported datatype: %s", tensor.Datatype) - } - req.Inputs = append(req.Inputs, tensor) + req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, len(restReq.Inputs)) + for i := range restReq.Inputs { + req.Inputs[i] = (*gw.ModelInferRequest_InferInputTensor)(&restReq.Inputs[i]) } return nil } @@ -257,3 +228,94 @@ func readBytes(dataBytes []byte, elementType tensorType, index int, numElements func sliceType(v interface{}) reflect.Type { return reflect.SliceOf(reflect.TypeOf(v)) } + +type tensorDataUnmarshaller struct { + target interface{} + shape []int64 +} + +func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error { + if len(t.shape) <= 1 { + return json.Unmarshal(data, t.target) // single-dimension fast-path + } + start := -1 + for i, b := range data { + if b == '[' { + if start != -1 { + data = data[start:] + break + } + start = i + } else if !isSpace(b) { + if start == -1 { + return errors.New("invalid tensor data: not a json array") + } + // fast-path: flat array + return json.Unmarshal(data, t.target) + } + } + // here we have nested arrays + + //TODO handle strings / BYTES case + + // strip all the square brackets (update data slice in-place) + var o, c int + j := 1 + for _, b := range data { + if b == '[' { + o++ + } else if b == ']' { + c++ + } else { + data[j] = b + j++ + } + } + if o != c || o != expectedBracketCount(t.shape) { + return errors.New("invalid tensor data: invalid nested json arrays") + } + data[j] = ']' + return json.Unmarshal(data[:j+1], t.target) +} + +func expectedBracketCount(shape []int64) int { + n := len(shape) - 1 + if n < 1 { + return 1 + } + p, s := 1, 1 + for i := 0; i < n; i++ { + p *= int(shape[i]) + s += p + } + return s +} + +func targetArray(dataType, tensorName string, contents *gw.InferTensorContents) (interface{}, error) { + switch dataType { + case BOOL: + return &contents.BoolContents, nil + case UINT8, UINT16, UINT32: + return &contents.UintContents, nil + case UINT64: + return &contents.Uint64Contents, nil + case INT8, INT16, INT32: + return &contents.IntContents, nil + case INT64: + return &contents.Int64Contents, nil + case FP16: + return nil, fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensorName) //TODO + case FP32: + return &contents.Fp32Contents, nil + case FP64: + return &contents.Fp64Contents, nil + case BYTES: + return &contents.BytesContents, nil //TODO still need to figure this one out + default: + return nil, fmt.Errorf("unsupported datatype: %s", dataType) + } +} + +func isSpace(c byte) bool { + return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') +}