From eec4b5d06e8fc24a3197329e38d3fa9ca50e1ed9 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 25 Sep 2024 11:37:10 +0200 Subject: [PATCH] refact pkg/database: extract function rollbackOnError(); dry error messages --- pkg/apiserver/alerts_test.go | 4 +-- pkg/apiserver/controllers/v1/errors.go | 12 ------- pkg/database/alertfilter.go | 2 +- pkg/database/alerts.go | 43 ++++++++------------------ pkg/database/errors.go | 1 - pkg/types/ip.go | 7 ++--- pkg/types/ip_test.go | 2 +- test/bats/90_decisions.bats | 4 +-- 8 files changed, 22 insertions(+), 53 deletions(-) diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index d86234e4813..00e9fbfb35a 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -242,7 +242,7 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) + assert.Equal(t, `{"message":"invalid ip address 'gruueq'"}`, w.Body.String()) // test range (ok) @@ -261,7 +261,7 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) + assert.Equal(t, `{"message":"invalid ip address 'ratata'"}`, w.Body.String()) // test since (ok) diff --git a/pkg/apiserver/controllers/v1/errors.go b/pkg/apiserver/controllers/v1/errors.go index d661de44b0e..d7b60c1a1b8 100644 --- a/pkg/apiserver/controllers/v1/errors.go +++ b/pkg/apiserver/controllers/v1/errors.go @@ -21,18 +21,6 @@ func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { case errors.Is(err, database.HashError): gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return - case errors.Is(err, database.InsertFail): - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return - case errors.Is(err, database.QueryFail): - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return - case errors.Is(err, database.ParseTimeFail): - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return - case errors.Is(err, database.ParseDurationFail): - gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return default: gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return diff --git a/pkg/database/alertfilter.go b/pkg/database/alertfilter.go index 566cd807b73..9e8cf53a450 100644 --- a/pkg/database/alertfilter.go +++ b/pkg/database/alertfilter.go @@ -204,7 +204,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e case "ip", "range": ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0]) if err != nil { - return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) + return nil, err } case "since", "created_before", "until": if err := handleTimeFilters(param, value[0], &predicates); err != nil { diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index a1445345260..4e3f209b012 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -31,6 +31,14 @@ const ( maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered ) +func rollbackOnError(tx *ent.Tx, err error, msg string) error { + if rbErr := tx.Rollback(); rbErr != nil { + log.Errorf("rollback error: %v", rbErr) + } + + return fmt.Errorf("%s: %w", msg, err) +} + // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them @@ -284,12 +292,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - - return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) + return 0,0,0, rollbackOnError(txClient, err, "parsing decision duration") } if decisionItem.Scope == nil { @@ -301,12 +304,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - - return 0, 0, 0, errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err) + return 0, 0, 0, rollbackOnError(txClient, err, "invalid ip addr/range") } } @@ -348,12 +346,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models decision.ValueIn(deleteChunk...), )).Exec(ctx) if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - - return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err) + return 0, 0, 0, rollbackOnError(txClient, err, "deleting older community blocklist decisions") } deleted += deletedDecisions @@ -364,12 +357,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models for _, builderChunk := range builderChunks { insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - - return 0, 0, 0, fmt.Errorf("while bulk creating decisions: %w", err) + return 0, 0, 0, rollbackOnError(txClient, err, "bulk creating decisions") } inserted += len(insertedDecisions) @@ -379,12 +367,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models err = txClient.Commit() if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - - return 0, 0, 0, fmt.Errorf("error committing transaction: %w", err) + return 0, 0, 0, rollbackOnError(txClient, err, "error committing transaction") } return alertRef.ID, inserted, deleted, nil diff --git a/pkg/database/errors.go b/pkg/database/errors.go index 77f92707e51..e0223be95b8 100644 --- a/pkg/database/errors.go +++ b/pkg/database/errors.go @@ -14,7 +14,6 @@ var ( ParseTimeFail = errors.New("unable to parse time") ParseDurationFail = errors.New("unable to parse duration") MarshalFail = errors.New("unable to serialize") - UnmarshalFail = errors.New("unable to parse") BulkError = errors.New("unable to insert bulk") ParseType = errors.New("unable to parse type") InvalidIPOrRange = errors.New("invalid ip address / range") diff --git a/pkg/types/ip.go b/pkg/types/ip.go index 9d08afd8809..53f4445a813 100644 --- a/pkg/types/ip.go +++ b/pkg/types/ip.go @@ -2,7 +2,6 @@ package types import ( "encoding/binary" - "errors" "fmt" "math" "net" @@ -38,7 +37,7 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { if strings.Contains(anyIP, "/") { _, net, err := net.ParseCIDR(anyIP) if err != nil { - return -1, 0, 0, 0, 0, fmt.Errorf("while parsing range %s: %w", anyIP, err) + return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip range '%s': %w", anyIP, err) } return Range2Ints(*net) @@ -46,12 +45,12 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { ip := net.ParseIP(anyIP) if ip == nil { - return -1, 0, 0, 0, 0, errors.New("invalid address") + return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip address '%s'", anyIP) } sz, start, end, err := IP2Ints(ip) if err != nil { - return -1, 0, 0, 0, 0, fmt.Errorf("while parsing ip %s: %w", anyIP, err) + return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip address '%s': %w", anyIP, err) } return sz, start, end, start, end, nil diff --git a/pkg/types/ip_test.go b/pkg/types/ip_test.go index f8c14b12e3c..66b04e9cafe 100644 --- a/pkg/types/ip_test.go +++ b/pkg/types/ip_test.go @@ -181,7 +181,7 @@ func TestAdd2Int(t *testing.T) { }, { in_addr: "xxx2", - exp_error: "invalid address", + exp_error: "invalid ip address 'xxx2'", }, } diff --git a/test/bats/90_decisions.bats b/test/bats/90_decisions.bats index 8601414db48..a1f72a31edf 100644 --- a/test/bats/90_decisions.bats +++ b/test/bats/90_decisions.bats @@ -172,7 +172,7 @@ teardown() { EOT assert_stderr --partial 'Parsing values' assert_stderr --partial 'Imported 1 decisions' - assert_file_contains "$LOGFILE" "invalid addr/range 'whatever': invalid address" + assert_file_contains "$LOGFILE" "invalid addr/range 'whatever': invalid ip address 'whatever'" rune -0 cscli decisions list -a -o json assert_json '[]' @@ -189,7 +189,7 @@ teardown() { EOT assert_stderr --partial 'Parsing values' assert_stderr --partial 'Imported 3 decisions' - assert_file_contains "$LOGFILE" "invalid addr/range 'bad-apple': invalid address" + assert_file_contains "$LOGFILE" "invalid addr/range 'bad-apple': invalid ip address 'bad-apple'" rune -0 cscli decisions list -a -o json rune -0 jq -r '.[0].decisions | length' <(output)