diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1762a0b..3feff73 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -45,3 +45,6 @@ jobs: version: "v1.57" skip-pkg-cache: true install-mode: "goinstall" + + - name: Run tests 🧪 + run: make test diff --git a/Makefile b/Makefile index f37e43a..dfd0f55 100644 --- a/Makefile +++ b/Makefile @@ -13,3 +13,6 @@ update-all: build-dev: tidy go build + +test: + go test -v ./... diff --git a/go.mod b/go.mod index 717bd5f..add65b9 100644 --- a/go.mod +++ b/go.mod @@ -11,12 +11,14 @@ require ( github.com/jackc/pgx/v5 v5.5.5 github.com/prometheus/client_golang v1.19.0 github.com/spf13/cast v1.6.0 + github.com/stretchr/testify v1.8.4 google.golang.org/grpc v1.63.0 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/expr-lang/expr v1.16.3 // indirect github.com/fatih/color v1.16.0 // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -26,6 +28,7 @@ require ( github.com/mitchellh/go-testing-interface v1.14.1 // indirect github.com/oklog/run v1.1.0 // indirect github.com/pganalyze/pg_query_go/v5 v5.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.52.2 // indirect github.com/prometheus/procfs v0.13.0 // indirect @@ -37,4 +40,5 @@ require ( golang.org/x/text v0.14.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3321435..55f17fe 100644 --- a/go.sum +++ b/go.sum @@ -114,5 +114,7 @@ google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 0000000..e412e31 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,289 @@ +package plugin + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" + v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" + "github.com/hashicorp/go-hclog" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_isSQLi(t *testing.T) { + p := &Plugin{ + EnableLibinjection: true, + Logger: hclog.NewNullLogger(), + } + // This is a false positive, since the query is not an SQL injection. + assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1")) + // This is an SQL injection. + assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1")) +} + +func Test_isSQLiDisabled(t *testing.T) { + p := &Plugin{ + EnableLibinjection: false, + Logger: hclog.NewNullLogger(), + } + // This is an SQL injection, but the libinjection is disabled. + assert.False(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1")) +} + +func Test_errorResponse(t *testing.T) { + p := &Plugin{ + Logger: hclog.NewNullLogger(), + } + + query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} + queryBytes, err := query.Encode(nil) + require.NoError(t, err) + + req := map[string]any{ + "request": queryBytes, + } + reqJSON, err := v1.NewStruct(req) + require.NoError(t, err) + assert.NotNil(t, reqJSON) + + resp := p.errorResponse( + reqJSON, + map[string]any{ + "score": 0.9999, + "detector": "deep_learning_model", + }, + ) + // We are modifying the pointer to the object, so they should be the same. + assert.Equal(t, reqJSON, resp) + assert.Len(t, resp.GetFields(), 3) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "response") + assert.Contains(t, resp.GetFields(), sdkAct.Signals) + // 2 signals: Terminate and Log. + assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2) +} + +func Test_OnTrafficFromClinet(t *testing.T) { + p := &Plugin{ + Logger: hclog.NewNullLogger(), + ModelName: "sqli_model", + ModelVersion: "2", + } + + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.URL.Path) + switch r.URL.Path { + case TokenizeAndSequencePath: + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + // This is the tokenized query: + // {"query":"select * from users where id = 1 or 1=1"} + resp := map[string][]float32{ + "tokens": { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12, + }, + } + data, _ := json.Marshal(resp) + _, err := w.Write(data) + require.NoError(t, err) + case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion): + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + // This is the output of the deep learning model. + resp := map[string][][]float32{"outputs": {{0.999909341}}} + data, _ := json.Marshal(resp) + _, err := w.Write(data) + require.NoError(t, err) + default: + w.WriteHeader(http.StatusNotFound) + } + }), + ) + defer server.Close() + + p.TokenizerAPIAddress = server.URL + p.ServingAPIAddress = server.URL + + query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} + queryBytes, err := query.Encode(nil) + require.NoError(t, err) + + req := map[string]any{ + "request": queryBytes, + } + reqJSON, err := v1.NewStruct(req) + require.NoError(t, err) + assert.NotNil(t, reqJSON) + + resp, err := p.OnTrafficFromClient(context.Background(), reqJSON) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.GetFields(), 4) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "query") + assert.Contains(t, resp.GetFields(), "response") + assert.Contains(t, resp.GetFields(), sdkAct.Signals) + // 2 signals: Terminate and Log. + assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2) +} + +func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) { + plugins := []*Plugin{ + { + Logger: hclog.NewNullLogger(), + ModelName: "sqli_model", + ModelVersion: "2", + // If libinjection is enabled, the response should contain the "response" field, + // and the "signals" field, which means the plugin will terminate the request. + EnableLibinjection: true, + }, + { + Logger: hclog.NewNullLogger(), + ModelName: "sqli_model", + ModelVersion: "2", + // If libinjection is disabled, the response should not contain the "response" field, + // and the "signals" field, which means the plugin will not terminate the request. + EnableLibinjection: false, + }, + } + + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.URL.Path) + switch r.URL.Path { + case TokenizeAndSequencePath: + w.WriteHeader(http.StatusInternalServerError) + default: + w.WriteHeader(http.StatusNotFound) + } + }), + ) + defer server.Close() + + for i := range plugins { + plugins[i].TokenizerAPIAddress = server.URL + plugins[i].ServingAPIAddress = server.URL + + query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} + queryBytes, err := query.Encode(nil) + require.NoError(t, err) + + req := map[string]any{ + "request": queryBytes, + } + reqJSON, err := v1.NewStruct(req) + require.NoError(t, err) + assert.NotNil(t, reqJSON) + + resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON) + require.NoError(t, err) + assert.NotNil(t, resp) + if plugins[i].EnableLibinjection { + assert.Len(t, resp.GetFields(), 4) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "query") + assert.Contains(t, resp.GetFields(), "response") + assert.Contains(t, resp.GetFields(), sdkAct.Signals) + // 2 signals: Terminate and Log. + assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2) + } else { + assert.Len(t, resp.GetFields(), 2) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "query") + assert.NotContains(t, resp.GetFields(), "response") + assert.NotContains(t, resp.GetFields(), sdkAct.Signals) + } + } +} + +func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) { + plugins := []*Plugin{ + { + Logger: hclog.NewNullLogger(), + ModelName: "sqli_model", + ModelVersion: "2", + // If libinjection is disabled, the response should not contain the "response" field, + // and the "signals" field, which means the plugin will not terminate the request. + EnableLibinjection: false, + }, + { + Logger: hclog.NewNullLogger(), + ModelName: "sqli_model", + ModelVersion: "2", + // If libinjection is enabled, the response should contain the "response" field, + // and the "signals" field, which means the plugin will terminate the request. + EnableLibinjection: true, + }, + } + + // This is the same for both plugins. + predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion) + + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.URL.Path) + switch r.URL.Path { + case TokenizeAndSequencePath: + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + // This is the tokenized query: + // {"query":"select * from users where id = 1 or 1=1"} + resp := map[string][]float32{ + "tokens": { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12, + }, + } + data, _ := json.Marshal(resp) + _, err := w.Write(data) + require.NoError(t, err) + case predictPath: + w.WriteHeader(http.StatusInternalServerError) + default: + w.WriteHeader(http.StatusNotFound) + } + }), + ) + defer server.Close() + + for i := range plugins { + plugins[i].TokenizerAPIAddress = server.URL + plugins[i].ServingAPIAddress = server.URL + + query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} + queryBytes, err := query.Encode(nil) + require.NoError(t, err) + + req := map[string]any{ + "request": queryBytes, + } + reqJSON, err := v1.NewStruct(req) + require.NoError(t, err) + assert.NotNil(t, reqJSON) + + resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON) + require.NoError(t, err) + assert.NotNil(t, resp) + if plugins[i].EnableLibinjection { + assert.Len(t, resp.GetFields(), 4) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "query") + assert.Contains(t, resp.GetFields(), "response") + assert.Contains(t, resp.GetFields(), sdkAct.Signals) + // 2 signals: Terminate and Log. + assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2) + } else { + assert.Len(t, resp.GetFields(), 2) + assert.Contains(t, resp.GetFields(), "request") + assert.Contains(t, resp.GetFields(), "query") + assert.NotContains(t, resp.GetFields(), "response") + assert.NotContains(t, resp.GetFields(), sdkAct.Signals) + } + } +}