Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support upto 4D array as REST input payload. #6

Merged
merged 5 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ google/
# general
.env
.bash_history

# IDE
.idea/
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/kserve/rest-proxy
go 1.16

require (
github.com/google/go-cmp v0.5.6
github.com/grpc-ecosystem/grpc-gateway/v2 v2.6.0
google.golang.org/grpc v1.40.0
google.golang.org/protobuf v1.27.1
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down
212 changes: 137 additions & 75 deletions proxy/marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
Expand All @@ -28,7 +29,7 @@ const (
BYTES = "BYTES"
)

type Tensor struct {
type OutputTensor struct {
Name string `json:"name,omitempty"`
Datatype string `json:"datatype,omitempty"`
Shape []int64 `json:"shape,omitempty"`
Expand All @@ -41,13 +42,13 @@ type RESTResponse struct {
ModelVersion string `json:"model_version,omitempty"`
Id string `json:"id,omitempty"`
Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"`
Outputs []Tensor `json:"outputs,omitempty"`
Outputs []OutputTensor `json:"outputs,omitempty"`
}

type RESTRequest struct {
Id string `json:"id,omitempty"`
Parameters map[string]*gw.InferParameter `json:"parameters,omitempty"`
Inputs []Tensor `json:"inputs,omitempty"`
Inputs []InputTensor `json:"inputs,omitempty"`
Outputs []*gw.ModelInferRequest_InferRequestedOutputTensor `json:"outputs,omitempty"`
}

Expand Down Expand Up @@ -86,7 +87,7 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) {
resp.ModelVersion = r.ModelVersion
resp.Id = r.Id
resp.Parameters = r.Parameters
resp.Outputs = make([]Tensor, len(r.Outputs))
resp.Outputs = make([]OutputTensor, len(r.Outputs))

for index, output := range r.Outputs {
tensor := &resp.Outputs[index]
Expand Down Expand Up @@ -145,6 +146,44 @@ func (c *CustomJSONPb) Marshal(v interface{}) ([]byte, error) {
return c.JSONPb.Marshal(v)
}

type InputTensor gw.ModelInferRequest_InferInputTensor

type InputTensorMeta struct {
Name string `json:"name"`
Datatype string `json:"datatype"`
Shape []int64 `json:"shape"`
}

type InputTensorData struct {
Data tensorDataUnmarshaller `json:"data"`
Parameters map[string]*gw.InferParameter `json:"parameters"`
}

func (t *InputTensor) UnmarshalJSON(data []byte) error {
meta := InputTensorMeta{}
if err := json.Unmarshal(data, &meta); err != nil {
return err
}
contents := &gw.InferTensorContents{}
target, err := targetArray(meta.Datatype, meta.Name, contents)
if err != nil {
return err
}
itd := &InputTensorData{Data: tensorDataUnmarshaller{target: target, shape: meta.Shape}}
if err := json.Unmarshal(data, itd); err != nil {
return err
}
*t = InputTensor{
Name: meta.Name,
Datatype: meta.Datatype,
Shape: meta.Shape,
Parameters: itd.Parameters,
Contents: contents,
}

return nil
}

// This function adjusts the user input before a gRPC message is sent to the server.
func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder {
return runtime.DecoderFunc(func(v interface{}) error {
Expand All @@ -159,77 +198,9 @@ func (c *CustomJSONPb) NewDecoder(r io.Reader) runtime.Decoder {
req.Id = restReq.Id
req.Parameters = restReq.Parameters
req.Outputs = restReq.Outputs
req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, 0, len(restReq.Inputs))

// TODO: Figure out better/cleaner way to do type coercion?
// TODO: Flatten N-Dimensional data arrays.

for index, input := range restReq.Inputs {
tensor := &gw.ModelInferRequest_InferInputTensor{
Name: input.Name,
Datatype: input.Datatype,
Shape: input.Shape,
Parameters: input.Parameters,
}
d := input.Data.([]interface{})
switch tensor.Datatype {
case BOOL:
data := make([]bool, len(d))
for i := range d {
data[i] = d[i].(bool)
}
tensor.Contents = &gw.InferTensorContents{BoolContents: data}
case UINT8, UINT16, UINT32:
data := make([]uint32, len(d))
for i := range d {
data[i] = uint32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{UintContents: data}
case UINT64:
data := make([]uint64, len(d))
for i := range d {
data[i] = uint64(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Uint64Contents: data}
case INT8, INT16, INT32:
data := make([]int32, len(d))
for i := range d {
data[i] = int32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{IntContents: data}
case INT64:
data := make([]int64, len(d))
for i := range d {
data[i] = int64(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Int64Contents: data}
case FP16:
return fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensor.Name) //TODO
case FP32:
data := make([]float32, len(d))
for i := range d {
data[i] = float32(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{Fp32Contents: data}
case FP64:
data := make([]float64, len(d))
for i := range d {
data[i] = d[i].(float64)
}
tensor.Contents = &gw.InferTensorContents{Fp64Contents: data}
case BYTES:
// TODO: BytesContents is multi-dimensional. Figure out how to
// correctly represent the data from a 2D slice.
data := make([][]byte, 1)
data[0] = make([]byte, len(d))
for i := range d {
data[index][i] = byte(d[i].(float64))
}
tensor.Contents = &gw.InferTensorContents{BytesContents: data}
default:
return fmt.Errorf("unsupported datatype: %s", tensor.Datatype)
}
req.Inputs = append(req.Inputs, tensor)
req.Inputs = make([]*gw.ModelInferRequest_InferInputTensor, len(restReq.Inputs))
for i := range restReq.Inputs {
req.Inputs[i] = (*gw.ModelInferRequest_InferInputTensor)(&restReq.Inputs[i])
}
return nil
}
Expand Down Expand Up @@ -257,3 +228,94 @@ func readBytes(dataBytes []byte, elementType tensorType, index int, numElements
func sliceType(v interface{}) reflect.Type {
return reflect.SliceOf(reflect.TypeOf(v))
}

type tensorDataUnmarshaller struct {
target interface{}
shape []int64
}

func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error {
if len(t.shape) <= 1 {
return json.Unmarshal(data, t.target) // single-dimension fast-path
}
start := -1
for i, b := range data {
if b == '[' {
if start != -1 {
data = data[start:]
break
}
start = i
} else if !isSpace(b) {
if start == -1 {
return errors.New("invalid tensor data: not a json array")
}
// fast-path: flat array
return json.Unmarshal(data, t.target)
}
}
// here we have nested arrays

//TODO handle strings / BYTES case

// strip all the square brackets (update data slice in-place)
var o, c int
j := 1
for _, b := range data {
if b == '[' {
o++
} else if b == ']' {
c++
} else {
data[j] = b
j++
}
}
if o != c || o != expectedBracketCount(t.shape) {
return errors.New("invalid tensor data: invalid nested json arrays")
}
data[j] = ']'
return json.Unmarshal(data[:j+1], t.target)
}

func expectedBracketCount(shape []int64) int {
n := len(shape) - 1
if n < 1 {
return 1
}
p, s := 1, 1
for i := 0; i < n; i++ {
p *= int(shape[i])
s += p
}
return s
}

func targetArray(dataType, tensorName string, contents *gw.InferTensorContents) (interface{}, error) {
switch dataType {
case BOOL:
return &contents.BoolContents, nil
case UINT8, UINT16, UINT32:
return &contents.UintContents, nil
case UINT64:
return &contents.Uint64Contents, nil
case INT8, INT16, INT32:
return &contents.IntContents, nil
case INT64:
return &contents.Int64Contents, nil
case FP16:
return nil, fmt.Errorf("FP16 tensors not supported (response tensor %s)", tensorName) //TODO
case FP32:
return &contents.Fp32Contents, nil
case FP64:
return &contents.Fp64Contents, nil
case BYTES:
return &contents.BytesContents, nil //TODO still need to figure this one out
default:
return nil, fmt.Errorf("unsupported datatype: %s", dataType)
}
}

func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
Loading