diff --git a/.github/workflows/checks.golang.yml b/.github/workflows/checks.golang.yml index cc684ea..948cd8b 100644 --- a/.github/workflows/checks.golang.yml +++ b/.github/workflows/checks.golang.yml @@ -30,3 +30,6 @@ jobs: export PATH=$PATH:$(go env GOPATH)/bin make install_lint make lint + - name: Run tests + run: | + make test \ No newline at end of file diff --git a/Makefile b/Makefile index 6e89c1c..e66170d 100644 --- a/Makefile +++ b/Makefile @@ -51,6 +51,9 @@ lint: golangci-lint --version golangci-lint run +test: + go test -v -race ./... + earthly: earthly +all diff --git a/go.mod b/go.mod index c6d2f39..6dd2458 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.19 require ( github.com/ClickHouse/clickhouse-go/v2 v2.14.2 github.com/go-chi/chi/v5 v5.0.10 + github.com/go-chi/cors v1.2.1 github.com/prometheus/client_golang v1.17.0 github.com/spf13/cobra v1.7.0 github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.26.0 golang.org/x/sync v0.4.0 + golang.org/x/time v0.3.0 gorm.io/driver/clickhouse v0.5.1 gorm.io/driver/postgres v1.5.3 gorm.io/gorm v1.25.5 diff --git a/go.sum b/go.sum index 37bf87f..02a148c 100644 --- a/go.sum +++ b/go.sum @@ -380,6 +380,8 @@ github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2H github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw= github.com/go-faster/errors v0.6.1 h1:nNIPOBkprlKzkThvS/0YaX8Zs9KewLCOSFQS5BU06FI= @@ -1349,6 +1351,8 @@ golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/zdb/methods_test.go b/pkg/zdb/methods_test.go index b1046d9..5763bf5 100644 --- a/pkg/zdb/methods_test.go +++ b/pkg/zdb/methods_test.go @@ -85,14 +85,14 @@ func (suite *ZDatabaseSuite) TestExec() { func (suite *ZDatabaseSuite) TestSelect() { suite.db.(*MockZDatabase).On("Select", "name", []interface{}{"Messi"}).Return(suite.db) - newDb := suite.db.Select("name", "Messi") + newDb := suite.db.Select("name", []interface{}{"Messi"}) suite.NotNil(newDb) suite.db.(*MockZDatabase).AssertExpectations(suite.T()) } func (suite *ZDatabaseSuite) TestWhere() { suite.db.(*MockZDatabase).On("Where", "name = ?", []interface{}{"Messi"}).Return(suite.db) - newDb := suite.db.Where("name = ?", "Messi") + newDb := suite.db.Where("name = ?", []interface{}{"Messi"}) suite.NotNil(newDb) suite.db.(*MockZDatabase).AssertExpectations(suite.T()) } diff --git a/pkg/zrouter/README.md b/pkg/zrouter/README.md new file mode 100644 index 0000000..e56c41c --- /dev/null +++ b/pkg/zrouter/README.md @@ -0,0 +1,240 @@ +# ZRouter package + +ZRouter is a Golang routing library built on the robust foundation of the chi router. + +## Table of Contents + +- [Features](#features) +- [Getting Started](#getting-started) + - [Installation](#installation) +- [Usage](#usage) +- [Routing](#routing) +- [Middleware](#middleware) +- [Adapters](#adapters) +- [Custom Configurations](#custom-configurations) +- [Monitoring and Logging](#monitoring-and-logging) +- [Advanced Topics](#advanced-topics) +- [Examples](#examples) +- [Conclusion](#conclusion) + +## Features + +- **Intuitive Routing Interface**: Define your routes with ease using all major HTTP methods. +- **Middleware Chaining**: Introduce layers of middleware to your HTTP requests and responses. +- **Flexible Adapters**: Seamlessly integrate with the `chi` router's context. +- **Enhanced Monitoring**: Integrated metrics server and structured logging for in-depth observability. +- **Customizable Settings**: Adjust server configurations like timeouts to suit your needs. + +## Getting Started + +### Installation + +To incorporate ZRouter into your project: + +```bash +go get github.com/zondax/golem/pkg/zrouter +``` + +## Usage + +Crafting a web service using ZRouter: + +```go +import "github.com/zondax/golem/pkg/zrouter" + +config := &zrouter.Config{ReadTimeOut: 10 * time.Second, WriteTimeOut: 10 * time.Second} +router := zrouter.New("ServiceName", metricServer, config) + +router.Use(middlewareLogic) + +groupedRoutes := router.Group("/grouped") +groupedRoutes.GET("/{param}", handlerFunction) +``` + +or + +```go +import "github.com/zondax/golem/pkg/zrouter" + +func main() { +router := zrouter.New("MyService", metricServer, nil) + +router.GET("/endpoint", func(ctx zrouter.Context) (domain.ServiceResponse, error) { +// Handler implementation +}) + +router.Run() +} +``` + +## Routing + +For dynamic URL parts, utilize the chi style, e.g., /entities/{entityID}. + +## Middleware + +Add pre- and post-processing steps to your routes. Chain multiple middlewares for enhanced functionality. + +### **Default Middlewares** + +`ZRouter` comes bundled with certain default middlewares for enhanced functionality and ease of use: + +- **ErrorHandlerMiddleware**: Systematically manages errors by translating them into a consistent response format. +- **RequestID()**: Attaches a unique request ID to every request, facilitating request tracking and debugging. +- **RequestMetrics()**: Monitors and logs metrics associated with requests, responses, and other interactions for performance insights. + +To activate these middlewares, make sure to call: + +```go +zr := zrouter.New("AppName", metricsServer, nil) +zr.SetDefaultMiddlewares() //Call this method! +``` +### **Additional Middlewares** + +Beyond the default offerings, `ZRouter` also provides extra middlewares to address specific needs: + +- **DefaultCors()**: Introduces a predefined set of Cross-Origin Resource Sharing (CORS) rules, facilitating browsers to make requests across origins safely. +- **Cors(options CorsOptions)**: A flexible CORS middleware that allows you to set specific CORS policies, such as permitted origins, headers, and methods, tailored to your application's demands. +- **RateLimit(maxRPM int)**: Shields your application from being swamped by imposing a rate limit on the influx of requests. By setting `maxRPM`, you can decide the maximum number of permissible requests per minute. + +## Adapters + +Use `chiContextAdapter` for translating the `chi` router's context to ZRouter's. + +## Custom Configurations + +Specify server behavior with `Config`. Use default settings or customize as needed. + +Default settings: +- `ReadTimeOut`: 240000 milliseconds. +- `WriteTimeOut`: 240000 milliseconds. +- `Logger`: Uses production logger settings by default. + +Override these defaults by providing values during initialization. + +Example: +```go + config := &Config{ + ReadTimeOut: 25000 * time.Millisecond, + WriteTimeOut: 25000 * time.Millisecond, + Logger: zapLoggerInstance, + } + zr := New("YourAppName", metricsServerInstance, config) +``` + +## Response Standards + +### ServiceResponse + +When handling responses, ZRouter provides a standardized way to return them using `ServiceResponse`, which includes status, headers, and body. + +**Example**: + +```go +func MyHandler(ctx Context) (domain.ServiceResponse, error) { + data := map[string]string{"message": "Hello, World!"} + return domain.NewServiceResponse(http.StatusOK, data), nil +} +``` + +### Handling Headers + +With `ServiceResponse`, you can easily set custom headers for your responses: + +```go +func MyHandler(ctx Context) (domain.ServiceResponse, error) { + headers := make(http.Header) + headers.Set("X-Custom-Header", "My Value") + + data := map[string]string{"message": "Hello, World!"} + response := domain.NewServiceResponseWithHeader(http.StatusOK, data, headers) + return response, nil +} +``` +### Error Handling + +Whenever you return an error, ZRouter translates it to a structured error response, maintaining consistency across your services. + +**Example**: + +```go +func MyHandler(ctx Context) (domain.ServiceResponse, error) { + return nil, domain.NewAPIErrorResponse(http.StatusNotFound, "not_found", "message") +} +``` + +## Context in ZRouter + +The `Context` is an essential part of ZRouter, providing a consistent interface to interact with the HTTP request and offering helper methods to streamline handler operations. This abstraction ensures that, as your router's needs evolve, the core interface to access request information remains consistent. + +### Functions and Usage: + +1. **Request**: + + Retrieve the raw `*http.Request` from the context: + + ```go + req := ctx.Request() + ``` + +2. **BindJSON**: + + Decode a JSON request body directly into a provided object: + + ```go + var myData MyStruct + err := ctx.BindJSON(&myData) + ``` + +3. **Header**: + + Set an HTTP header for the response: + + ```go + ctx.Header("X-Custom-Header", "Custom Value") + ``` + +4. **Param**: + + Get URL parameters (path variables): + + ```go + userID := ctx.Param("userID") + ``` + +5. **Query**: + + Retrieve a query parameter from the URL: + + ```go + sortBy := ctx.Query("sortBy") + ``` + +6. **DefaultQuery**: + + Retrieve a query parameter from the URL, but return a default value if it's not present: + + ```go + order := ctx.DefaultQuery("order", "asc") + ``` + +### Adapting to chi: + +Behind the scenes, ZRouter leverages the powerful `chi` router. The `chiContextAdapter` translates the chi context to ZRouter's, ensuring that you get the benefits of chi's speed and power with ZRouter's simplified and consistent interface. + +## Monitoring and Logging + +Monitor request metrics and employ structured logging for in-depth insights. + +## Advanced Topics + +- **Route Grouping**: Consolidate routes under specific prefixes using `Group()`. +- **NotFound Handling**: Specify custom logic for unmatched routes. +- **Route Tracking**: Fetch a structured list of all registered routes. + +### **Why ZRouter?** + +- **Consistent Standard:** In a world full of routers, `ZRouter` gives us a way to keep things standard across our projects. +- **Flexibility:** Today we're using `chi`, but what about tomorrow? With `ZRouter`, if we ever want to switch, we can do it here and keep everything else unchanged. +- **Speed & Power of Chi:** We get all the speed and flexibility of routers like `chi` but without tying ourselves down to one specific router. +- **Unified Approach:** `ZRouter` sets a clear standard for how we handle metrics, responses, errors, and more. It's about making sure everything works the same way, every time. \ No newline at end of file diff --git a/pkg/zrouter/context.go b/pkg/zrouter/context.go new file mode 100644 index 0000000..148c016 --- /dev/null +++ b/pkg/zrouter/context.go @@ -0,0 +1,50 @@ +package zrouter + +import ( + "encoding/json" + "github.com/go-chi/chi/v5" + "net/http" +) + +type Context interface { + Request() *http.Request + BindJSON(obj interface{}) error + Header(key, value string) + Param(key string) string + Query(key string) string + DefaultQuery(key, defaultValue string) string +} + +type chiContextAdapter struct { + ctx http.ResponseWriter + req *http.Request +} + +func (c *chiContextAdapter) Request() *http.Request { + return c.req +} + +func (c *chiContextAdapter) BindJSON(obj interface{}) error { + return json.NewDecoder(c.req.Body).Decode(obj) +} + +func (c *chiContextAdapter) Header(key, value string) { + c.ctx.Header().Set(key, value) +} + +func (c *chiContextAdapter) Param(key string) string { + return chi.URLParam(c.req, key) +} + +func (c *chiContextAdapter) Query(key string) string { + values := c.req.URL.Query() + return values.Get(key) +} + +func (c *chiContextAdapter) DefaultQuery(key, defaultValue string) string { + value := c.Query(key) + if value == "" { + return defaultValue + } + return value +} diff --git a/pkg/zrouter/context_mock.go b/pkg/zrouter/context_mock.go new file mode 100644 index 0000000..8a518b0 --- /dev/null +++ b/pkg/zrouter/context_mock.go @@ -0,0 +1,39 @@ +package zrouter + +import ( + "github.com/stretchr/testify/mock" + "net/http" +) + +type MockContext struct { + mock.Mock +} + +func (m *MockContext) Request() *http.Request { + args := m.Called() + return args.Get(0).(*http.Request) +} + +func (m *MockContext) BindJSON(obj interface{}) error { + args := m.Called(obj) + return args.Error(0) +} + +func (m *MockContext) Header(key, value string) { + m.Called(key, value) +} + +func (m *MockContext) Param(key string) string { + args := m.Called(key) + return args.String(0) +} + +func (m *MockContext) Query(key string) string { + args := m.Called(key) + return args.String(0) +} + +func (m *MockContext) DefaultQuery(key, defaultValue string) string { + args := m.Called(key, defaultValue) + return args.String(0) +} diff --git a/pkg/zrouter/context_test.go b/pkg/zrouter/context_test.go new file mode 100644 index 0000000..0ff692d --- /dev/null +++ b/pkg/zrouter/context_test.go @@ -0,0 +1,56 @@ +package zrouter + +import ( + "bytes" + "encoding/json" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "net/http" + "net/http/httptest" + "testing" +) + +type ChiContextAdapterSuite struct { + suite.Suite +} + +func (suite *ChiContextAdapterSuite) TestChiContextAdapter() { + r := chi.NewRouter() + r.Get("/hello/{name}", func(w http.ResponseWriter, req *http.Request) { + adapter := &chiContextAdapter{ctx: w, req: req} + assert.NotNil(suite.T(), adapter.Request()) + + var input struct { + Message string `json:"message"` + } + err := adapter.BindJSON(&input) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), "Hello", input.Message) + + adapter.Header("Custom-Header", "CustomValue") + + w.Header().Set("Content-Type", "application/json") + response := map[string]string{"response": "OK"} + err = json.NewEncoder(w).Encode(response) + assert.NoError(suite.T(), err) + }) + + body := bytes.NewBuffer([]byte(`{"message":"Hello"}`)) + req := httptest.NewRequest("GET", "/hello/world?test=query", body) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + var output map[string]string + err := json.NewDecoder(rec.Body).Decode(&output) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), "OK", output["response"]) + assert.Equal(suite.T(), "application/json", rec.Header().Get("Content-Type")) + assert.Equal(suite.T(), "CustomValue", rec.Header().Get("Custom-Header")) +} + +func TestChiContextAdapterSuite(t *testing.T) { + suite.Run(t, new(ChiContextAdapterSuite)) +} diff --git a/pkg/zrouter/domain/apierror.go b/pkg/zrouter/domain/apierror.go new file mode 100644 index 0000000..734665c --- /dev/null +++ b/pkg/zrouter/domain/apierror.go @@ -0,0 +1,28 @@ +package domain + +import "fmt" + +type APIError struct { + HTTPStatus int `json:"-"` + ErrorCode string `json:"error_code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +func (ae *APIError) Error() string { + return fmt.Sprintf("HTTP Status: %d, ErrorCode: %s, Message: %s", ae.HTTPStatus, ae.ErrorCode, ae.Message) +} + +func NewAPIErrorResponse(httpStatus int, errorCode, message string, details ...string) *APIError { + apiError := &APIError{ + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: message, + } + + if len(details) > 0 { + apiError.Details = details[0] + } + + return apiError +} diff --git a/pkg/zrouter/domain/response.go b/pkg/zrouter/domain/response.go new file mode 100644 index 0000000..0e17eff --- /dev/null +++ b/pkg/zrouter/domain/response.go @@ -0,0 +1,99 @@ +package domain + +import ( + "encoding/json" + "net/http" + "sync" +) + +const ( + ContentTypeHeader = "Content-Type" + contentTypeApplicationJSON = "application/json; charset=utf-8" + ContentTypeJSON = "json" +) + +type ServiceResponse interface { + Status() int + Header() http.Header + ResponseBytes() ([]byte, error) + ResponseFormat() string + Contents() interface{} +} + +type defaultServiceResponse struct { + status int + header http.Header + response interface{} + once sync.Once + responseBytes []byte + marshalError error +} + +func (d *defaultServiceResponse) Status() int { + return d.status +} + +func (d *defaultServiceResponse) Header() http.Header { + h := d.header + if h == nil { + h = http.Header{} + } + if h.Get(ContentTypeHeader) == "" { + h.Set(ContentTypeHeader, contentTypeApplicationJSON) + } + return h +} + +func (d *defaultServiceResponse) ResponseFormat() string { + return ContentTypeJSON +} + +func (d *defaultServiceResponse) ResponseBytes() ([]byte, error) { + d.once.Do(func() { + if d.response != nil { + d.responseBytes, d.marshalError = json.Marshal(d.response) + } else { + d.responseBytes = []byte{} + } + }) + return d.responseBytes, d.marshalError +} + +func (d *defaultServiceResponse) Contents() interface{} { + return d.response +} + +func NewServiceResponse(status int, response interface{}) ServiceResponse { + return &defaultServiceResponse{ + status: status, + response: response, + header: nil, + } +} + +func NewServiceResponseWithHeader(status int, response interface{}, header http.Header) ServiceResponse { + return &defaultServiceResponse{ + status: status, + response: response, + header: header, + } +} + +func NewErrorResponse(status int, errorCode, errMsg string) ServiceResponse { + apiError := NewAPIErrorResponse(status, errorCode, errMsg) + apiErrorBytes, err := json.Marshal(apiError) + if err != nil { + return NewServiceResponse(status, errMsg) + } + + return &defaultServiceResponse{ + status: status, + response: apiErrorBytes, + header: nil, + responseBytes: apiErrorBytes, + } +} + +func NewErrorNotFound(errMsg string) ServiceResponse { + return NewErrorResponse(http.StatusNotFound, "ROUTE_NOT_FOUND", errMsg) +} diff --git a/pkg/zrouter/handler.go b/pkg/zrouter/handler.go new file mode 100644 index 0000000..a7c7039 --- /dev/null +++ b/pkg/zrouter/handler.go @@ -0,0 +1,68 @@ +package zrouter + +import ( + "encoding/json" + "errors" + "github.com/zondax/golem/pkg/zrouter/domain" + "net/http" +) + +type HandlerFunc func(ctx Context) (domain.ServiceResponse, error) + +func NotFoundHandler(_ Context) (domain.ServiceResponse, error) { + msg := "Route not found" + return domain.NewErrorNotFound(msg), nil +} + +func getChiHandler(handler HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + adaptedContext := &chiContextAdapter{ctx: w, req: r} + + serviceResponse, err := handler(adaptedContext) + if err != nil { + handleError(w, err) + return + } + + handleServiceResponse(w, serviceResponse) + } +} + +func handleError(w http.ResponseWriter, err error) { + var apiErr *domain.APIError + + if errors.As(err, &apiErr) { + writeAPIErrorResponse(w, apiErr) + return + } + + writeInternalServerError(w) +} + +func handleServiceResponse(w http.ResponseWriter, serviceResponse domain.ServiceResponse) { + if serviceResponse == nil { + return + } + + body, err := serviceResponse.ResponseBytes() + if err != nil { + http.Error(w, "Failed to process response.", http.StatusInternalServerError) + return + } + + contentType := serviceResponse.Header().Get(domain.ContentTypeHeader) + w.Header().Set(domain.ContentTypeHeader, contentType) + w.WriteHeader(serviceResponse.Status()) + _, _ = w.Write(body) +} + +func writeAPIErrorResponse(w http.ResponseWriter, apiErr *domain.APIError) { + w.Header().Set(domain.ContentTypeHeader, domain.ContentTypeJSON) + w.WriteHeader(apiErr.HTTPStatus) + responseBody, _ := json.Marshal(apiErr) + _, _ = w.Write(responseBody) +} + +func writeInternalServerError(w http.ResponseWriter) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) +} diff --git a/pkg/zrouter/handler_test.go b/pkg/zrouter/handler_test.go new file mode 100644 index 0000000..8461836 --- /dev/null +++ b/pkg/zrouter/handler_test.go @@ -0,0 +1,40 @@ +package zrouter + +import ( + "bytes" + "github.com/stretchr/testify/suite" + "github.com/zondax/golem/pkg/zrouter/domain" + "net/http" + "net/http/httptest" + "testing" +) + +type ChiHandlerAdapterSuite struct { + suite.Suite +} + +func (suite *ChiHandlerAdapterSuite) TestChiHandlerAdapter() { + h := http.Header{} + h.Add("Content-Type", "application/test") + + handlerFunc := func(ctx Context) (domain.ServiceResponse, error) { + return domain.NewServiceResponseWithHeader(http.StatusOK, "Hello", h), nil + } + + httpHandlerFunc := getChiHandler(handlerFunc) + + req, err := http.NewRequest("GET", "/test", bytes.NewBuffer(nil)) + suite.Require().NoError(err) + + recorder := httptest.NewRecorder() + + httpHandlerFunc(recorder, req) + + suite.Equal(http.StatusOK, recorder.Code) + suite.Equal("\"Hello\"", recorder.Body.String()) + suite.Equal("application/test", recorder.Header().Get("Content-Type")) +} + +func TestChiHandlerAdapterSuite(t *testing.T) { + suite.Run(t, new(ChiHandlerAdapterSuite)) +} diff --git a/pkg/zrouter/zmiddlewares/cors.go b/pkg/zrouter/zmiddlewares/cors.go new file mode 100644 index 0000000..84ac43c --- /dev/null +++ b/pkg/zrouter/zmiddlewares/cors.go @@ -0,0 +1,42 @@ +package zmiddlewares + +import ( + "github.com/go-chi/cors" + "net/http" +) + +type CorsOptions struct { + AllowedOrigins []string + AllowOriginFunc func(r *http.Request, origin string) bool + AllowedMethods []string + AllowedHeaders []string + ExposedHeaders []string + AllowCredentials bool + MaxAge int + OptionsPassthrough bool + Debug bool +} + +func (co CorsOptions) toChiOptions() cors.Options { + return cors.Options{ + AllowedOrigins: co.AllowedOrigins, + AllowOriginFunc: co.AllowOriginFunc, + AllowedMethods: co.AllowedMethods, + AllowedHeaders: co.AllowedHeaders, + ExposedHeaders: co.ExposedHeaders, + AllowCredentials: co.AllowCredentials, + MaxAge: co.MaxAge, + OptionsPassthrough: co.OptionsPassthrough, + Debug: co.Debug, + } +} + +func DefaultCors() Middleware { + corsMiddleware := cors.New(cors.Options{}) + return corsMiddleware.Handler +} + +func Cors(options CorsOptions) Middleware { + corsMiddleware := cors.New(options.toChiOptions()) + return corsMiddleware.Handler +} diff --git a/pkg/zrouter/zmiddlewares/error_handler.go b/pkg/zrouter/zmiddlewares/error_handler.go new file mode 100644 index 0000000..01e9d68 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/error_handler.go @@ -0,0 +1,36 @@ +package zmiddlewares + +import ( + "encoding/json" + "fmt" + "github.com/zondax/golem/pkg/zrouter/domain" + "go.uber.org/zap" + "net/http" + "runtime/debug" +) + +const ( + internalErrorCode = "internal_error" +) + +func ErrorHandlerMiddleware(logger *zap.SugaredLogger) Middleware { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + logger.Errorf("Internal error: %v\n%s", err, debug.Stack()) + message := fmt.Sprintf("An internal error occurred: %v", err) + apiError := domain.NewAPIErrorResponse(http.StatusInternalServerError, internalErrorCode, message) + + w.Header().Set(domain.ContentTypeHeader, domain.ContentTypeJSON) + w.WriteHeader(apiError.HTTPStatus) + _ = json.NewEncoder(w).Encode(apiError) + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} diff --git a/pkg/zrouter/zmiddlewares/error_handler_test.go b/pkg/zrouter/zmiddlewares/error_handler_test.go new file mode 100644 index 0000000..1681af5 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/error_handler_test.go @@ -0,0 +1,34 @@ +package zmiddlewares + +import ( + "encoding/json" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/zondax/golem/pkg/zrouter/domain" + "go.uber.org/zap" + "net/http" + "net/http/httptest" + "testing" +) + +func TestErrorHandlerMiddleware(t *testing.T) { + r := chi.NewRouter() + + r.Use(ErrorHandlerMiddleware(zap.S())) + + r.Get("/panic", func(w http.ResponseWriter, r *http.Request) { + panic("Some unexpected error") + }) + + req := httptest.NewRequest("GET", "/panic", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + var apiError domain.APIError + err := json.NewDecoder(rec.Body).Decode(&apiError) + assert.NoError(t, err) + assert.Equal(t, "internal_error", apiError.ErrorCode) + assert.Contains(t, apiError.Message, "Some unexpected error") +} diff --git a/pkg/zrouter/zmiddlewares/http.go b/pkg/zrouter/zmiddlewares/http.go new file mode 100644 index 0000000..ab5bef7 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/http.go @@ -0,0 +1,13 @@ +package zmiddlewares + +import ( + "github.com/go-chi/chi/v5/middleware" +) + +func RequestID() Middleware { + return middleware.RequestID +} + +func Logger() Middleware { + return middleware.Logger +} diff --git a/pkg/zrouter/zmiddlewares/metrics.go b/pkg/zrouter/zmiddlewares/metrics.go new file mode 100644 index 0000000..0babb08 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/metrics.go @@ -0,0 +1,93 @@ +package zmiddlewares + +import ( + "fmt" + "github.com/go-chi/chi/v5" + "github.com/zondax/golem/pkg/metrics" + "github.com/zondax/golem/pkg/metrics/collectors" + "net/http" + "strconv" + "sync" + "time" +) + +const ( + activeConnectionsMetricType = "active_connections" + durationMillisecondsMetricType = "duration_milliseconds" + responseSizeMetricType = "response_size_bytes" + totalRequestMetricType = "total_requests" + pathLabel = "path" + methodLabel = "method" + statusLabel = "status" +) + +func RegisterRequestMetrics(appName string, metricsServer metrics.TaskMetrics) []error { + var errs []error + + register := func(name, help string, labels []string, handler metrics.MetricHandler) { + if err := metricsServer.RegisterMetric(name, help, labels, handler); err != nil { + errs = append(errs, err) + } + } + + totalRequestsMetricName := getMetricName(appName, totalRequestMetricType) + responseSizeMetricName := getMetricName(appName, responseSizeMetricType) + durationMillisecondsMetricName := getMetricName(appName, durationMillisecondsMetricType) + activeConnectionsMetricName := getMetricName(appName, activeConnectionsMetricType) + register(totalRequestsMetricName, "Total number of HTTP requests made.", []string{methodLabel, pathLabel, statusLabel}, &collectors.Counter{}) + register(durationMillisecondsMetricName, "Duration of HTTP requests.", []string{methodLabel, pathLabel, statusLabel}, &collectors.Histogram{}) + register(responseSizeMetricName, "Size of HTTP response in bytes.", []string{methodLabel, pathLabel, statusLabel}, &collectors.Histogram{}) + register(activeConnectionsMetricName, "Number of active HTTP connections.", nil, &collectors.Gauge{}) + + if len(errs) > 0 { + return errs + } + + return nil +} + +func RequestMetrics(appName string, metricsServer metrics.TaskMetrics) Middleware { + var activeConnections int64 + var mu sync.Mutex + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + activeConnectionsMetricName := getMetricName(appName, activeConnectionsMetricType) + + mu.Lock() + activeConnections++ + _ = metricsServer.UpdateMetric(activeConnectionsMetricName, float64(activeConnections)) + mu.Unlock() + + mrw := &metricsResponseWriter{ResponseWriter: w} + next.ServeHTTP(mrw, r) + + mu.Lock() + activeConnections-- + _ = metricsServer.UpdateMetric(activeConnectionsMetricName, float64(activeConnections)) + mu.Unlock() + + duration := float64(time.Since(startTime).Milliseconds()) + path := chi.RouteContext(r.Context()).RoutePattern() + + responseStatus := mrw.status + bytesWritten := mrw.written + + labels := []string{r.Method, path, strconv.Itoa(responseStatus)} + + durationMillisecondsMetricName := getMetricName(appName, durationMillisecondsMetricType) + _ = metricsServer.UpdateMetric(durationMillisecondsMetricName, duration, labels...) + + responseSizeMetricName := getMetricName(appName, responseSizeMetricType) + _ = metricsServer.UpdateMetric(responseSizeMetricName, float64(bytesWritten), labels...) + + totalRequestsMetricName := getMetricName(appName, totalRequestMetricType) + _ = metricsServer.UpdateMetric(totalRequestsMetricName, 1, labels...) + }) + } +} + +func getMetricName(appName, metricType string) string { + return fmt.Sprintf("zrouter_request_%s_%s", appName, metricType) +} diff --git a/pkg/zrouter/zmiddlewares/middleware.go b/pkg/zrouter/zmiddlewares/middleware.go new file mode 100644 index 0000000..ec8e5c6 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/middleware.go @@ -0,0 +1,24 @@ +package zmiddlewares + +import ( + "net/http" +) + +type Middleware func(next http.Handler) http.Handler + +type metricsResponseWriter struct { + http.ResponseWriter + status int + written int64 +} + +func (mrw *metricsResponseWriter) WriteHeader(statusCode int) { + mrw.status = statusCode + mrw.ResponseWriter.WriteHeader(statusCode) +} + +func (mrw *metricsResponseWriter) Write(p []byte) (int, error) { + n, err := mrw.ResponseWriter.Write(p) + mrw.written += int64(n) + return n, err +} diff --git a/pkg/zrouter/zmiddlewares/rate_limit.go b/pkg/zrouter/zmiddlewares/rate_limit.go new file mode 100644 index 0000000..054b094 --- /dev/null +++ b/pkg/zrouter/zmiddlewares/rate_limit.go @@ -0,0 +1,22 @@ +package zmiddlewares + +import ( + "golang.org/x/time/rate" + "net/http" + "time" +) + +func RateLimit(maxRPM int) Middleware { + limiter := rate.NewLimiter(rate.Every(time.Minute/time.Duration(maxRPM)), maxRPM) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !limiter.Allow() { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/zrouter/zmiddlewares/rate_limit_test.go b/pkg/zrouter/zmiddlewares/rate_limit_test.go new file mode 100644 index 0000000..15ec9ba --- /dev/null +++ b/pkg/zrouter/zmiddlewares/rate_limit_test.go @@ -0,0 +1,30 @@ +package zmiddlewares + +import ( + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRateLimit(t *testing.T) { + r := chi.NewRouter() + + r.Use(RateLimit(1)) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("OK")) + }) + + req := httptest.NewRequest("GET", "/", nil) + + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) +} diff --git a/pkg/zrouter/zrouter.go b/pkg/zrouter/zrouter.go new file mode 100644 index 0000000..a8380dd --- /dev/null +++ b/pkg/zrouter/zrouter.go @@ -0,0 +1,206 @@ +package zrouter + +import ( + "github.com/go-chi/chi/v5" + "github.com/zondax/golem/pkg/metrics" + "github.com/zondax/golem/pkg/zrouter/zmiddlewares" + "go.uber.org/zap" + "net/http" + "sync" + "time" +) + +const ( + defaultAddress = ":8080" + defaultTimeOut = 240000 +) + +type Config struct { + ReadTimeOut time.Duration + WriteTimeOut time.Duration + Logger *zap.SugaredLogger +} + +func (c *Config) setDefaultValues() { + if c.ReadTimeOut == 0 { + c.ReadTimeOut = time.Duration(defaultTimeOut) * time.Millisecond + } + + if c.WriteTimeOut == 0 { + c.WriteTimeOut = time.Duration(defaultTimeOut) * time.Millisecond + } + + if c.Logger == nil { + l, _ := zap.NewProduction() + c.Logger = l.Sugar() + } +} + +type RegisteredRoute struct { + Method string + Path string +} + +type ZRouter interface { + Routes + Run(addr ...string) error +} + +type Routes interface { + GET(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + POST(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + PUT(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + PATCH(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + DELETE(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + Route(method, path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes + Group(prefix string) Routes + Use(middlewares ...zmiddlewares.Middleware) Routes + NoRoute(handler HandlerFunc) + GetRegisteredRoutes() []RegisteredRoute + SetDefaultMiddlewares() + GetHandler() http.Handler +} + +type zrouter struct { + router *chi.Mux + middlewares []zmiddlewares.Middleware + metricsServer metrics.TaskMetrics + appName string + routes []RegisteredRoute + mutex sync.Mutex + config *Config +} + +func New(appName string, metricsServer metrics.TaskMetrics, config *Config) ZRouter { + if appName == "" { + panic("appName cannot be an empty string") + } + + if config == nil { + config = &Config{} + } + + config.setDefaultValues() + zr := &zrouter{ + router: chi.NewRouter(), + metricsServer: metricsServer, + appName: appName, + config: config, + } + return zr +} + +func (r *zrouter) SetDefaultMiddlewares() { + r.Use(zmiddlewares.ErrorHandlerMiddleware(r.config.Logger)) + r.Use(zmiddlewares.RequestID()) + if err := zmiddlewares.RegisterRequestMetrics(r.appName, r.metricsServer); err != nil { + r.config.Logger.With("err", err).Error("Error registering metrics") + } + + r.Use(zmiddlewares.RequestMetrics(r.appName, r.metricsServer)) +} + +func (r *zrouter) Group(prefix string) Routes { + newRouter := &zrouter{ + router: chi.NewRouter(), + } + + r.router.Group(func(groupRouter chi.Router) { + groupRouter.Mount(prefix, newRouter.router) + }) + + return newRouter +} + +func (r *zrouter) Run(addr ...string) error { + address := defaultAddress + if len(addr) > 0 { + address = addr[0] + } + + r.config.Logger.Infof("Start server at %v", address) + + server := &http.Server{ + Addr: address, + Handler: r.router, + ReadTimeout: r.config.ReadTimeOut, + WriteTimeout: r.config.WriteTimeOut, + } + return server.ListenAndServe() +} + +func (r *zrouter) applyMiddlewares(handler http.HandlerFunc, middlewares ...zmiddlewares.Middleware) http.Handler { + var wrappedHandler http.Handler = handler + + for _, mw := range r.middlewares { + wrappedHandler = mw(wrappedHandler) + } + + for _, mw := range middlewares { + wrappedHandler = mw(wrappedHandler) + } + return wrappedHandler +} + +func (r *zrouter) Method(method, path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + chiHandler := getChiHandler(handler) + finalHandler := r.applyMiddlewares(chiHandler, middlewares...) + r.router.Method(method, path, finalHandler) + + r.mutex.Lock() + r.routes = append(r.routes, RegisteredRoute{Method: method, Path: path}) + r.mutex.Unlock() + return r +} + +func (r *zrouter) GET(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(http.MethodGet, path, handler, middlewares...) + return r +} + +func (r *zrouter) POST(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(http.MethodPost, path, handler, middlewares...) + return r +} + +func (r *zrouter) PUT(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(http.MethodPut, path, handler, middlewares...) + return r +} + +func (r *zrouter) PATCH(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(http.MethodPatch, path, handler, middlewares...) + return r +} + +func (r *zrouter) DELETE(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(http.MethodDelete, path, handler, middlewares...) + return r +} + +func (r *zrouter) Route(method, path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + r.Method(method, path, handler, middlewares...) + return r +} + +func (r *zrouter) NoRoute(handler HandlerFunc) { + r.router.NotFound(getChiHandler(handler)) +} + +func (r *zrouter) Use(middlewares ...zmiddlewares.Middleware) Routes { + r.middlewares = append(r.middlewares, middlewares...) + return r +} + +func (r *zrouter) GetRegisteredRoutes() []RegisteredRoute { + r.mutex.Lock() + defer r.mutex.Unlock() + + routesCopy := make([]RegisteredRoute, len(r.routes)) + copy(routesCopy, r.routes) + return routesCopy +} + +func (r *zrouter) GetHandler() http.Handler { + return r.router +} diff --git a/pkg/zrouter/zrouter_mock.go b/pkg/zrouter/zrouter_mock.go new file mode 100644 index 0000000..cae69f2 --- /dev/null +++ b/pkg/zrouter/zrouter_mock.go @@ -0,0 +1,55 @@ +package zrouter + +import ( + "github.com/stretchr/testify/mock" + "github.com/zondax/golem/pkg/zrouter/zmiddlewares" +) + +type MockZRouter struct { + mock.Mock +} + +func (m *MockZRouter) Run(addr ...string) error { + args := m.Called(addr) + return args.Error(0) +} + +func (m *MockZRouter) GET(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) POST(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) PUT(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) PATCH(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) DELETE(path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) Route(method, path string, handler HandlerFunc, middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(method, path, handler, middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) Use(middlewares ...zmiddlewares.Middleware) Routes { + args := m.Called(middlewares) + return args.Get(0).(Routes) +} + +func (m *MockZRouter) Group(prefix string) Routes { + args := m.Called(prefix) + return args.Get(0).(Routes) +} diff --git a/pkg/zrouter/zrouter_test.go b/pkg/zrouter/zrouter_test.go new file mode 100644 index 0000000..70b687a --- /dev/null +++ b/pkg/zrouter/zrouter_test.go @@ -0,0 +1,54 @@ +package zrouter + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/zondax/golem/pkg/zrouter/domain" + "net/http" + "net/http/httptest" + "testing" +) + +type ZRouterSuite struct { + suite.Suite + router ZRouter +} + +func (suite *ZRouterSuite) SetupTest() { + suite.router = New("testApp", nil, nil) +} + +func (suite *ZRouterSuite) TestRegisterAndGetRoutes() { + + suite.router.GET("/get", func(ctx Context) (domain.ServiceResponse, error) { + return domain.NewServiceResponse(http.StatusOK, []byte("GET OK")), nil + }) + + suite.router.POST("/post", func(ctx Context) (domain.ServiceResponse, error) { + return domain.NewServiceResponse(http.StatusOK, []byte("POST OK")), nil + }) + + routes := suite.router.GetRegisteredRoutes() + + assert.Len(suite.T(), routes, 2) + assert.Contains(suite.T(), routes, RegisteredRoute{Method: "GET", Path: "/get"}) + assert.Contains(suite.T(), routes, RegisteredRoute{Method: "POST", Path: "/post"}) +} + +func (suite *ZRouterSuite) TestRouteHandling() { + suite.router.GET("/test", func(ctx Context) (domain.ServiceResponse, error) { + return domain.NewServiceResponse(http.StatusOK, "test route"), nil + }) + + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + recorder := httptest.NewRecorder() + handler := suite.router.GetHandler() + handler.ServeHTTP(recorder, req) + + assert.Equal(suite.T(), http.StatusOK, recorder.Code) + assert.Equal(suite.T(), "\"test route\"", recorder.Body.String()) +} + +func TestZRouterSuite(t *testing.T) { + suite.Run(t, new(ZRouterSuite)) +}