From 6325524feb4e7383c595cfaf8524142391842890 Mon Sep 17 00:00:00 2001 From: Asphodelus <1749094641@qq.com> Date: Thu, 28 Mar 2024 17:08:04 +0800 Subject: [PATCH 1/2] feat: issue 1081 --- pkg/app/server/option.go | 16 ++++++++++++++++ pkg/common/config/option.go | 1 + pkg/route/engine.go | 28 +++++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index e7970348c..4503f5413 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -82,6 +82,22 @@ func WithRedirectTrailingSlash(b bool) config.Option { }} } +// WithFixTrailingSlash sets fixTrailingSlash. +// +// If enabled, the router tries to fix the current request path, if no +// handle is registered for it. +// For example if /foo is requested but a route only exists for /foo/, the +// client requests /foo/ without redirecting for all request methods. +// This option conflicts with RedirectTrailingSlash +func WithFixTrailingSlash(b bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.FixTrailingSlash = b + if b { + o.RedirectTrailingSlash = !b + } + }} +} + // WithRedirectFixedPath sets redirectFixedPath. // // If enabled, the router tries to fix the current request path, if no diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 417955fc9..06fb23ae8 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -48,6 +48,7 @@ type Options struct { WriteTimeout time.Duration IdleTimeout time.Duration RedirectTrailingSlash bool + FixTrailingSlash bool MaxRequestBodySize int MaxKeepBodySize int GetOnly bool diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 30b899dac..a818a06be 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -54,6 +54,8 @@ import ( "sync" "sync/atomic" + "github.com/cloudwego/hertz/pkg/route/param" + "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" @@ -721,7 +723,6 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { if engine.PanicHandler != nil { defer engine.recv(ctx) } - rPath := string(ctx.Request.URI().Path()) // align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2 @@ -770,6 +771,10 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { redirectTrailingSlash(ctx) return } + if value.tsr && engine.options.FixTrailingSlash { + directTrailingSlash(c, ctx, t[i], paramsPointer, unescape) + return + } if engine.options.RedirectFixedPath && redirectFixedPath(ctx, t[i].root, engine.options.RedirectFixedPath) { return } @@ -823,6 +828,27 @@ func trailingSlashURL(ts string) string { return tmpURI } +func directTrailingSlash(c context.Context, ctx *app.RequestContext, r *router, paramsPointer *param.Params, unescape bool) { + p := bytesconv.B2s(ctx.Request.URI().Path()) + if prefix := utils.CleanPath(bytesconv.B2s(ctx.Request.Header.Peek("X-Forwarded-Prefix"))); prefix != "." { + p = prefix + "/" + p + } + + tmpURI := trailingSlashURL(p) + + query := ctx.Request.URI().QueryString() + if len(query) > 0 { + tmpURI = tmpURI + "?" + bytesconv.B2s(query) + } + tmpURI = "/" + strings.TrimLeft(tmpURI, "/") + v := r.find(tmpURI, paramsPointer, unescape) + if v.handlers != nil { + ctx.SetHandlers(v.handlers) + ctx.SetFullPath(v.fullPath) + ctx.Next(c) + } +} + func redirectTrailingSlash(c *app.RequestContext) { p := bytesconv.B2s(c.Request.URI().Path()) if prefix := utils.CleanPath(bytesconv.B2s(c.Request.Header.Peek("X-Forwarded-Prefix"))); prefix != "." { From e1231dde9f7c508e3dc33730489c50375b39c41b Mon Sep 17 00:00:00 2001 From: Asphodelus <1749094641@qq.com> Date: Sat, 30 Mar 2024 14:36:45 +0800 Subject: [PATCH 2/2] test: route fix traling slash --- pkg/route/routes_test.go | 52 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/pkg/route/routes_test.go b/pkg/route/routes_test.go index 1e76d673e..3aa94e1f6 100644 --- a/pkg/route/routes_test.go +++ b/pkg/route/routes_test.go @@ -228,6 +228,58 @@ func TestRouteNotOK3(t *testing.T) { testRouteNotOK3(consts.MethodTrace, t) } +func TestRouteFixTrailingSlash(t *testing.T) { + router := NewEngine(config.NewOptions(nil)) + router.options.RedirectFixedPath = false + router.options.RedirectTrailingSlash = false + router.options.FixTrailingSlash = true + router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) + router.GET("/path2/", func(c context.Context, ctx *app.RequestContext) {}) + router.POST("/path3", func(c context.Context, ctx *app.RequestContext) {}) + router.PUT("/path4/", func(c context.Context, ctx *app.RequestContext) {}) + + w := performRequest(router, consts.MethodGet, "/path/") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodGet, "/path2") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodPost, "/path3/") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodPut, "/path4") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodGet, "/path") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodGet, "/path2/") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodPost, "/path3") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodPut, "/path4/") + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"}) + assert.DeepEqual(t, consts.StatusOK, w.Code) + + w = performRequest(router, consts.MethodGet, "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"}) + assert.DeepEqual(t, consts.StatusOK, w.Code) + + router.options.FixTrailingSlash = false + + w = performRequest(router, consts.MethodGet, "/path/") + assert.DeepEqual(t, consts.StatusNotFound, w.Code) + w = performRequest(router, consts.MethodGet, "/path2") + assert.DeepEqual(t, consts.StatusNotFound, w.Code) + w = performRequest(router, consts.MethodPost, "/path3/") + assert.DeepEqual(t, consts.StatusNotFound, w.Code) + w = performRequest(router, consts.MethodPut, "/path4") + assert.DeepEqual(t, consts.StatusNotFound, w.Code) +} + func TestRouteRedirectTrailingSlash(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RedirectFixedPath = false