Skip to content

Commit

Permalink
feat: Support BYTES/string type tensors
Browse files Browse the repository at this point in the history
Addresses #8

Current limitations:
- Does not obey request-level content_type request parameters (has to be set at tensor level)
- Always encodes response tensor bytes as base64

Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill committed Nov 23, 2022
1 parent d9501be commit 8eca0be
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 17 deletions.
211 changes: 211 additions & 0 deletions proxy/bytes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package main

import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"unicode/utf16"
"unicode/utf8"

gw "github.com/kserve/rest-proxy/gen"
)

// This file contains logic related to marshalling/unmarshalling BYTES type tensor data
// "raw" parsing is done for performance and structural validation is done only on a best-effort basis

var escMap = map[byte]byte{'b': '\b', 'f': '\f', 'r': '\r', 't': '\t', 'n': '\n', '\\': '\\', '/': '/', '"': '"'}

func unmarshalBytesJson(target *[][]byte, shape []int64, b64 bool, data []byte) error {
// Cases:
// 1dim raw ( [ [ N ) -> as-is
// flat raw ( [ [ N ) -> as-is
// nested raw ( [ [ [ .. N ) -> strip "middles" (use shape prob)
// 1dim str ( [ " )
// flat str ( [ " )
// nested str ( [ [ .. " ) -> strip all

start := -1
depth := 0
isString := false
for i, b := range data {
if b == '[' {
if start == -1 {
start = i
}
depth += 1
} else if !isSpace(b) {
isString = b == '"'
break
}
}
if depth == 0 {
return errors.New("invalid tensor data: not a json array")
}
if start > 0 {
data = data[start:]
}
if isString {
if depth != 1 && depth != len(shape) {
return errors.New("data array nesting does not match tensor shape")
}
return unmarshalStringArray(target, shape, b64, data)
}
if depth <= 1 {
return errors.New("invalid tensor data: must be an array of byte arrays")
}
if depth == 2 {
// flat numeric case, e.g. [[1,2,3],[4,5,6],[7,8,9]]
return json.Unmarshal(data, target)
}

// nested numeric case, e.g. [[[1,2],[3,4]],[[5,6],[7,8]]]
// ignore innermost dimension because elements are lists of bytes
if (depth - 1) != len(shape) {
return errors.New("invalid tensor data: array nesting does not match tensor shape")
}
return unmarshalNestedNumeric(target, depth, data)
}

// nested numeric case, e.g. [[[1,2],[3,4]],[[5,6],[7,8]]]
func unmarshalNestedNumeric(target *[][]byte, depth int, data []byte) error {
d := 0
j := 1
for _, b := range data {
include := true
if b == '[' {
d++
if d > depth {
return errors.New("invalid tensor data: array nesting does not match tensor shape")
}
include = d == depth
} else if b == ']' {
include = d == depth
d--
}
if include {
data[j] = b
j++
}
}
if d != 0 {
return errors.New("invalid tensor data: array nesting does not match tensor shape")
}
data[j] = ']'
return json.Unmarshal(data[:j+1], target)
}

func unmarshalStringArray(target *[][]byte, shape []int64, b64 bool, data []byte) error {
elems := int(elementCount(shape))
t := make([][]byte, 0, elems)

depth := 0
strStart := -1
j := 0
var ok bool
l := len(data)
for i := 0; i < l; i++ {
b := data[i]
if strStart == -1 {
if b == '[' {
depth++
} else if b == ']' {
depth--
} else if b == '"' {
if len(t) >= elems {
return errors.New("more strings than expected for tensor shape")
}
strStart = i
j = i + 1
} else if b != ',' && !isSpace(b) {
return errors.New("tensor data must be a flat or nested json array of strings")
}
continue
}
// here we are mid-string
if b == '\\' {
i++
if i == l {
break // will error with unexpected end
}
b = data[i]
if b == 'u' {
i += 4
if i >= l {
break // will error with unexpected end
}
cp := utf16.Decode([]uint16{binary.BigEndian.Uint16(data[i-3 : i+1])})
for _, r := range cp {
j += utf8.EncodeRune(data[j:], r)
}
continue
} else if b, ok = escMap[b]; !ok {
return errors.New("invalid escaped char in json string")
}
} else if b == '"' {
//end of string
s := data[strStart+1 : j]
if b64 {
if n, err := base64.StdEncoding.Decode(s, s); err != nil {
return fmt.Errorf("error decoding json string as base64: %w", err)
} else {
s = s[:n]
}
}
t = append(t, s)
strStart = -1
}
if j != i {
data[j] = b
}
j++
}
if strStart != -1 {
return errors.New("fewer strings than expected for tensor shape")
}
if depth != 0 {
return errors.New("invalid tensor data: invalid nested json arrays")
}

*target = t
return nil
}

func isBase64Content(parameters map[string]*gw.InferParameter) bool {
ct := parameters[CONTENT_TYPE].GetStringParam()
if ct == "" || ct == "utf8" || ct == "str" || ct == "UTF8" {
return false
}
if ct == BASE64 || ct == "b64" || ct == "BASE64" || ct == "B64" {
return true
}
if ct != "utf-8" && ct != "UTF-8" {
logger.Error(nil, "Unrecognized content_type, treating as utf8", CONTENT_TYPE, ct)
}
return false
}

// Split raw bytes into separate byte arrays based on 4-byte size delimeters
func splitRawBytes(raw []byte, expectedSize int) ([][]byte, error) {
off, length := int64(0), int64(len(raw))
strings := make([][]byte, expectedSize)
r := bytes.NewReader(raw)
for i := 0; i < expectedSize; i++ {
var size uint32
if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
return nil, errors.New("unexpected end of raw tensor bytes")
}
start := off + 4
if off, _ = r.Seek(int64(size), io.SeekCurrent); off > length {
return nil, errors.New("unexpected end of raw tensor bytes")
}
strings[i] = raw[start:off]
}
if off < length {
return nil, errors.New("more raw tensor bytes than expected")
}
return strings, nil
}
21 changes: 15 additions & 6 deletions proxy/marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
BYTES = "BYTES"
)

const CONTENT_TYPE = "content_type"
const BASE64 = "base64"

type CustomJSONPb struct {
runtime.JSONPb
}
Expand Down Expand Up @@ -118,20 +121,24 @@ func transformResponse(r *gw.ModelInferResponse) (*RESTResponse, error) {
if tensor.Datatype == FP16 {
return nil, fmt.Errorf("FP16 tensors not supported (request tensor %s)", tensor.Name) //TODO
}
if tensor.Datatype == BYTES {
tensor.Parameters[CONTENT_TYPE] = BASE64
}
if r.RawOutputContents != nil {
tt, ok := tensorTypes[tensor.Datatype]
if !ok {
return nil, fmt.Errorf("unsupported datatype in inference response outputs: %s",
tensor.Datatype)
}
numElements := int(elementCount(tensor.Shape))
var err error
if tensor.Datatype == BYTES {
tensor.Data = r.RawOutputContents[index]
tensor.Data, err = splitRawBytes(r.RawOutputContents[index], numElements)
} else {
numElements := int(elementCount(tensor.Shape))
var err error
if tensor.Data, err = readBytes(r.RawOutputContents[index], tt, 0, numElements); err != nil {
return nil, err
}
tensor.Data, err = readBytes(r.RawOutputContents[index], tt, 0, numElements)
}
if err != nil {
return nil, err
}
} else {
switch tensor.Datatype {
Expand All @@ -150,6 +157,8 @@ func transformResponse(r *gw.ModelInferResponse) (*RESTResponse, error) {
case FP64:
tensor.Data = output.Contents.Fp64Contents
case BYTES:
// this will be encoded as array of b64-encoded strings
//TODO support
tensor.Data = output.Contents.BytesContents
default:
return nil, fmt.Errorf("unsupported datatype in inference response outputs: %s",
Expand Down
81 changes: 81 additions & 0 deletions proxy/marshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,63 @@ func generateProtoBufResponse() *gw.ModelInferResponse {
var jsonResponse = `{"model_name":"example","id":"foo","parameters":{"bool_param":false,"content_type":"bar","headers":null,"int_param":12345},` +
`"outputs":[{"name":"predict","datatype":"INT64","shape":[2],"data":[8,8]}]}`

func generateProtoBufBytesResponse() *gw.ModelInferResponse {
expectedOutput := []*gw.ModelInferResponse_InferOutputTensor{{
Name: "predict",
Datatype: "BYTES",
Shape: []int64{2, 2},
Contents: &gw.InferTensorContents{
BytesContents: [][]byte{[]byte("String1"), []byte("String2"), []byte("String3"), []byte("String4")},
},
}}

return &gw.ModelInferResponse{
ModelName: "example",
Id: "foo",
Outputs: expectedOutput,
Parameters: map[string]*gw.InferParameter{
"content_type": {ParameterChoice: &gw.InferParameter_StringParam{StringParam: "bar"}},
"headers": {ParameterChoice: nil},
"int_param": {ParameterChoice: &gw.InferParameter_Int64Param{Int64Param: 12345}},
"bool_param": {ParameterChoice: &gw.InferParameter_BoolParam{BoolParam: false}},
},
}
}

func generateProtoBufBytesResponseRawOutput() *gw.ModelInferResponse {
expectedOutput := []*gw.ModelInferResponse_InferOutputTensor{{
Name: "predict",
Datatype: "BYTES",
Shape: []int64{2, 2},
}}

seven := make([]byte, 4)
binary.LittleEndian.PutUint32(seven, 7)
rawBytes := append(seven, "String1"...)
rawBytes = append(rawBytes, seven...)
rawBytes = append(rawBytes, "String2"...)
rawBytes = append(rawBytes, seven...)
rawBytes = append(rawBytes, "String3"...)
rawBytes = append(rawBytes, seven...)
rawBytes = append(rawBytes, "String4"...)

return &gw.ModelInferResponse{
ModelName: "example",
Id: "foo",
Outputs: expectedOutput,
Parameters: map[string]*gw.InferParameter{
"content_type": {ParameterChoice: &gw.InferParameter_StringParam{StringParam: "bar"}},
"headers": {ParameterChoice: nil},
"int_param": {ParameterChoice: &gw.InferParameter_Int64Param{Int64Param: 12345}},
"bool_param": {ParameterChoice: &gw.InferParameter_BoolParam{BoolParam: false}},
},
RawOutputContents: [][]byte{rawBytes},
}
}

var jsonBytesResponse = `{"model_name":"example","id":"foo","parameters":{"bool_param":false,"content_type":"bar","headers":null,"int_param":12345},` +
`"outputs":[{"name":"predict","datatype":"BYTES","shape":[2,2],"parameters":{"content_type":"base64"},"data":["U3RyaW5nMQ==","U3RyaW5nMg==","U3RyaW5nMw==","U3RyaW5nNA=="]}]}`

func TestRESTResponse(t *testing.T) {
c := CustomJSONPb{}
v := generateProtoBufResponse()
Expand All @@ -63,6 +120,30 @@ func TestRESTResponse(t *testing.T) {
}
}

func TestBytesRESTResponse(t *testing.T) {
c := CustomJSONPb{}
v := generateProtoBufBytesResponse()
marshal, err := c.Marshal(v)
if err != nil {
t.Error(err)
}
if d := cmp.Diff(string(marshal), jsonBytesResponse); d != "" {
t.Errorf("diff :%s", d)
}
}

func TestBytesRESTResponseRawOutput(t *testing.T) {
c := CustomJSONPb{}
v := generateProtoBufBytesResponseRawOutput()
marshal, err := c.Marshal(v)
if err != nil {
t.Error(err)
}
if d := cmp.Diff(string(marshal), jsonBytesResponse); d != "" {
t.Errorf("diff :%s", d)
}
}

func TestRESTResponseRawOutput(t *testing.T) {
c := CustomJSONPb{}
buf := new(bytes.Buffer)
Expand Down
Loading

0 comments on commit 8eca0be

Please sign in to comment.