diff --git a/proxy/bytes.go b/proxy/bytes.go new file mode 100644 index 0000000..dfe5e55 --- /dev/null +++ b/proxy/bytes.go @@ -0,0 +1,211 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "unicode/utf16" + "unicode/utf8" + + gw "github.com/kserve/rest-proxy/gen" +) + +// This file contains logic related to marshalling/unmarshalling BYTES type tensor data +// "raw" parsing is done for performance and structural validation is done only on a best-effort basis + +var escMap = map[byte]byte{'b': '\b', 'f': '\f', 'r': '\r', 't': '\t', 'n': '\n', '\\': '\\', '/': '/', '"': '"'} + +func unmarshalBytesJson(target *[][]byte, shape []int64, b64 bool, data []byte) error { + // Cases: + // 1dim raw ( [ [ N ) -> as-is + // flat raw ( [ [ N ) -> as-is + // nested raw ( [ [ [ .. N ) -> strip "middles" (use shape prob) + // 1dim str ( [ " ) + // flat str ( [ " ) + // nested str ( [ [ .. " ) -> strip all + + start := -1 + depth := 0 + isString := false + for i, b := range data { + if b == '[' { + if start == -1 { + start = i + } + depth += 1 + } else if !isSpace(b) { + isString = b == '"' + break + } + } + if depth == 0 { + return errors.New("invalid tensor data: not a json array") + } + if start > 0 { + data = data[start:] + } + if isString { + if depth != 1 && depth != len(shape) { + return errors.New("data array nesting does not match tensor shape") + } + return unmarshalStringArray(target, shape, b64, data) + } + if depth <= 1 { + return errors.New("invalid tensor data: must be an array of byte arrays") + } + if depth == 2 { + // flat numeric case, e.g. [[1,2,3],[4,5,6],[7,8,9]] + return json.Unmarshal(data, target) + } + + // nested numeric case, e.g. [[[1,2],[3,4]],[[5,6],[7,8]]] + // ignore innermost dimension because elements are lists of bytes + if (depth - 1) != len(shape) { + return errors.New("invalid tensor data: array nesting does not match tensor shape") + } + return unmarshalNestedNumeric(target, depth, data) +} + +// nested numeric case, e.g. [[[1,2],[3,4]],[[5,6],[7,8]]] +func unmarshalNestedNumeric(target *[][]byte, depth int, data []byte) error { + d := 0 + j := 1 + for _, b := range data { + include := true + if b == '[' { + d++ + if d > depth { + return errors.New("invalid tensor data: array nesting does not match tensor shape") + } + include = d == depth + } else if b == ']' { + include = d == depth + d-- + } + if include { + data[j] = b + j++ + } + } + if d != 0 { + return errors.New("invalid tensor data: array nesting does not match tensor shape") + } + data[j] = ']' + return json.Unmarshal(data[:j+1], target) +} + +func unmarshalStringArray(target *[][]byte, shape []int64, b64 bool, data []byte) error { + elems := int(elementCount(shape)) + t := make([][]byte, 0, elems) + + depth := 0 + strStart := -1 + j := 0 + var ok bool + l := len(data) + for i := 0; i < l; i++ { + b := data[i] + if strStart == -1 { + if b == '[' { + depth++ + } else if b == ']' { + depth-- + } else if b == '"' { + if len(t) >= elems { + return errors.New("more strings than expected for tensor shape") + } + strStart = i + j = i + 1 + } else if b != ',' && !isSpace(b) { + return errors.New("tensor data must be a flat or nested json array of strings") + } + continue + } + // here we are mid-string + if b == '\\' { + i++ + if i == l { + break // will error with unexpected end + } + b = data[i] + if b == 'u' { + i += 4 + if i >= l { + break // will error with unexpected end + } + cp := utf16.Decode([]uint16{binary.BigEndian.Uint16(data[i-3 : i+1])}) + for _, r := range cp { + j += utf8.EncodeRune(data[j:], r) + } + continue + } else if b, ok = escMap[b]; !ok { + return errors.New("invalid escaped char in json string") + } + } else if b == '"' { + //end of string + s := data[strStart+1 : j] + if b64 { + if n, err := base64.StdEncoding.Decode(s, s); err != nil { + return fmt.Errorf("error decoding json string as base64: %w", err) + } else { + s = s[:n] + } + } + t = append(t, s) + strStart = -1 + } + if j != i { + data[j] = b + } + j++ + } + if strStart != -1 { + return errors.New("fewer strings than expected for tensor shape") + } + if depth != 0 { + return errors.New("invalid tensor data: invalid nested json arrays") + } + + *target = t + return nil +} + +func isBase64Content(parameters map[string]*gw.InferParameter) bool { + ct := parameters[CONTENT_TYPE].GetStringParam() + if ct == "" || ct == "utf8" || ct == "str" || ct == "UTF8" { + return false + } + if ct == BASE64 || ct == "b64" || ct == "BASE64" || ct == "B64" { + return true + } + if ct != "utf-8" && ct != "UTF-8" { + logger.Error(nil, "Unrecognized content_type, treating as utf8", CONTENT_TYPE, ct) + } + return false +} + +// Split raw bytes into separate byte arrays based on 4-byte size delimeters +func splitRawBytes(raw []byte, expectedSize int) ([][]byte, error) { + off, length := int64(0), int64(len(raw)) + strings := make([][]byte, expectedSize) + r := bytes.NewReader(raw) + for i := 0; i < expectedSize; i++ { + var size uint32 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, errors.New("unexpected end of raw tensor bytes") + } + start := off + 4 + if off, _ = r.Seek(int64(size), io.SeekCurrent); off > length { + return nil, errors.New("unexpected end of raw tensor bytes") + } + strings[i] = raw[start:off] + } + if off < length { + return nil, errors.New("more raw tensor bytes than expected") + } + return strings, nil +} diff --git a/proxy/marshaler.go b/proxy/marshaler.go index e319d1f..c02bc74 100644 --- a/proxy/marshaler.go +++ b/proxy/marshaler.go @@ -42,6 +42,9 @@ const ( BYTES = "BYTES" ) +const CONTENT_TYPE = "content_type" +const BASE64 = "base64" + type CustomJSONPb struct { runtime.JSONPb } @@ -118,20 +121,24 @@ func transformResponse(r *gw.ModelInferResponse) (*RESTResponse, error) { if tensor.Datatype == FP16 { return nil, fmt.Errorf("FP16 tensors not supported (request tensor %s)", tensor.Name) //TODO } + if tensor.Datatype == BYTES { + tensor.Parameters[CONTENT_TYPE] = BASE64 + } if r.RawOutputContents != nil { tt, ok := tensorTypes[tensor.Datatype] if !ok { return nil, fmt.Errorf("unsupported datatype in inference response outputs: %s", tensor.Datatype) } + numElements := int(elementCount(tensor.Shape)) + var err error if tensor.Datatype == BYTES { - tensor.Data = r.RawOutputContents[index] + tensor.Data, err = splitRawBytes(r.RawOutputContents[index], numElements) } else { - numElements := int(elementCount(tensor.Shape)) - var err error - if tensor.Data, err = readBytes(r.RawOutputContents[index], tt, 0, numElements); err != nil { - return nil, err - } + tensor.Data, err = readBytes(r.RawOutputContents[index], tt, 0, numElements) + } + if err != nil { + return nil, err } } else { switch tensor.Datatype { @@ -150,6 +157,8 @@ func transformResponse(r *gw.ModelInferResponse) (*RESTResponse, error) { case FP64: tensor.Data = output.Contents.Fp64Contents case BYTES: + // this will be encoded as array of b64-encoded strings + //TODO support tensor.Data = output.Contents.BytesContents default: return nil, fmt.Errorf("unsupported datatype in inference response outputs: %s", diff --git a/proxy/marshaler_test.go b/proxy/marshaler_test.go index bfce1d8..2906e6c 100644 --- a/proxy/marshaler_test.go +++ b/proxy/marshaler_test.go @@ -51,6 +51,63 @@ func generateProtoBufResponse() *gw.ModelInferResponse { var jsonResponse = `{"model_name":"example","id":"foo","parameters":{"bool_param":false,"content_type":"bar","headers":null,"int_param":12345},` + `"outputs":[{"name":"predict","datatype":"INT64","shape":[2],"data":[8,8]}]}` +func generateProtoBufBytesResponse() *gw.ModelInferResponse { + expectedOutput := []*gw.ModelInferResponse_InferOutputTensor{{ + Name: "predict", + Datatype: "BYTES", + Shape: []int64{2, 2}, + Contents: &gw.InferTensorContents{ + BytesContents: [][]byte{[]byte("String1"), []byte("String2"), []byte("String3"), []byte("String4")}, + }, + }} + + return &gw.ModelInferResponse{ + ModelName: "example", + Id: "foo", + Outputs: expectedOutput, + Parameters: map[string]*gw.InferParameter{ + "content_type": {ParameterChoice: &gw.InferParameter_StringParam{StringParam: "bar"}}, + "headers": {ParameterChoice: nil}, + "int_param": {ParameterChoice: &gw.InferParameter_Int64Param{Int64Param: 12345}}, + "bool_param": {ParameterChoice: &gw.InferParameter_BoolParam{BoolParam: false}}, + }, + } +} + +func generateProtoBufBytesResponseRawOutput() *gw.ModelInferResponse { + expectedOutput := []*gw.ModelInferResponse_InferOutputTensor{{ + Name: "predict", + Datatype: "BYTES", + Shape: []int64{2, 2}, + }} + + seven := make([]byte, 4) + binary.LittleEndian.PutUint32(seven, 7) + rawBytes := append(seven, "String1"...) + rawBytes = append(rawBytes, seven...) + rawBytes = append(rawBytes, "String2"...) + rawBytes = append(rawBytes, seven...) + rawBytes = append(rawBytes, "String3"...) + rawBytes = append(rawBytes, seven...) + rawBytes = append(rawBytes, "String4"...) + + return &gw.ModelInferResponse{ + ModelName: "example", + Id: "foo", + Outputs: expectedOutput, + Parameters: map[string]*gw.InferParameter{ + "content_type": {ParameterChoice: &gw.InferParameter_StringParam{StringParam: "bar"}}, + "headers": {ParameterChoice: nil}, + "int_param": {ParameterChoice: &gw.InferParameter_Int64Param{Int64Param: 12345}}, + "bool_param": {ParameterChoice: &gw.InferParameter_BoolParam{BoolParam: false}}, + }, + RawOutputContents: [][]byte{rawBytes}, + } +} + +var jsonBytesResponse = `{"model_name":"example","id":"foo","parameters":{"bool_param":false,"content_type":"bar","headers":null,"int_param":12345},` + + `"outputs":[{"name":"predict","datatype":"BYTES","shape":[2,2],"parameters":{"content_type":"base64"},"data":["U3RyaW5nMQ==","U3RyaW5nMg==","U3RyaW5nMw==","U3RyaW5nNA=="]}]}` + func TestRESTResponse(t *testing.T) { c := CustomJSONPb{} v := generateProtoBufResponse() @@ -63,6 +120,30 @@ func TestRESTResponse(t *testing.T) { } } +func TestBytesRESTResponse(t *testing.T) { + c := CustomJSONPb{} + v := generateProtoBufBytesResponse() + marshal, err := c.Marshal(v) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(string(marshal), jsonBytesResponse); d != "" { + t.Errorf("diff :%s", d) + } +} + +func TestBytesRESTResponseRawOutput(t *testing.T) { + c := CustomJSONPb{} + v := generateProtoBufBytesResponseRawOutput() + marshal, err := c.Marshal(v) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(string(marshal), jsonBytesResponse); d != "" { + t.Errorf("diff :%s", d) + } +} + func TestRESTResponseRawOutput(t *testing.T) { c := CustomJSONPb{} buf := new(bytes.Buffer) diff --git a/proxy/request.go b/proxy/request.go index 1e929ae..24c2a58 100644 --- a/proxy/request.go +++ b/proxy/request.go @@ -38,7 +38,8 @@ func transformRequest(restReq *RESTRequest, req *gw.ModelInferRequest) { } type RESTRequest struct { - Id string `json:"id,omitempty"` + Id string `json:"id,omitempty"` + //TODO figure out how to handle request-level content type parameter Parameters parameterMap `json:"parameters,omitempty"` Inputs []InputTensor `json:"inputs,omitempty"` Outputs []*gw.ModelInferRequest_InferRequestedOutputTensor `json:"outputs,omitempty"` @@ -49,9 +50,10 @@ type RESTRequest struct { type InputTensor gw.ModelInferRequest_InferInputTensor type InputTensorMeta struct { - Name string `json:"name"` - Datatype string `json:"datatype"` - Shape []int64 `json:"shape"` + Name string `json:"name"` + Datatype string `json:"datatype"` + Shape []int64 `json:"shape"` + Parameters parameterMap `json:"parameters"` } type InputTensorData struct { @@ -69,7 +71,11 @@ func (t *InputTensor) UnmarshalJSON(data []byte) error { if err != nil { return err } - itd := &InputTensorData{Data: tensorDataUnmarshaller{target: target, shape: meta.Shape}} + isBytes := meta.Datatype == BYTES + itd := &InputTensorData{Data: tensorDataUnmarshaller{ + target: target, shape: meta.Shape, + bytes: isBytes, b64: isBytes && isBase64Content(meta.Parameters), + }} if err := json.Unmarshal(data, itd); err != nil { return err } @@ -77,7 +83,7 @@ func (t *InputTensor) UnmarshalJSON(data []byte) error { Name: meta.Name, Datatype: meta.Datatype, Shape: meta.Shape, - Parameters: itd.Parameters, + Parameters: meta.Parameters, Contents: contents, } @@ -110,11 +116,16 @@ func targetArray(dataType, tensorName string, contents *gw.InferTensorContents) } type tensorDataUnmarshaller struct { - target interface{} shape []int64 + bytes bool + b64 bool + target interface{} } func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error { + if t.bytes { + return unmarshalBytesJson(t.target.(*[][]byte), t.shape, t.b64, data) + } if len(t.shape) <= 1 { return json.Unmarshal(data, t.target) // single-dimension fast-path } @@ -122,8 +133,10 @@ func (t *tensorDataUnmarshaller) UnmarshalJSON(data []byte) error { for i, b := range data { if b == '[' { if start != -1 { - data = data[start:] - break + if start != 0 { + data = data[start:] + } + break // here we have nested arrays } start = i } else if !isSpace(b) { diff --git a/proxy/request_test.go b/proxy/request_test.go index 9db6f17..e6591ad 100644 --- a/proxy/request_test.go +++ b/proxy/request_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/json" "fmt" "strings" "testing" @@ -103,6 +104,72 @@ func restRequest(data string, shape string) string { }` } +type bytesTensorTestCase struct { + shape []int64 + jsonData string + pbBytes [][]byte + parameters map[string]string +} + +var bytesTensorCases = []bytesTensorTestCase{ + { + shape: []int64{2}, + jsonData: `["My UTF8 String", "Another string"]`, + pbBytes: [][]byte{[]byte("My UTF8 String"), []byte("Another string")}, + }, + { + shape: []int64{1}, + jsonData: `[[77, 121, 32, 85, 84, 70, 56, 32, 83, 116, 114, 105, 110, 103]]`, + pbBytes: [][]byte{{77, 121, 32, 85, 84, 70, 56, 32, 83, 116, 114, 105, 110, 103}}, + }, + { + shape: []int64{2, 1}, + jsonData: `[["String1"], ["String2"]]`, + pbBytes: [][]byte{[]byte("String1"), []byte("String2")}, + }, + { + shape: []int64{2, 1}, + jsonData: `["String1", "String2"]`, + pbBytes: [][]byte{[]byte("String1"), []byte("String2")}, + parameters: map[string]string{"content_type": "str"}, + }, + { + shape: []int64{2, 1}, + jsonData: `[[[83, 116, 114, 105, 110, 103, 32, 49]], [[83, 116, 114, 105, 110, 103, 32, 50]]]`, + pbBytes: [][]byte{{83, 116, 114, 105, 110, 103, 32, 49}, {83, 116, 114, 105, 110, 103, 32, 50}}, + }, + { + shape: []int64{2, 1}, + jsonData: `["TXkgVVRGOCBTdHJpbmc=", "QW5vdGhlciBzdHJpbmc="]`, + pbBytes: [][]byte{[]byte("My UTF8 String"), []byte("Another string")}, + parameters: map[string]string{"content_type": "base64"}, + }, +} + +func bytesRestRequest(shape []int64, jsonData string, parameters map[string]string) string { + shapeStr, _ := json.Marshal(shape) + parameterStr := "" + if len(parameters) != 0 { + p, _ := json.Marshal(parameters) + parameterStr = `, "parameters": ` + string(p) + } + + return `{ + "id": "foo", + "parameters": { + "top_level": "foo", + "bool_param": false + }, + "inputs": [{ + "name": "predict", + "shape": ` + string(shapeStr) + `, + "datatype": "BYTES", + "data": ` + jsonData + + parameterStr + ` + }] + }` +} + func generateProtoBufRequest(shape []int64) *gw.ModelInferRequest { var expectedInput = gw.ModelInferRequest_InferInputTensor{ Name: "predict", @@ -146,8 +213,7 @@ func TestRESTRequest(t *testing.T) { out := &gw.ModelInferRequest{} buffer := &bytes.Buffer{} buffer.Write([]byte(restRequest(data, strings.Join(strings.Split(fmt.Sprintln(inputDataShapes[k]), " "), ",")))) - err := c.NewDecoder(buffer).Decode(out) - if err != nil { + if err := c.NewDecoder(buffer).Decode(out); err != nil { t.Error(err) } expected := generateProtoBufRequest(inputDataShapes[k]) @@ -156,3 +222,48 @@ func TestRESTRequest(t *testing.T) { } } } + +func TestBytesRESTRequest(t *testing.T) { + for _, test := range bytesTensorCases { + c := CustomJSONPb{} + buffer := &bytes.Buffer{} + out := &gw.ModelInferRequest{} + buffer.Write([]byte(bytesRestRequest(test.shape, test.jsonData, test.parameters))) + if err := c.NewDecoder(buffer).Decode(out); err != nil { + t.Error(err) + } + + expected := &gw.ModelInferRequest{ + Id: "foo", + Parameters: map[string]*gw.InferParameter{ + "top_level": {ParameterChoice: &gw.InferParameter_StringParam{StringParam: "foo"}}, + "bool_param": {ParameterChoice: &gw.InferParameter_BoolParam{BoolParam: false}}, + }, + Inputs: []*gw.ModelInferRequest_InferInputTensor{{ + Name: "predict", + Datatype: "BYTES", + Shape: test.shape, + Contents: &gw.InferTensorContents{ + BytesContents: test.pbBytes, + }}, + }, + RawInputContents: nil, + } + + if len(test.parameters) > 0 { + p := map[string]*gw.InferParameter{} + for k, v := range test.parameters { + p[k] = &gw.InferParameter{ + ParameterChoice: &gw.InferParameter_StringParam{StringParam: v}, + } + } + expected.Inputs[0].Parameters = p + } + + fmt.Println(out) + if !proto.Equal(out, expected) { + t.Errorf("REST request failed to decode for test: %v: %v != %v", test, out, expected) + } + } + +}