From f840160b31dda86f735f49fd9b71e649a7a7fd2d Mon Sep 17 00:00:00 2001 From: kevkevinpal Date: Wed, 11 Dec 2024 16:36:46 -0500 Subject: [PATCH] middleware --- handlers/auth_test.go | 2 +- routes/bounty.go | 4 ---- routes/index.go | 25 +++++++++++++++++++++++++ routes/ticket_routes.go | 3 --- utils/error_handler.go | 21 ++++++++++++++++++--- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 2528e4688..46b73bd05 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -39,7 +39,7 @@ func TestGetAdminPubkeys(t *testing.T) { handler.ServeHTTP(rr, req) - if status := rr.Code; status != rr.Status { + if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } diff --git a/routes/bounty.go b/routes/bounty.go index 0974ae849..ea4eb4c6e 100644 --- a/routes/bounty.go +++ b/routes/bounty.go @@ -7,15 +7,11 @@ import ( "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" - "github.com/stakwork/sphinx-tribes/utils" ) func BountyRoutes() chi.Router { r := chi.NewRouter() bountyHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB) - - r.Use(utils.ErrorHandler) - r.Group(func(r chi.Router) { r.Get("/all", bountyHandler.GetAllBounties) diff --git a/routes/index.go b/routes/index.go index 8da500084..de8b21095 100644 --- a/routes/index.go +++ b/routes/index.go @@ -140,11 +140,36 @@ func getFromAuth(path string) (*extractResponse, error) { }, nil } +// Middleware to handle InternalServerError +func internalServerErrorHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rr := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(rr, r) + + if rr.statusCode == http.StatusInternalServerError { + fmt.Printf("Internal Server Error: %s %s\n", r.Method, r.URL.Path) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + }) +} + +// Custom ResponseWriter to capture status codes +type responseRecorder struct { + http.ResponseWriter + statusCode int +} + +func (rr *responseRecorder) WriteHeader(code int) { + rr.statusCode = code + rr.ResponseWriter.WriteHeader(code) +} + func initChi() *chi.Mux { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.Logger) r.Use(middleware.Recoverer) + r.Use(internalServerErrorHandler) cors := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, diff --git a/routes/ticket_routes.go b/routes/ticket_routes.go index 34ebfb8ea..a45cd46c0 100644 --- a/routes/ticket_routes.go +++ b/routes/ticket_routes.go @@ -7,15 +7,12 @@ import ( "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" - "github.com/stakwork/sphinx-tribes/utils" ) func TicketRoutes() chi.Router { r := chi.NewRouter() ticketHandler := handlers.NewTicketHandler(http.DefaultClient, db.DB) - r.Use(utils.ErrorHandler) - r.Group(func(r chi.Router) { r.Get("/{uuid}", ticketHandler.GetTicket) r.Post("/review", ticketHandler.ProcessTicketReview) diff --git a/utils/error_handler.go b/utils/error_handler.go index f62195d47..dc316153e 100644 --- a/utils/error_handler.go +++ b/utils/error_handler.go @@ -9,11 +9,26 @@ import ( "net/http" ) -type customError struct { - error +type CustomError struct { + Err error StatusCode int } +func NewCustomError(err error, statusCode int) *CustomError { + return &CustomError{ + Err: err, + StatusCode: statusCode, + } +} + +// Error implements the error interface. +func (e *CustomError) Error() string { + if e.Err != nil { + return e.Err.Error() + } + return fmt.Sprintf("HTTP %d", e.StatusCode) +} + func ErrorHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { @@ -31,7 +46,7 @@ func ErrorHandler(next http.Handler) http.Handler { statusCode := http.StatusNotFound if errors.Is(ww.error, sql.ErrNoRows) { statusCode = http.StatusNotFound - } else if err, ok := ww.error.(*customError); ok { + } else if err, ok := ww.error.(*CustomError); ok { statusCode = err.StatusCode }