diff --git a/.gitignore b/.gitignore index 1e16531..3a8674b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ google/ # general .env .bash_history + +# IDE +.idea/ diff --git a/go.mod b/go.mod index b045486..986ec25 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kserve/rest-proxy go 1.16 require ( + github.com/google/go-cmp v0.5.6 github.com/grpc-ecosystem/grpc-gateway/v2 v2.6.0 google.golang.org/grpc v1.40.0 google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index ce6e0d9..bce16db 100644 --- a/go.sum +++ b/go.sum @@ -482,7 +482,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/proxy/marshaler_test.go b/proxy/marshaler_test.go new file mode 100644 index 0000000..1fa75ca --- /dev/null +++ b/proxy/marshaler_test.go @@ -0,0 +1,159 @@ +package main + +import ( + "github.com/google/go-cmp/cmp" + gw "github.com/kserve/rest-proxy/gen" + "testing" +) + +var json2DInput = ` + { + "name": "predict", + "shape": [2, 64], + "datatype": "FP32", + "data": [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, + 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, + 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0], + + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, + 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, + 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] + ] +} +` + +var json3DInput = ` +{ + "name": "predict", + "shape": [2, 2, 32], + "datatype": "FP32", + "data": [ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, + 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0 + ], + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, + 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0 + ] + ], + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0, + 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0 + ], + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0, + 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0 + ] + ] + ] +} +` + +var json4DInput = ` +{ + "name": "predict", + "shape": [2, 2, 2, 16], + "datatype": "FP32", + "data": [ + [ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0], + [0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0] + ], + [ + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0], + [0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] + ] + ], + [ + [ + [0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, 16.0, 8.0, 0.0], + [0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, 14.0, 0.0, 0.0] + ], + [ + [0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, 16.0, 6.0, 0.0], + [0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, 12.0, 1.0, 0.0] + ] + ] + ] +} +` + +var expectedInput = InputTensor{ + Name: "predict", + Datatype: "FP32", + Shape: []int64{2, 64}, + Contents: &gw.InferTensorContents{ + Fp32Contents: []float32{0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, + 16.0, 8.0, 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, + 14.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, + 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, + 12.0, 1.0, 0.0, 0.0, 0.0, 1.0, 11.0, 14.0, 15.0, 3.0, 0.0, 0.0, 1.0, 13.0, 16.0, 12.0, + 16.0, 8.0, 0.0, 0.0, 8.0, 16.0, 4.0, 6.0, 16.0, 5.0, 0.0, 0.0, 5.0, 15.0, 11.0, 13.0, + 14.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 16.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 16.0, + 16.0, 6.0, 0.0, 0.0, 0.0, 0.0, 16.0, 16.0, 16.0, 7.0, 0.0, 0.0, 0.0, 0.0, 11.0, 13.0, + 12.0, 1.0, 0.0}, + }, +} + +func TestUnmarshaller2DArray(t *testing.T) { + i := &InputTensor{} + err := i.UnmarshalJSON([]byte(json2DInput)) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(i.Name, expectedInput.Name); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Shape, expectedInput.Shape); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Datatype, expectedInput.Datatype); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Contents.Fp32Contents, expectedInput.Contents.Fp32Contents); d != "" { + t.Errorf("diff: %s", d) + } +} + +func TestUnmarshaller3DArray(t *testing.T) { + i := &InputTensor{} + err := i.UnmarshalJSON([]byte(json3DInput)) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(i.Name, expectedInput.Name); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Shape, []int64{2, 2, 32}); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Datatype, expectedInput.Datatype); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Contents.Fp32Contents, expectedInput.Contents.Fp32Contents); d != "" { + t.Errorf("diff: %s", d) + } +} + +func TestUnmarshaller4DArray(t *testing.T) { + i := &InputTensor{} + err := i.UnmarshalJSON([]byte(json4DInput)) + if err != nil { + t.Error(err) + } + if d := cmp.Diff(i.Name, expectedInput.Name); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Shape, []int64{2, 2, 2, 16}); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Datatype, expectedInput.Datatype); d != "" { + t.Errorf("diff: %s", d) + } + if d := cmp.Diff(i.Contents.Fp32Contents, expectedInput.Contents.Fp32Contents); d != "" { + t.Errorf("diff: %s", d) + } +}