diff --git a/netquery/database.go b/netquery/database.go index aad649ce3..ac47197c4 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -57,6 +57,13 @@ type ( writeConn *sqlite.Conn } + BatchExecute struct { + ID string + SQL string + Params map[string]any + Result *[]map[string]any + } + // Conn is a network connection that is stored in a SQLite database and accepted // by the *Database type of this package. This also defines, using the ./orm package, // the table schema and the model that is exposed via the runtime database as well as @@ -325,6 +332,22 @@ func (db *Database) Execute(ctx context.Context, sql string, args ...orm.QueryOp }) } +// ExecuteBatch executes multiple custom SQL query using a read-only connection against the SQLite +// database used by db. +func (db *Database) ExecuteBatch(ctx context.Context, batches []BatchExecute) error { + return db.withConn(ctx, func(conn *sqlite.Conn) error { + merr := new(multierror.Error) + + for _, batch := range batches { + if err := orm.RunQuery(ctx, conn, batch.SQL, orm.WithNamedArgs(batch.Params), orm.WithResult(batch.Result)); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("%s: %w", batch.ID, err)) + } + } + + return merr.ErrorOrNil() + }) +} + // CountRows returns the number of rows stored in the database. func (db *Database) CountRows(ctx context.Context) (int, error) { var result []struct { diff --git a/netquery/module_api.go b/netquery/module_api.go index 5e182fd17..396471862 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -82,6 +82,11 @@ func (m *module) prepare() error { IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } + batchHander := &BatchQueryHandler{ + Database: m.Store, + IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), + } + chartHandler := &ChartHandler{ Database: m.Store, } @@ -99,6 +104,19 @@ func (m *module) prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + Name: "Batch Query Connections", + Description: "Batch query the in-memory sqlite connection database.", + Path: "netquery/query/batch", + MimeType: "application/json", + Read: api.PermitUser, // Needs read+write as the query is sent using POST data. + Write: api.PermitUser, // Needs read+write as the query is sent using POST data. + BelongsTo: m.Module, + HandlerFunc: servertiming.Middleware(batchHander, nil).ServeHTTP, + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + if err := api.RegisterEndpoint(api.Endpoint{ Name: "Active Connections Chart", Description: "Query the in-memory sqlite connection database and return a chart of active connections.", diff --git a/netquery/query_handler.go b/netquery/query_handler.go index 0f6ec7251..ae848a9e1 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -10,6 +10,7 @@ import ( "regexp" "strings" + "github.com/hashicorp/go-multierror" servertiming "github.com/mitchellh/go-server-timing" "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" @@ -113,6 +114,98 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { // we can do. log.Errorf("failed to encode JSON response: %s", err) + return + } + +} +func (batch *BatchQueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + timing := servertiming.FromContext(req.Context()) + + timingQueryParsed := timing.NewMetric("query_parsed"). + WithDesc("Query has been parsed"). + Start() + + requestPayload, err := parseQueryRequestPayload[BatchQueryRequestPayload](req) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + timingQueryParsed.Stop() + + response := make(map[string][]map[string]any, len(*requestPayload)) + + batches := make([]BatchExecute, 0, len(*requestPayload)) + + for key, query := range *requestPayload { + + timingQueryBuilt := timing.NewMetric("query_built_" + key). + WithDesc("The SQL query has been built"). + Start() + + sql, paramMap, err := query.generateSQL(req.Context(), batch.Database.Schema) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + timingQueryBuilt.Stop() + + var result []map[string]any + batches = append(batches, BatchExecute{ + ID: key, + SQL: sql, + Params: paramMap, + Result: &result, + }) + } + + timingQueryExecute := timing.NewMetric("sql_exec"). + WithDesc("SQL query execution time"). + Start() + + status := http.StatusOK + if err := batch.Database.ExecuteBatch(req.Context(), batches); err != nil { + status = http.StatusInternalServerError + + var merr *multierror.Error + if errors.As(err, &merr) { + for _, e := range merr.Errors { + resp.Header().Add("X-Query-Error", e.Error()) + } + } else { + // Should not happen, ExecuteBatch always returns a multierror.Error + resp.WriteHeader(status) + + return + } + } + + timingQueryExecute.Stop() + + // collect the results + for _, b := range batches { + response[b.ID] = *b.Result + } + + // send the HTTP status code + resp.WriteHeader(status) + + // prepare the result encoder. + enc := json.NewEncoder(resp) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + + // and finally stream the response + if err := enc.Encode(response); err != nil { + // we failed to encode the JSON body to resp so we likely either already sent a + // few bytes or the pipe was already closed. In either case, trying to send the + // error using http.Error() is non-sense. We just log it out here and that's all + // we can do. + log.Errorf("failed to encode JSON response: %s", err) + return } }