Skip to content

Commit

Permalink
Add integration with the new Prediction API
Browse files Browse the repository at this point in the history
Remove the integration with Tokenization and Serving APIs
Remove unused env-vars
Update tests to reflect changes
Update plugin config
  • Loading branch information
mostafa committed Oct 28, 2024
1 parent a1511a8 commit 9865e83
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 115 deletions.
5 changes: 1 addition & 4 deletions gatewayd_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ plugins:
- METRICS_ENABLED=True
- METRICS_UNIX_DOMAIN_SOCKET=/tmp/gatewayd-plugin-sql-ids-ips.sock
- METRICS_PATH=/metrics
- TOKENIZER_API_ADDRESS=http://localhost:8000
- SERVING_API_ADDRESS=http://localhost:8501
- MODEL_NAME=sqli_model
- MODEL_VERSION=3
- PREDICTION_API_ADDRESS=http://localhost:8000
# Threshold determine the minimum prediction confidence
# required to detect an SQL injection attack. Any value
# between 0 and 1 is valid, and it is inclusive.
Expand Down
5 changes: 1 addition & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ func main() {
pluginInstance.Impl.EnableLibinjection = cast.ToBool(cfg["enableLibinjection"])
pluginInstance.Impl.LibinjectionPermissiveMode = cast.ToBool(
cfg["libinjectionPermissiveMode"])
pluginInstance.Impl.TokenizerAPIAddress = cast.ToString(cfg["tokenizerAPIAddress"])
pluginInstance.Impl.ServingAPIAddress = cast.ToString(cfg["servingAPIAddress"])
pluginInstance.Impl.ModelName = cast.ToString(cfg["modelName"])
pluginInstance.Impl.ModelVersion = cast.ToString(cfg["modelVersion"])
pluginInstance.Impl.PredictionAPIAddress = cast.ToString(cfg["predictionAPIAddress"])

pluginInstance.Impl.ResponseType = cast.ToString(cfg["responseType"])
pluginInstance.Impl.ErrorMessage = cast.ToString(cfg["errorMessage"])
Expand Down
6 changes: 2 additions & 4 deletions plugin/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package plugin
const (
DecodedQueryField string = "decodedQuery"
DetectorField string = "detector"
ScoreField string = "score"
QueryField string = "query"
ErrorField string = "error"
IsInjectionField string = "is_injection"
ResponseField string = "response"
OutputsField string = "outputs"
ConfidenceField string = "confidence"
TokensField string = "tokens"
StringField string = "String"
ResponseTypeField string = "response_type"
Expand All @@ -23,6 +22,5 @@ const (
ErrorDetail string = "Back off, you're not welcome here."
LogLevel string = "error"

TokenizeAndSequencePath string = "/tokenize_and_sequence"
PredictPath string = "/v1/models/%s/versions/%s:predict"
PredictPath string = "/predict"
)
8 changes: 2 additions & 6 deletions plugin/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ var (
"metricsUnixDomainSocket": sdkConfig.GetEnv(
"METRICS_UNIX_DOMAIN_SOCKET", "/tmp/gatewayd-plugin-sql-ids-ips.sock"),
"metricsEndpoint": sdkConfig.GetEnv("METRICS_ENDPOINT", "/metrics"),
"tokenizerAPIAddress": sdkConfig.GetEnv(
"TOKENIZER_API_ADDRESS", "http://localhost:8000"),
"servingAPIAddress": sdkConfig.GetEnv(
"SERVING_API_ADDRESS", "http://localhost:8501"),
"modelName": sdkConfig.GetEnv("MODEL_NAME", "sqli_model"),
"modelVersion": sdkConfig.GetEnv("MODEL_VERSION", "1"),
"predictionAPIAddress": sdkConfig.GetEnv(
"PREDICTION_API_ADDRESS", "http://localhost:8000"),
"threshold": sdkConfig.GetEnv("THRESHOLD", "0.8"),
"enableLibinjection": sdkConfig.GetEnv("ENABLE_LIBINJECTION", "true"),
"libinjectionPermissiveMode": sdkConfig.GetEnv("LIBINJECTION_MODE", "true"),
Expand Down
54 changes: 12 additions & 42 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/base64"
"encoding/json"
"fmt"

"github.com/carlmjohnson/requests"
"github.com/corazawaf/libinjection-go"
Expand All @@ -28,10 +27,7 @@ type Plugin struct {
Threshold float32
EnableLibinjection bool
LibinjectionPermissiveMode bool
TokenizerAPIAddress string
ServingAPIAddress string
ModelName string
ModelVersion string
PredictionAPIAddress string
ResponseType string
ErrorMessage string
ErrorSeverity string
Expand Down Expand Up @@ -111,36 +107,12 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
}
queryString := cast.ToString(queryMap[StringField])

var tokens map[string]any
err = requests.
URL(p.TokenizerAPIAddress).
Path(TokenizeAndSequencePath).
BodyJSON(map[string]any{
QueryField: queryString,
}).
ToJSON(&tokens).
Fetch(context.Background())
if err != nil {
p.Logger.Error("Failed to make POST request", ErrorField, err)
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
return p.prepareResponse(
req,
map[string]any{
QueryField: queryString,
DetectorField: Libinjection,
ErrorField: "Failed to make POST request to tokenizer API",
},
), nil
}
return req, nil
}

var output map[string]any
err = requests.
URL(p.ServingAPIAddress).
Path(fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion)).
URL(p.PredictionAPIAddress).
Path(PredictPath).
BodyJSON(map[string]any{
"inputs": []any{cast.ToSlice(tokens[TokensField])},
QueryField: queryString,
}).
ToJSON(&output).
Fetch(context.Background())
Expand All @@ -152,34 +124,32 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
map[string]any{
QueryField: queryString,
DetectorField: Libinjection,
ErrorField: "Failed to make POST request to serving API",
ErrorField: "Failed to make POST request to tokenizer API",
},
), nil
}
return req, nil
}

predictions := cast.ToSlice(output[OutputsField])
scores := cast.ToSlice(predictions[0])
score := cast.ToFloat32(scores[0])
p.Logger.Trace("Deep learning model prediction", ScoreField, score)
confidence := cast.ToFloat32(output[ConfidenceField])
p.Logger.Trace("Deep learning model prediction", ConfidenceField, confidence)

// Check the prediction against the threshold,
// otherwise check if the query is an SQL injection using libinjection.
injection := p.isSQLi(queryString)
if score >= p.Threshold {
if confidence >= p.Threshold {
if p.EnableLibinjection && !injection {
p.Logger.Debug("False positive detected", DetectorField, Libinjection)
}

Detections.With(map[string]string{DetectorField: DeepLearningModel}).Inc()
p.Logger.Warn(p.ErrorMessage, ScoreField, score, DetectorField, DeepLearningModel)
p.Logger.Warn(p.ErrorMessage, ConfidenceField, confidence, DetectorField, DeepLearningModel)
return p.prepareResponse(
req,
map[string]any{
QueryField: queryString,
ScoreField: score,
DetectorField: DeepLearningModel,
QueryField: queryString,
ConfidenceField: confidence,
DetectorField: DeepLearningModel,
},
), nil
} else if p.EnableLibinjection && injection && !p.LibinjectionPermissiveMode {
Expand Down
66 changes: 11 additions & 55 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package plugin
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -71,28 +70,13 @@ func Test_errorResponse(t *testing.T) {

func Test_OnTrafficFromClinet(t *testing.T) {
p := &Plugin{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
Logger: hclog.NewNullLogger(),
}

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case TokenizeAndSequencePath:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the tokenized query:
// {"query":"select * from users where id = 1 or 1=1"}
resp := map[string][]float32{
"tokens": {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
},
}
data, _ := json.Marshal(resp)
_, err := w.Write(data)
require.NoError(t, err)
case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion):
case PredictPath:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the output of the deep learning model.
Expand All @@ -107,8 +91,7 @@ func Test_OnTrafficFromClinet(t *testing.T) {
)
defer server.Close()

p.TokenizerAPIAddress = server.URL
p.ServingAPIAddress = server.URL
p.PredictionAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
Expand Down Expand Up @@ -136,17 +119,13 @@ func Test_OnTrafficFromClinet(t *testing.T) {
func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
plugins := []*Plugin{
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
Logger: hclog.NewNullLogger(),
// If libinjection is enabled, the response should contain the "response" field,
// and the "signals" field, which means the plugin will terminate the request.
EnableLibinjection: true,
},
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
Logger: hclog.NewNullLogger(),
// If libinjection is disabled, the response should not contain the "response" field,
// and the "signals" field, which means the plugin will not terminate the request.
EnableLibinjection: false,
Expand All @@ -156,7 +135,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case TokenizeAndSequencePath:
case PredictPath:
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusNotFound)
Expand All @@ -166,8 +145,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
defer server.Close()

for i := range plugins {
plugins[i].TokenizerAPIAddress = server.URL
plugins[i].ServingAPIAddress = server.URL
plugins[i].PredictionAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
Expand Down Expand Up @@ -204,43 +182,22 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
plugins := []*Plugin{
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
Logger: hclog.NewNullLogger(),
// If libinjection is disabled, the response should not contain the "response" field,
// and the "signals" field, which means the plugin will not terminate the request.
EnableLibinjection: false,
},
{
Logger: hclog.NewNullLogger(),
ModelName: "sqli_model",
ModelVersion: "2",
Logger: hclog.NewNullLogger(),
// If libinjection is enabled, the response should contain the "response" field,
// and the "signals" field, which means the plugin will terminate the request.
EnableLibinjection: true,
},
}

// This is the same for both plugins.
predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion)

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case TokenizeAndSequencePath:
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// This is the tokenized query:
// {"query":"select * from users where id = 1 or 1=1"}
resp := map[string][]float32{
"tokens": {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
},
}
data, _ := json.Marshal(resp)
_, err := w.Write(data)
require.NoError(t, err)
case predictPath:
case PredictPath:
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusNotFound)
Expand All @@ -250,8 +207,7 @@ func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
defer server.Close()

for i := range plugins {
plugins[i].TokenizerAPIAddress = server.URL
plugins[i].ServingAPIAddress = server.URL
plugins[i].PredictionAPIAddress = server.URL

query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
queryBytes, err := query.Encode(nil)
Expand Down

0 comments on commit 9865e83

Please sign in to comment.