Skip to content

Commit

Permalink
Use more custom errors
Browse files Browse the repository at this point in the history
  • Loading branch information
9seconds committed Dec 16, 2020
1 parent 3a8fe8b commit 89b4e44
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
24 changes: 17 additions & 7 deletions default_layers.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package httransform

import (
"fmt"

"github.com/9seconds/httransform/v2/errors"
"github.com/9seconds/httransform/v2/headers"
"github.com/9seconds/httransform/v2/layers"
"github.com/PumpkinSeed/errors"
)

type layerStartHeaders struct{}
Expand All @@ -14,15 +12,21 @@ func (l layerStartHeaders) OnRequest(ctx *layers.Context) error {
requestHeaders := headers.NewRequestHeaderWrapper(&ctx.Request().Header)

if err := ctx.RequestHeaders.Init(requestHeaders); err != nil {
return fmt.Errorf("cannot read request headers: %w", err)
return &errors.Error{
Message: "cannot read request headers",
Err: err,
}
}

return nil
}

func (l layerStartHeaders) OnResponse(ctx *layers.Context, err error) error {
if err2 := ctx.ResponseHeaders.Sync(); err2 != nil {
return errors.Wrap(err, fmt.Errorf("cannot sync response headers: %w", err))
return &errors.Error{
Message: "cannot sync response headers",
Err: err,
}
}

return err
Expand All @@ -32,7 +36,10 @@ type layerFinishHeaders struct{}

func (l layerFinishHeaders) OnRequest(ctx *layers.Context) error {
if err := ctx.RequestHeaders.Sync(); err != nil {
return fmt.Errorf("cannot sync request headers: %w", err)
return &errors.Error{
Message: "cannot sync request headers",
Err: err,
}
}

return nil
Expand All @@ -42,7 +49,10 @@ func (l layerFinishHeaders) OnResponse(ctx *layers.Context, err error) error {
responseHeaders := headers.NewResponseHeaderWrapper(&ctx.Response().Header)

if err2 := ctx.ResponseHeaders.Init(responseHeaders); err2 != nil {
return errors.Wrap(err, fmt.Errorf("cannot read response headers: %w", err))
return &errors.Error{
Message: "cannot read response headers",
Err: err,
}
}

return err
Expand Down
4 changes: 2 additions & 2 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type ServerOpts struct {
// WriteTimeout defines a timeout for writing to client socket.
WriteTimeout time.Duration

// TCPKeepAlivePeriod defines a time period between 2 consequtive
// TCPKeepAlivePeriod defines a time period between 2 consecutive
// TCP keepalive probes.
TCPKeepAlivePeriod time.Duration

Expand Down Expand Up @@ -210,7 +210,7 @@ func (s *ServerOpts) GetLayers() []layers.Layer {
return toReturn
}

// GetLayers returns an authenticator instanse to use paying attention
// GetLayers returns an authenticator instance to use paying attention
// to default value (no auth).
func (s *ServerOpts) GetAuthenticator() auth.Interface {
if s == nil || s.Authenticator == nil {
Expand Down
27 changes: 19 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func (s *Server) Close() error {
func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) {
user, err := s.authenticator.Authenticate(ctx)
if err != nil {
var customErr *errors.Error

if !errors.As(err, &customErr) {
customErr = errors.Annotate(err, "authenication is failed", "bad_auth", fasthttp.StatusProxyAuthRequired)
errToReturn := &errors.Error{
Message: "authenication is failed",
StatusCode: fasthttp.StatusProxyAuthRequired,
Err: err,
}

customErr.WriteTo(ctx)
errToReturn.WriteTo(ctx)
ctx.Response.Header.Add("Proxy-Authenticate", "Basic")
s.eventStream.Send(ctx, events.EventTypeFailedAuth, nil, "")

Expand All @@ -71,7 +71,13 @@ func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) {

address, err := s.extractAddress(string(ctx.RequestURI()), true)
if err != nil {
ctx.Error(fmt.Sprintf("cannot extract a host for tunneled connection: %s", err.Error()), fasthttp.StatusBadRequest)
errToReturn := &errors.Error{
Message: "cannot extract a host for tunneled connection",
StatusCode: fasthttp.StatusBadGateway,
Err: err,
}

errToReturn.WriteTo(ctx)

return
}
Expand All @@ -84,7 +90,13 @@ func (s *Server) entrypoint(ctx *fasthttp.RequestCtx) {

address, err := s.extractAddress(string(ctx.Host()), false)
if err != nil {
ctx.Error(fmt.Sprintf("cannot extract a host for tunneled connection: %s", err.Error()), fasthttp.StatusBadRequest)
errToReturn := &errors.Error{
Message: "cannot extract a host for tunneled connection",
StatusCode: fasthttp.StatusBadGateway,
Err: err,
}

errToReturn.WriteTo(ctx)

return
}
Expand Down Expand Up @@ -217,7 +229,6 @@ func (s *Server) completeRequestType(ctx *layers.Context) {
ctx.Request().Header.VisitAll(func(key, value []byte) {
if bytes.EqualFold(key, []byte("Connection")) {
values := headers.Values(string(value))
ctx.RequestType &^= events.RequestTypeUpgraded

for i := range values {
if strings.EqualFold(values[i], "Upgrade") {
Expand Down

0 comments on commit 89b4e44

Please sign in to comment.