diff --git a/.gitignore b/.gitignore index 1e16531..3a8674b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ google/ # general .env .bash_history + +# IDE +.idea/ diff --git a/go.mod b/go.mod index b045486..986ec25 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kserve/rest-proxy go 1.16 require ( + github.com/google/go-cmp v0.5.6 github.com/grpc-ecosystem/grpc-gateway/v2 v2.6.0 google.golang.org/grpc v1.40.0 google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index ce6e0d9..bce16db 100644 --- a/go.sum +++ b/go.sum @@ -482,7 +482,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/proxy/marshaler.go b/proxy/marshaler.go index 6acd0e0..442f828 100644 --- a/proxy/marshaler.go +++ b/proxy/marshaler.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "reflect" @@ -28,7 +29,7 @@ const ( BYTES = "BYTES" ) -type Tensor struct { +type OutputTensor struct { Name string `json:"name,omitempty"` Datatype string `json:"datatype,omitempty"` Shape []int64 `json:"shape,omitempty"` @@ -41,13 +42,13 @@ type RESTResponse struct { ModelVersion string `json:"model_version,omitempty"` Id string `json:"id,omitempty"` Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"` - Outputs []Tensor `json:"outputs,omitempty"` + Outputs []OutputTensor `json:"outputs,omitempty"` } type RESTRequest struct { Id string `json:"id,omitempty"` Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"` - Inputs []Tensor `json:"inputs,omitempty"` + Inputs []InputTensor `json:"inputs,omitempty"` Outputs []*gw.ModelInferRequest_InferRequestedOutputTensor `json:"outputs,omitempty"` } @@ -86,7 +87,7 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) { resp.ModelVersion = r.ModelVersion resp.Id = r.Id resp.Parameters = r.Parameters - resp.Outputs = make([]Tensor, len(r.Outputs)) + resp.Outputs = make([]OutputTensor, len(r.Outputs)) for index, output := range r.Outputs { tensor := &resp.Outputs[index] @@ -145,6 +146,44 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) { return c.JSONPb.Marshal(v) } +type InputTensor gw.ModelInferRequest_InferInputTensor + +type InputTensorMeta struct { + Name string `json:"name"` + Datatype string `json:"datatype"` + Shape []int64 `json:"shape"` +} + +type InputTensorData struct { + Data tensorDataUnmarshaller `json:"data"` + Parameters map[string]*gw.InferParameter `json:"parameters"` +} + +func (t *InputTensor) UnmarshalJSON(data []byte) error { + meta := InputTensorMeta{} + if err := json.Unmarshal(data, &meta); err != nil { + return err + } + contents := &gw.InferTensorContents{} + target, err := targetArray(meta.Datatype, meta.Name, contents) + if err != nil { + return err + } + itd := &InputTensorData{Data: tensorDataUnmarshaller{target: target, shape: meta.Shape}} + if err := json.Unmarshal(data, itd); err != nil { + return err + } + *t = InputTensor{ + Name: meta.Name, + Datatype: meta.Datatype, + Shape: meta.Shape, + Parameters: itd.Parameters, + Contents: contents, + } + + return nil +} + // This function adjusts the user input before a gRPC message is sent to the server. func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder { return runtime.DecoderFunc(func(v interface{}) error { @@ -159,77 +198,9 @@ func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder { req.Id = restReq.Id req.Parameters = restReq.Parameters req.Outputs = restReq.Outputs - req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, 0, len(restReq.Inputs)) - - // TODO: Figure out better/cleaner way to do type coercion? - // TODO: Flatten N-Dimensional data arrays. - - for index, input := range restReq.Inputs { - tensor := &gw.ModelInferRequest_InferInputTensor{ - Name: input.Name, - Datatype: input.Datatype, - Shape: input.Shape, - Parameters: input.Parameters, - } - d := input.Data.([]interface{}) - switch tensor.Datatype { - case BOOL: - data := make([]bool, len(d)) - for i := range d { - data[i] = d[i].(bool) - } - tensor.Contents = &gw.InferTensorContents{BoolContents: data} - case UINT8, UINT16, UINT32: - data := make([]uint32, len(d)) - for i := range d { - data[i] = uint32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{UintContents: data} - case UINT64: - data := make([]uint64, len(d)) - for i := range d { - data[i] = uint64(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Uint64Contents: data} - case INT8, INT16, INT32: - data := make([]int32, len(d)) - for i := range d { - data[i] = int32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{IntContents: data} - case INT64: - data := make([]int64, len(d)) - for i := range d { - data[i] = int64(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Int64Contents: data} - case FP16: - return fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensor.Name) //TODO - case FP32: - data := make([]float32, len(d)) - for i := range d { - data[i] = float32(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{Fp32Contents: data} - case FP64: - data := make([]float64, len(d)) - for i := range d { - data[i] = d[i].(float64) - } - tensor.Contents = &gw.InferTensorContents{Fp64Contents: data} - case BYTES: - // TODO: BytesContents is multi-dimensional. Figure out how to - // correctly represent the data from a 2D slice. - data := make([][]byte, 1) - data[0] = make([]byte, len(d)) - for i := range d { - data[index][i] = byte(d[i].(float64)) - } - tensor.Contents = &gw.InferTensorContents{BytesContents: data} - default: - return fmt.Errorf("unsupported datatype: %s", tensor.Datatype) - } - req.Inputs = append(req.Inputs, tensor) + req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, len(restReq.Inputs)) + for i := range restReq.Inputs { + req.Inputs[i] = (*gw.ModelInferRequest_InferInputTensor)(&restReq.Inputs[i]) } return nil } @@ -257,3 +228,94 @@ func readBytes(dataBytes []byte, elementType tensorType, index int, numElements func sliceType(v interface{}) reflect.Type { return reflect.SliceOf(reflect.TypeOf(v)) } + +type tensorDataUnmarshaller struct { + target interface{} + shape []int64 +} + +func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error { + if len(t.shape) <= 1 { + return json.Unmarshal(data, t.target) // single-dimension fast-path + } + start := -1 + for i, b := range data { + if b == '[' { + if start != -1 { + data = data[start:] + break + } + start = i + } else if !isSpace(b) { + if start == -1 { + return errors.New("invalid tensor data: not a json array") + } + // fast-path: flat array + return json.Unmarshal(data, t.target) + } + } + // here we have nested arrays + + //TODO handle strings / BYTES case + + // strip all the square brackets (update data slice in-place) + var o, c int + j := 1 + for _, b := range data { + if b == '[' { + o++ + } else if b == ']' { + c++ + } else { + data[j] = b + j++ + } + } + if o != c || o != expectedBracketCount(t.shape) { + return errors.New("invalid tensor data: invalid nested json arrays") + } + data[j] = ']' + return json.Unmarshal(data[:j+1], t.target) +} + +func expectedBracketCount(shape []int64) int { + n := len(shape) - 1 + if n < 1 { + return 1 + } + p, s := 1, 1 + for i := 0; i < n; i++ { + p *= int(shape[i]) + s += p + } + return s +} + +func targetArray(dataType, tensorName string, contents *gw.InferTensorContents) (interface{}, error) { + switch dataType { + case BOOL: + return &contents.BoolContents, nil + case UINT8, UINT16, UINT32: + return &contents.UintContents, nil + case UINT64: + return &contents.Uint64Contents, nil + case INT8, INT16, INT32: + return &contents.IntContents, nil + case INT64: + return &contents.Int64Contents, nil + case FP16: + return nil, fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensorName) //TODO + case FP32: + return &contents.Fp32Contents, nil + case FP64: + return &contents.Fp64Contents, nil + case BYTES: + return &contents.BytesContents, nil //TODO still need to figure this one out + default: + return nil, fmt.Errorf("unsupported datatype: %s", dataType) + } +} + +func isSpace(c byte) bool { + return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') +} diff --git a/proxy/marshaler_test.go b/proxy/marshaler_test.go new file mode 100644 index 0000000..6d9031a --- /dev/null +++ b/proxy/marshaler_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + gw "github.com/kserve/rest-proxy/gen" +) + +func restRequest(data string, shape string) string { + return `{ + "id": "foo", + "inputs": [{ + "name": "predict", + "shape": ` + shape + `, + "datatype": "FP32", + "data":` + data + ` + }] + }` +} + +var data1D = ` + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, + 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0, 0.0, 0.0, + 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, + 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0, 0.0, 0.0, 1.0, 11.0, 14.0, 15.0, + 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, + 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, + 13.0, 12.0, 1.0, 0.0] +` + +var data2D = ` +[ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, + 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, + 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0], + + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, + 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, + 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] +] +` + +var data3D = ` +[ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, + 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0 + ], + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, + 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0 + ] + ], + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, + 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0 + ], + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, + 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0 + ] + ] +] +` + +var data4D = ` +[ + [ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0], + [0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0] + ], + [ + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0], + [0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] + ] + ], + [ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0], + [0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0] + ], + [ + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0], + [0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] + ] + ] +] +` + +func generateProtoBufResponse() *gw.ModelInferResponse { + expectedOutput := []*gw.ModelInferResponse_InferOutputTensor{{ + Name: "predict", + Datatype: "INT64", + Shape: []int64{2}, + Contents: &gw.InferTensorContents{ + Int64Contents: []int64{8, 8}, + }, + }} + + return &gw.ModelInferResponse{ + ModelName: "example", + Id: "foo", + Outputs: expectedOutput, + } +} + +var jsonResponse = `{"model_name":"example","id":"foo","outputs":[{"name":"predict","datatype":"INT64","shape":[2],"data":[8,8]}]}` + +func generateProtoBufRequest(shape []int64) *gw.ModelInferRequest { + var expectedInput = gw.ModelInferRequest_InferInputTensor{ + Name: "predict", + Datatype: "FP32", + Shape: shape, + Contents: &gw.InferTensorContents{ + Fp32Contents: []float32{0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, + 16.0, 8.0, 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, + 14.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, + 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, + 12.0, 1.0, 0.0, 0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, + 16.0, 8.0, 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, + 14.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, + 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, + 12.0, 1.0, 0.0}, + }, + } + + var modelInferRequest = &gw.ModelInferRequest{ + Id: "foo", + Inputs: []*gw.ModelInferRequest_InferInputTensor{&expectedInput}, + } + return modelInferRequest +} + +func TestRESTRequest(t *testing.T) { + c := CustomJSONPb{} + inputDataArray := []string{data1D, data2D, data3D, data4D} + inputDataShapes := [][]int64{{2, 64}, {2, 64}, {2, 2, 32}, {2, 2, 2, 16}} + for k, data := range inputDataArray { + out := &gw.ModelInferRequest{} + buffer := &bytes.Buffer{} + buffer.Write([]byte(restRequest(data, strings.Join(strings.Split(fmt.Sprintln(inputDataShapes[k]), " "), ",")))) + err := c.NewDecoder(buffer).Decode(out) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(out, generateProtoBufRequest(inputDataShapes[k])) { + t.Errorf("REST request failed to decode for shape: %v", inputDataShapes[k]) + } + } +} + +func TestRESTResponse(t *testing.T) { + c := CustomJSONPb{} + v := generateProtoBufResponse() + marshal, err := c.Marshal(v) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(string(marshal), jsonResponse); d != "" { + t.Errorf("diff :%s", d) + } +}