Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests #20

Merged
merged 3 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ jobs:
version: "v1.57"
skip-pkg-cache: true
install-mode: "goinstall"

- name: Run tests 🧪
run: make test
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ update-all:

build-dev: tidy
go build

test:
go test -v ./...
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ require (
github.com/jackc/pgx/v5 v5.5.5
github.com/prometheus/client_golang v1.19.0
github.com/spf13/cast v1.6.0
github.com/stretchr/testify v1.8.4
google.golang.org/grpc v1.63.0
)

require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/expr-lang/expr v1.16.3 // indirect
github.com/fatih/color v1.16.0 // indirect
github.com/golang/protobuf v1.5.4 // indirect
Expand All @@ -26,6 +28,7 @@ require (
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
github.com/oklog/run v1.1.0 // indirect
github.com/pganalyze/pg_query_go/v5 v5.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.52.2 // indirect
github.com/prometheus/procfs v0.13.0 // indirect
Expand All @@ -37,4 +40,5 @@ require (
golang.org/x/text v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

289 changes: 289 additions & 0 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
package plugin

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/hashicorp/go-hclog"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_isSQLi(t *testing.T) {
p := &Plugin{
EnableLibinjection: true,
Logger: hclog.NewNullLogger(),
}
// This is a false positive, since the query is not an SQL injection.
assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1"))
// This is an SQL injection.
assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1"))
}

func Test_isSQLiDisabled(t *testing.T) {
p := &Plugin{
EnableLibinjection: false,
Logger: hclog.NewNullLogger(),
}
// This is an SQL injection, but the libinjection is disabled.
assert.False(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1"))
}

func Test_errorResponse(t *testing.T) {
p := &Plugin{
Logger: hclog.NewNullLogger(),
}

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
require.NoError(t, err)

req := map[string]any{
"request": queryBytes,
}
reqJSON, err := v1.NewStruct(req)
require.NoError(t, err)
assert.NotNil(t, reqJSON)

resp := p.errorResponse(
reqJSON,
map[string]any{
"score": 0.9999,
"detector": "deep_learning_model",
},
)
// We are modifying the pointer to the object, so they should be the same.
assert.Equal(t, reqJSON, resp)
assert.Len(t, resp.GetFields(), 3)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "response")
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
// 2 signals: Terminate and Log.
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
}

func Test_OnTrafficFromClinet(t *testing.T) {
p := &Plugin{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
}

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r.URL.Path)
switch r.URL.Path {
case TokenizeAndSequencePath:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the tokenized query:
// {"query":"select * from users where id = 1 or 1=1"}
resp := map[string][]float32{
"tokens": {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
},
}
data, _ := json.Marshal(resp)
_, err := w.Write(data)
require.NoError(t, err)
case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion):
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the output of the deep learning model.
resp := map[string][][]float32{"outputs": {{0.999909341}}}
data, _ := json.Marshal(resp)
_, err := w.Write(data)
require.NoError(t, err)
default:
w.WriteHeader(http.StatusNotFound)
}
}),
)
defer server.Close()

p.TokenizerAPIAddress = server.URL
p.ServingAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
require.NoError(t, err)

req := map[string]any{
"request": queryBytes,
}
reqJSON, err := v1.NewStruct(req)
require.NoError(t, err)
assert.NotNil(t, reqJSON)

resp, err := p.OnTrafficFromClient(context.Background(), reqJSON)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.GetFields(), 4)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "query")
assert.Contains(t, resp.GetFields(), "response")
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
// 2 signals: Terminate and Log.
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
}

func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
plugins := []*Plugin{
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
// If libinjection is enabled, the response should contain the "response" field,
// and the "signals" field, which means the plugin will terminate the request.
EnableLibinjection: true,
},
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
// If libinjection is disabled, the response should not contain the "response" field,
// and the "signals" field, which means the plugin will not terminate the request.
EnableLibinjection: false,
},
}

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r.URL.Path)
switch r.URL.Path {
case TokenizeAndSequencePath:
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusNotFound)
}
}),
)
defer server.Close()

for i := range plugins {
plugins[i].TokenizerAPIAddress = server.URL
plugins[i].ServingAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
require.NoError(t, err)

req := map[string]any{
"request": queryBytes,
}
reqJSON, err := v1.NewStruct(req)
require.NoError(t, err)
assert.NotNil(t, reqJSON)

resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON)
require.NoError(t, err)
assert.NotNil(t, resp)
if plugins[i].EnableLibinjection {
assert.Len(t, resp.GetFields(), 4)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "query")
assert.Contains(t, resp.GetFields(), "response")
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
// 2 signals: Terminate and Log.
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
} else {
assert.Len(t, resp.GetFields(), 2)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "query")
assert.NotContains(t, resp.GetFields(), "response")
assert.NotContains(t, resp.GetFields(), sdkAct.Signals)
}
}
}

func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
plugins := []*Plugin{
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
// If libinjection is disabled, the response should not contain the "response" field,
// and the "signals" field, which means the plugin will not terminate the request.
EnableLibinjection: false,
},
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
// If libinjection is enabled, the response should contain the "response" field,
// and the "signals" field, which means the plugin will terminate the request.
EnableLibinjection: true,
},
}

// This is the same for both plugins.
predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion)

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r.URL.Path)
switch r.URL.Path {
case TokenizeAndSequencePath:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the tokenized query:
// {"query":"select * from users where id = 1 or 1=1"}
resp := map[string][]float32{
"tokens": {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
},
}
data, _ := json.Marshal(resp)
_, err := w.Write(data)
require.NoError(t, err)
case predictPath:
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusNotFound)
}
}),
)
defer server.Close()

for i := range plugins {
plugins[i].TokenizerAPIAddress = server.URL
plugins[i].ServingAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
require.NoError(t, err)

req := map[string]any{
"request": queryBytes,
}
reqJSON, err := v1.NewStruct(req)
require.NoError(t, err)
assert.NotNil(t, reqJSON)

resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON)
require.NoError(t, err)
assert.NotNil(t, resp)
if plugins[i].EnableLibinjection {
assert.Len(t, resp.GetFields(), 4)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "query")
assert.Contains(t, resp.GetFields(), "response")
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
// 2 signals: Terminate and Log.
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
} else {
assert.Len(t, resp.GetFields(), 2)
assert.Contains(t, resp.GetFields(), "request")
assert.Contains(t, resp.GetFields(), "query")
assert.NotContains(t, resp.GetFields(), "response")
assert.NotContains(t, resp.GetFields(), sdkAct.Signals)
}
}
}