From f840043210cf8de3b724e0ff293cddceaac9bd5c Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Tue, 6 Aug 2024 11:30:24 +0100 Subject: [PATCH] Only forward X-Forwarded-* when not using TLS In [8d19698d] we added forwarding of existing X-Forwarded-* headers, because these are often needed when running behind another proxy layer. However, we shouldn't be trusting the headers as set by the original client. So in the case where there is no additional proxy layer between us and the client, we should drop the headers instead. Given that our TLS mode is intended for direct client access (and requires direct access for certificate provisioning), let's make the condition as follows: - When running with TLS, assume requests are from external clients, and drop the headers. - When running without TLS, allow that there may be another proxy terminating TLS downstream of us, and keep the headers. In a case where the default logic does not apply, the `FORWARD_HEADERS` env can be used to override it. --- README.md | 1 + internal/config.go | 14 ++++++++++++-- internal/config_test.go | 35 +++++++++++++++++++++++++++++++++++ internal/handler.go | 3 ++- internal/handler_test.go | 27 +++++++++++++++++++++++++-- internal/proxy_handler.go | 26 +++++++++++++++----------- internal/server.go | 2 +- internal/service.go | 1 + 8 files changed, 92 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 0372623..8a1f758 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ environment variables that you can set. | `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) | | `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | | `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | +| `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client.| Disabled when running with TLS; enabled otherwise | | `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | To prevent naming clashes with your application's own environment variables, diff --git a/internal/config.go b/internal/config.go index 6073b72..d0d2ce2 100644 --- a/internal/config.go +++ b/internal/config.go @@ -59,6 +59,8 @@ type Config struct { HttpReadTimeout time.Duration HttpWriteTimeout time.Duration + ForwardHeaders bool + LogLevel slog.Level } @@ -72,7 +74,7 @@ func NewConfig() (*Config, error) { logLevel = slog.LevelDebug } - return &Config{ + config := &Config{ TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort), UpstreamCommand: os.Args[1], UpstreamArgs: os.Args[2:], @@ -96,7 +98,15 @@ func NewConfig() (*Config, error) { HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), LogLevel: logLevel, - }, nil + } + + config.ForwardHeaders = getEnvBool("FORWARD_HEADERS", !config.HasTLS()) + + return config, nil +} + +func (c *Config) HasTLS() bool { + return len(c.TLSDomains) > 0 } func findEnv(key string) (string, bool) { diff --git a/internal/config_test.go b/internal/config_test.go index 2a60717..c8415d9 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -17,6 +17,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{}, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) }) t.Run("with an empty TLS_DOMAIN", func(t *testing.T) { @@ -27,6 +29,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{}, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) }) t.Run("with single TLS_DOMAIN", func(t *testing.T) { @@ -37,6 +41,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) t.Run("with multiple TLS_DOMAIN", func(t *testing.T) { @@ -47,6 +53,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) t.Run("with TLS_DOMAIN containing whitespace", func(t *testing.T) { @@ -57,6 +65,33 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) + }) + + t.Run("overriding with FORWARD_HEADERS when using TLS", func(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "TLS_DOMAIN", "example.com") + usingEnvVar(t, "FORWARD_HEADERS", "true") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Equal(t, []string{"example.com"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) + }) + + t.Run("overriding with FORWARD_HEADERS when not using TLS", func(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "FORWARD_HEADERS", "false") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Empty(t, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) } diff --git a/internal/handler.go b/internal/handler.go index 7372132..11feac9 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -15,10 +15,11 @@ type HandlerOptions struct { maxRequestBody int targetUrl *url.URL xSendfileEnabled bool + forwardHeaders bool } func NewHandler(options HandlerOptions) http.Handler { - handler := NewProxyHandler(options.targetUrl, options.badGatewayPage) + handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders) handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler) handler = NewSendfileHandler(options.xSendfileEnabled, handler) handler = gzhttp.GzipHandler(handler) diff --git a/internal/handler_test.go b/internal/handler_test.go index 49ce742..a3485b6 100644 --- a/internal/handler_test.go +++ b/internal/handler_test.go @@ -199,9 +199,9 @@ func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) { h.ServeHTTP(w, r) } -func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) { +func TestHandlerXForwardedHeadersForwardsExistingHeadersWhenForwardingEnabled(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) + assert.Equal(t, "4.3.2.1, 1.2.3.4", r.Header.Get("X-Forwarded-For")) assert.Equal(t, "other.example.com", r.Header.Get("X-Forwarded-Host")) assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto")) })) @@ -211,6 +211,28 @@ func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "http://example.org", nil) + r.Header.Set("X-Forwarded-For", "4.3.2.1") + r.Header.Set("X-Forwarded-Proto", "https") + r.Header.Set("X-Forwarded-Host", "other.example.com") + r.RemoteAddr = "1.2.3.4:1234" + h.ServeHTTP(w, r) +} + +func TestHandlerXForwardedHeadersDropsExistingHeadersWhenForwardingNotEnabled(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) + assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host")) + assert.Equal(t, "http", r.Header.Get("X-Forwarded-Proto")) + })) + defer upstream.Close() + + options := handlerOptions(upstream.URL) + options.forwardHeaders = false + h := NewHandler(options) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "http://example.org", nil) + r.Header.Set("X-Forwarded-For", "4.3.2.1") r.Header.Set("X-Forwarded-Proto", "https") r.Header.Set("X-Forwarded-Host", "other.example.com") r.RemoteAddr = "1.2.3.4:1234" @@ -228,5 +250,6 @@ func handlerOptions(targetUrl string) HandlerOptions { xSendfileEnabled: true, maxCacheableResponseBody: 1024, badGatewayPage: "", + forwardHeaders: true, } } diff --git a/internal/proxy_handler.go b/internal/proxy_handler.go index 50b4214..c6e157d 100644 --- a/internal/proxy_handler.go +++ b/internal/proxy_handler.go @@ -9,13 +9,12 @@ import ( "os" ) -func NewProxyHandler(targetUrl *url.URL, badGatewayPage string) http.Handler { +func NewProxyHandler(targetUrl *url.URL, badGatewayPage string, forwardHeaders bool) http.Handler { return &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(targetUrl) r.Out.Host = r.In.Host - r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] - setXForwarded(r) + setXForwarded(r, forwardHeaders) }, ErrorHandler: ProxyErrorHandler(badGatewayPage), Transport: createProxyTransport(), @@ -47,16 +46,21 @@ func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *htt } } -func setXForwarded(r *httputil.ProxyRequest) { - // Populate new headers by default +func setXForwarded(r *httputil.ProxyRequest, forwardHeaders bool) { + if forwardHeaders { + r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] + } + r.SetXForwarded() - // Preserve original headers if we had them - if r.In.Header.Get("X-Forwarded-Host") != "" { - r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host")) - } - if r.In.Header.Get("X-Forwarded-Proto") != "" { - r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto")) + if forwardHeaders { + // Preserve original headers if we had them + if r.In.Header.Get("X-Forwarded-Host") != "" { + r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host")) + } + if r.In.Header.Get("X-Forwarded-Proto") != "" { + r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto")) + } } } diff --git a/internal/server.go b/internal/server.go index 1f48227..b168402 100644 --- a/internal/server.go +++ b/internal/server.go @@ -31,7 +31,7 @@ func (s *Server) Start() { httpAddress := fmt.Sprintf(":%d", s.config.HttpPort) httpsAddress := fmt.Sprintf(":%d", s.config.HttpsPort) - if len(s.config.TLSDomains) > 0 { + if s.config.HasTLS() { manager := s.certManager() s.httpServer = s.defaultHttpServer(httpAddress) diff --git a/internal/service.go b/internal/service.go index eedea69..5418eb0 100644 --- a/internal/service.go +++ b/internal/service.go @@ -25,6 +25,7 @@ func (s *Service) Run() int { maxCacheableResponseBody: s.config.MaxCacheItemSizeBytes, maxRequestBody: s.config.MaxRequestBody, badGatewayPage: s.config.BadGatewayPage, + forwardHeaders: s.config.ForwardHeaders, } handler := NewHandler(handlerOptions)