From 89b4e44ade9d5cf261594fd6117eceb0b4b9252c Mon Sep 17 00:00:00 2001 From: 9seconds Date: Wed, 16 Dec 2020 11:16:18 +0300 Subject: [PATCH] Use more custom errors --- default_layers.go | 24 +++++++++++++++++------- opts.go | 4 ++-- server.go | 27 +++++++++++++++++++-------- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/default_layers.go b/default_layers.go index 36b8fd55..0d5bb1ab 100644 --- a/default_layers.go +++ b/default_layers.go @@ -1,11 +1,9 @@ package httransform import ( - "fmt" - + "github.com/9seconds/httransform/v2/errors" "github.com/9seconds/httransform/v2/headers" "github.com/9seconds/httransform/v2/layers" - "github.com/PumpkinSeed/errors" ) type layerStartHeaders struct{} @@ -14,7 +12,10 @@ func (l layerStartHeaders) OnRequest(ctx *layers.Context) error { requestHeaders := headers.NewRequestHeaderWrapper(&ctx.Request().Header) if err := ctx.RequestHeaders.Init(requestHeaders); err != nil { - return fmt.Errorf("cannot read request headers: %w", err) + return &errors.Error{ + Message: "cannot read request headers", + Err: err, + } } return nil @@ -22,7 +23,10 @@ func (l layerStartHeaders) OnRequest(ctx *layers.Context) error { func (l layerStartHeaders) OnResponse(ctx *layers.Context, err error) error { if err2 := ctx.ResponseHeaders.Sync(); err2 != nil { - return errors.Wrap(err, fmt.Errorf("cannot sync response headers: %w", err)) + return &errors.Error{ + Message: "cannot sync response headers", + Err: err, + } } return err @@ -32,7 +36,10 @@ type layerFinishHeaders struct{} func (l layerFinishHeaders) OnRequest(ctx *layers.Context) error { if err := ctx.RequestHeaders.Sync(); err != nil { - return fmt.Errorf("cannot sync request headers: %w", err) + return &errors.Error{ + Message: "cannot sync request headers", + Err: err, + } } return nil @@ -42,7 +49,10 @@ func (l layerFinishHeaders) OnResponse(ctx *layers.Context, err error) error { responseHeaders := headers.NewResponseHeaderWrapper(&ctx.Response().Header) if err2 := ctx.ResponseHeaders.Init(responseHeaders); err2 != nil { - return errors.Wrap(err, fmt.Errorf("cannot read response headers: %w", err)) + return &errors.Error{ + Message: "cannot read response headers", + Err: err, + } } return err diff --git a/opts.go b/opts.go index c51eb45c..e3eb509f 100644 --- a/opts.go +++ b/opts.go @@ -63,7 +63,7 @@ type ServerOpts struct { // WriteTimeout defines a timeout for writing to client socket. WriteTimeout time.Duration - // TCPKeepAlivePeriod defines a time period between 2 consequtive + // TCPKeepAlivePeriod defines a time period between 2 consecutive // TCP keepalive probes. TCPKeepAlivePeriod time.Duration @@ -210,7 +210,7 @@ func (s *ServerOpts) GetLayers() []layers.Layer { return toReturn } -// GetLayers returns an authenticator instanse to use paying attention +// GetLayers returns an authenticator instance to use paying attention // to default value (no auth). func (s *ServerOpts) GetAuthenticator() auth.Interface { if s == nil || s.Authenticator == nil { diff --git a/server.go b/server.go index fa5d83db..73e3b402 100644 --- a/server.go +++ b/server.go @@ -51,13 +51,13 @@ func (s *Server) Close() error { func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) { user, err := s.authenticator.Authenticate(ctx) if err != nil { - var customErr *errors.Error - - if !errors.As(err, &customErr) { - customErr = errors.Annotate(err, "authenication is failed", "bad_auth", fasthttp.StatusProxyAuthRequired) + errToReturn := &errors.Error{ + Message: "authenication is failed", + StatusCode: fasthttp.StatusProxyAuthRequired, + Err: err, } - customErr.WriteTo(ctx) + errToReturn.WriteTo(ctx) ctx.Response.Header.Add("Proxy-Authenticate", "Basic") s.eventStream.Send(ctx, events.EventTypeFailedAuth, nil, "") @@ -71,7 +71,13 @@ func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) { address, err := s.extractAddress(string(ctx.RequestURI()), true) if err != nil { - ctx.Error(fmt.Sprintf("cannot extract a host for tunneled connection: %s", err.Error()), fasthttp.StatusBadRequest) + errToReturn := &errors.Error{ + Message: "cannot extract a host for tunneled connection", + StatusCode: fasthttp.StatusBadGateway, + Err: err, + } + + errToReturn.WriteTo(ctx) return } @@ -84,7 +90,13 @@ func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) { address, err := s.extractAddress(string(ctx.Host()), false) if err != nil { - ctx.Error(fmt.Sprintf("cannot extract a host for tunneled connection: %s", err.Error()), fasthttp.StatusBadRequest) + errToReturn := &errors.Error{ + Message: "cannot extract a host for tunneled connection", + StatusCode: fasthttp.StatusBadGateway, + Err: err, + } + + errToReturn.WriteTo(ctx) return } @@ -217,7 +229,6 @@ func (s *Server) completeRequestType(ctx *layers.Context) { ctx.Request().Header.VisitAll(func(key, value []byte) { if bytes.EqualFold(key, []byte("Connection")) { values := headers.Values(string(value)) - ctx.RequestType &^= events.RequestTypeUpgraded for i := range values { if strings.EqualFold(values[i], "Upgrade") {