From e64664525b2ebc909a6f9dc00bd6737e52cbf205 Mon Sep 17 00:00:00 2001 From: Lucas Francisco Lopez Date: Fri, 12 Apr 2024 18:39:27 +0200 Subject: [PATCH] Fix GetRoutePattern for subrouters --- pkg/zrouter/zmiddlewares/common.go | 15 ++++++++++- pkg/zrouter/zmiddlewares/common_test.go | 34 ++++++++++++++++++------- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pkg/zrouter/zmiddlewares/common.go b/pkg/zrouter/zmiddlewares/common.go index fff7117..fd32d24 100644 --- a/pkg/zrouter/zmiddlewares/common.go +++ b/pkg/zrouter/zmiddlewares/common.go @@ -21,7 +21,20 @@ func PathToRegexp(path string) *regexp.Regexp { } func GetRoutePattern(r *http.Request) string { - return chi.RouteContext(r.Context()).RoutePattern() + rctx := chi.RouteContext(r.Context()) + if pattern := rctx.RoutePattern(); pattern != "" && !strings.HasSuffix(pattern, "*") { + return pattern + } + + routePath := r.URL.Path + tctx := chi.NewRouteContext() + if !rctx.Routes.Match(tctx, r.Method, routePath) { + // No matching pattern, so just return the request path. + return routePath + } + + // tctx has the updated pattern, since Match mutates it + return tctx.RoutePattern() } func getRequestBody(r *http.Request) ([]byte, error) { diff --git a/pkg/zrouter/zmiddlewares/common_test.go b/pkg/zrouter/zmiddlewares/common_test.go index 2b81079..201f4b1 100644 --- a/pkg/zrouter/zmiddlewares/common_test.go +++ b/pkg/zrouter/zmiddlewares/common_test.go @@ -10,22 +10,38 @@ import ( "testing" ) -func TestGetRoutePattern(t *testing.T) { +func TestGetRoutePatternIncludingSubrouters(t *testing.T) { r := chi.NewRouter() + subRouter := chi.NewRouter() - routePattern := "/test/{param}" - r.Get(routePattern, func(w http.ResponseWriter, r *http.Request) { + // Configure a test route on the subrouter + subRoutePattern := "/sub/{subParam}" + subRouter.Get(subRoutePattern, func(w http.ResponseWriter, r *http.Request) { routePattern := GetRoutePattern(r) - - assert.Equal(t, routePattern, "/test/{param}", "The returned route pattern should match the one configured in the router.") + assert.Equal(t, "/test/sub/{subParam}", routePattern, "The returned route pattern should match the subrouter pattern.") }) - req := httptest.NewRequest("GET", "/test/123", nil) - w := httptest.NewRecorder() + // Mount the subrouter onto a specific path of the main router + r.Mount("/test", subRouter) + + // Test request for the subrouter route + reqSub := httptest.NewRequest("GET", "/test/sub/456", nil) + wSub := httptest.NewRecorder() + r.ServeHTTP(wSub, reqSub) + assert.Equal(t, http.StatusOK, wSub.Code, "The expected status code for subrouter should be 200 OK.") - r.ServeHTTP(w, req) + // Configure a test route on the main router + mainRoutePattern := "/main/{mainParam}" + r.Get(mainRoutePattern, func(w http.ResponseWriter, r *http.Request) { + routePattern := GetRoutePattern(r) + assert.Equal(t, "/main/{mainParam}", routePattern, "The returned route pattern should match the main router pattern.") + }) - assert.Equal(t, http.StatusOK, w.Code, "The expected status code should be 200 OK.") + // Test request for the main router route + reqMain := httptest.NewRequest("GET", "/main/123", nil) + wMain := httptest.NewRecorder() + r.ServeHTTP(wMain, reqMain) + assert.Equal(t, http.StatusOK, wMain.Code, "The expected status code for main router should be 200 OK.") } func TestGetRequestBody(t *testing.T) {