From 5da5607b327d92528faeb19ebda933fa83996396 Mon Sep 17 00:00:00 2001 From: darkweak Date: Fri, 15 Nov 2024 19:17:50 +0100 Subject: [PATCH] fix(writer): buffer race condition --- pkg/middleware/middleware.go | 61 +++++++++++++----- pkg/middleware/writer.go | 6 ++ .../traefik/override/middleware/middleware.go | 64 +++++++++++++------ plugins/traefik/override/middleware/writer.go | 6 ++ .../souin/pkg/middleware/middleware.go | 64 +++++++++++++------ .../darkweak/souin/pkg/middleware/writer.go | 6 ++ 6 files changed, 153 insertions(+), 54 deletions(-) diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 2b5bfc051..606bfa23e 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -467,7 +467,9 @@ func (s *SouinBaseHandler) Upstream( } err := s.Store(customWriter, rq, requestCc, cachedKey, uri) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), @@ -521,7 +523,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF statusCode := customWriter.GetStatusCode() if err == nil { if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified { - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) return nil, errors.New("") @@ -542,7 +546,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF ), ) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), headers: customWriter.Header().Clone(), @@ -598,6 +604,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n req := s.context.SetBaseContext(rq) cacheName := req.Context().Value(context.CacheName).(string) + if rq.Header.Get("Upgrade") == "websocket" || rq.Header.Get("Accept") == "text/event-stream" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI") return next(rw, req) @@ -689,14 +696,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NotModified { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return nil } customWriter.WriteHeader(response.StatusCode) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return nil @@ -722,7 +733,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } customWriter.WriteHeader(response.StatusCode) s.Configuration.GetLogger().Debugf("Serve from cache %+v", req) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() prometheus.Increment(prometheus.CachedResponseCounter) @@ -742,7 +755,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() customWriter = NewCustomWriter(req, rw, bufPool) go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { @@ -766,14 +781,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) maps.Copy(customWriter.Header(), response.Header) customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err } rw.WriteHeader(http.StatusGatewayTimeout) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, err := customWriter.Send() return err @@ -784,7 +803,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n rfc.SetCacheStatusHeader(response, storerName) customWriter.WriteHeader(response.StatusCode) maps.Copy(customWriter.Header(), response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return err @@ -793,7 +814,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n if statusCode != http.StatusNotModified && validator.Matched { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return err @@ -808,7 +831,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) maps.Copy(customWriter.Header(), response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -822,7 +847,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) maps.Copy(customWriter.Header(), response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -846,8 +873,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) maps.Copy(customWriter.Header(), response.Header) customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err diff --git a/pkg/middleware/writer.go b/pkg/middleware/writer.go index 4e9753e33..abc21b4b7 100644 --- a/pkg/middleware/writer.go +++ b/pkg/middleware/writer.go @@ -39,6 +39,12 @@ type CustomWriter struct { statusCode int } +func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) { + r.mutex.Lock() + callback(r.Buf) + r.mutex.Unlock() +} + // Header will write the response headers func (r *CustomWriter) Header() http.Header { r.mutex.Lock() diff --git a/plugins/traefik/override/middleware/middleware.go b/plugins/traefik/override/middleware/middleware.go index 808926cbf..ab7770c0e 100644 --- a/plugins/traefik/override/middleware/middleware.go +++ b/plugins/traefik/override/middleware/middleware.go @@ -189,7 +189,7 @@ func (s *SouinBaseHandler) Store( ma = time.Duration(responseCc.SMaxAge) * time.Second } else if responseCc.MaxAge >= 0 { ma = time.Duration(responseCc.MaxAge) * time.Second - } else if customWriter.Header().Get("Expires") != "" { + } else if !modeContext.Bypass_response && customWriter.Header().Get("Expires") != "" { exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires")) if err != nil { return nil @@ -249,7 +249,7 @@ func (s *SouinBaseHandler) Store( } res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length")) response, err := httputil.DumpResponse(&res, true) - if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) { + if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode)) { variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header) if isVaryStar { // "Implies that the response is uncacheable" @@ -372,7 +372,9 @@ func (s *SouinBaseHandler) Upstream( } err := s.Store(customWriter, rq, requestCc, cachedKey) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), @@ -423,7 +425,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler statusCode := customWriter.GetStatusCode() if err == nil { if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified { - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) return nil, errors.New("") @@ -444,7 +448,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler ), ) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), headers: customWriter.Header().Clone(), @@ -493,6 +499,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n handler(rw, rq) return nil } + req := s.context.SetBaseContext(rq) cacheName := req.Context().Value(context.CacheName).(string) @@ -526,7 +533,6 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control")) modeContext := req.Context().Value(context.Mode).(*context.ModeContext) - if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR") @@ -593,14 +599,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NotModified { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return nil } customWriter.WriteHeader(response.StatusCode) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return nil @@ -624,7 +634,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header()[h] = v } customWriter.WriteHeader(response.StatusCode) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -643,7 +655,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() customWriter = NewCustomWriter(req, rw, bufPool) go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { @@ -656,7 +670,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { + if modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) statusCode := customWriter.GetStatusCode() @@ -670,14 +684,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header().Set(k, response.Header.Get(k)) } customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err } rw.WriteHeader(http.StatusGatewayTimeout) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, err := customWriter.Send() return err @@ -691,7 +709,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n for k := range response.Header { customWriter.Header().Set(k, response.Header.Get(k)) } - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return err @@ -700,7 +720,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n if statusCode != http.StatusNotModified && validator.Matched { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return err @@ -718,7 +740,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n for k := range response.Header { customWriter.Header().Set(k, response.Header.Get(k)) } - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -747,8 +771,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header().Set(k, response.Header.Get(k)) } customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err diff --git a/plugins/traefik/override/middleware/writer.go b/plugins/traefik/override/middleware/writer.go index 97b479bd7..31300c95f 100644 --- a/plugins/traefik/override/middleware/writer.go +++ b/plugins/traefik/override/middleware/writer.go @@ -38,6 +38,12 @@ type CustomWriter struct { statusCode int } +func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) { + r.mutex.Lock() + callback(r.Buf) + r.mutex.Unlock() +} + // Header will write the response headers func (r *CustomWriter) Header() http.Header { r.mutex.Lock() diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go index 808926cbf..ab7770c0e 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go @@ -189,7 +189,7 @@ func (s *SouinBaseHandler) Store( ma = time.Duration(responseCc.SMaxAge) * time.Second } else if responseCc.MaxAge >= 0 { ma = time.Duration(responseCc.MaxAge) * time.Second - } else if customWriter.Header().Get("Expires") != "" { + } else if !modeContext.Bypass_response && customWriter.Header().Get("Expires") != "" { exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires")) if err != nil { return nil @@ -249,7 +249,7 @@ func (s *SouinBaseHandler) Store( } res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length")) response, err := httputil.DumpResponse(&res, true) - if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) { + if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode)) { variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header) if isVaryStar { // "Implies that the response is uncacheable" @@ -372,7 +372,9 @@ func (s *SouinBaseHandler) Upstream( } err := s.Store(customWriter, rq, requestCc, cachedKey) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), @@ -423,7 +425,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler statusCode := customWriter.GetStatusCode() if err == nil { if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified { - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) return nil, errors.New("") @@ -444,7 +448,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler ), ) - defer customWriter.Buf.Reset() + defer customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) return singleflightValue{ body: customWriter.Buf.Bytes(), headers: customWriter.Header().Clone(), @@ -493,6 +499,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n handler(rw, rq) return nil } + req := s.context.SetBaseContext(rq) cacheName := req.Context().Value(context.CacheName).(string) @@ -526,7 +533,6 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control")) modeContext := req.Context().Value(context.Mode).(*context.ModeContext) - if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR") @@ -593,14 +599,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NotModified { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return nil } customWriter.WriteHeader(response.StatusCode) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return nil @@ -624,7 +634,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header()[h] = v } customWriter.WriteHeader(response.StatusCode) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -643,7 +655,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() customWriter = NewCustomWriter(req, rw, bufPool) go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { @@ -656,7 +670,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { + if modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) statusCode := customWriter.GetStatusCode() @@ -670,14 +684,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header().Set(k, response.Header.Get(k)) } customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err } rw.WriteHeader(http.StatusGatewayTimeout) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, err := customWriter.Send() return err @@ -691,7 +709,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n for k := range response.Header { customWriter.Header().Set(k, response.Header.Get(k)) } - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, _ = customWriter.Send() return err @@ -700,7 +720,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n if statusCode != http.StatusNotModified && validator.Matched { customWriter.WriteHeader(http.StatusNotModified) - customWriter.Buf.Reset() + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + }) _, _ = customWriter.Send() return err @@ -718,7 +740,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n for k := range response.Header { customWriter.Header().Set(k, response.Header.Get(k)) } - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err @@ -747,8 +771,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n customWriter.Header().Set(k, response.Header.Get(k)) } customWriter.WriteHeader(response.StatusCode) - customWriter.Buf.Reset() - _, _ = io.Copy(customWriter.Buf, response.Body) + customWriter.handleBuffer(func(b *bytes.Buffer) { + b.Reset() + _, _ = io.Copy(b, response.Body) + }) _, err := customWriter.Send() return err diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go index 97b479bd7..31300c95f 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go @@ -38,6 +38,12 @@ type CustomWriter struct { statusCode int } +func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) { + r.mutex.Lock() + callback(r.Buf) + r.mutex.Unlock() +} + // Header will write the response headers func (r *CustomWriter) Header() http.Header { r.mutex.Lock()