diff --git a/README.md b/README.md index 0372623..8a1f758 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ environment variables that you can set. | `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) | | `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | | `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | +| `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client.| Disabled when running with TLS; enabled otherwise | | `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | To prevent naming clashes with your application's own environment variables, diff --git a/internal/config.go b/internal/config.go index 6073b72..d0d2ce2 100644 --- a/internal/config.go +++ b/internal/config.go @@ -59,6 +59,8 @@ type Config struct { HttpReadTimeout time.Duration HttpWriteTimeout time.Duration + ForwardHeaders bool + LogLevel slog.Level } @@ -72,7 +74,7 @@ func NewConfig() (*Config, error) { logLevel = slog.LevelDebug } - return &Config{ + config := &Config{ TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort), UpstreamCommand: os.Args[1], UpstreamArgs: os.Args[2:], @@ -96,7 +98,15 @@ func NewConfig() (*Config, error) { HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), LogLevel: logLevel, - }, nil + } + + config.ForwardHeaders = getEnvBool("FORWARD_HEADERS", !config.HasTLS()) + + return config, nil +} + +func (c *Config) HasTLS() bool { + return len(c.TLSDomains) > 0 } func findEnv(key string) (string, bool) { diff --git a/internal/config_test.go b/internal/config_test.go index 2a60717..c8415d9 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -17,6 +17,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{}, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) }) t.Run("with an empty TLS_DOMAIN", func(t *testing.T) { @@ -27,6 +29,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{}, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) }) t.Run("with single TLS_DOMAIN", func(t *testing.T) { @@ -37,6 +41,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) t.Run("with multiple TLS_DOMAIN", func(t *testing.T) { @@ -47,6 +53,8 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) t.Run("with TLS_DOMAIN containing whitespace", func(t *testing.T) { @@ -57,6 +65,33 @@ func TestConfig_tls(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) + }) + + t.Run("overriding with FORWARD_HEADERS when using TLS", func(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "TLS_DOMAIN", "example.com") + usingEnvVar(t, "FORWARD_HEADERS", "true") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Equal(t, []string{"example.com"}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.True(t, c.ForwardHeaders) + }) + + t.Run("overriding with FORWARD_HEADERS when not using TLS", func(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "FORWARD_HEADERS", "false") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Empty(t, c.TLSDomains) + assert.False(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) }) } diff --git a/internal/handler.go b/internal/handler.go index 7372132..11feac9 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -15,10 +15,11 @@ type HandlerOptions struct { maxRequestBody int targetUrl *url.URL xSendfileEnabled bool + forwardHeaders bool } func NewHandler(options HandlerOptions) http.Handler { - handler := NewProxyHandler(options.targetUrl, options.badGatewayPage) + handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders) handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler) handler = NewSendfileHandler(options.xSendfileEnabled, handler) handler = gzhttp.GzipHandler(handler) diff --git a/internal/handler_test.go b/internal/handler_test.go index 49ce742..a3485b6 100644 --- a/internal/handler_test.go +++ b/internal/handler_test.go @@ -199,9 +199,9 @@ func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) { h.ServeHTTP(w, r) } -func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) { +func TestHandlerXForwardedHeadersForwardsExistingHeadersWhenForwardingEnabled(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) + assert.Equal(t, "4.3.2.1, 1.2.3.4", r.Header.Get("X-Forwarded-For")) assert.Equal(t, "other.example.com", r.Header.Get("X-Forwarded-Host")) assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto")) })) @@ -211,6 +211,28 @@ func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "http://example.org", nil) + r.Header.Set("X-Forwarded-For", "4.3.2.1") + r.Header.Set("X-Forwarded-Proto", "https") + r.Header.Set("X-Forwarded-Host", "other.example.com") + r.RemoteAddr = "1.2.3.4:1234" + h.ServeHTTP(w, r) +} + +func TestHandlerXForwardedHeadersDropsExistingHeadersWhenForwardingNotEnabled(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) + assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host")) + assert.Equal(t, "http", r.Header.Get("X-Forwarded-Proto")) + })) + defer upstream.Close() + + options := handlerOptions(upstream.URL) + options.forwardHeaders = false + h := NewHandler(options) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "http://example.org", nil) + r.Header.Set("X-Forwarded-For", "4.3.2.1") r.Header.Set("X-Forwarded-Proto", "https") r.Header.Set("X-Forwarded-Host", "other.example.com") r.RemoteAddr = "1.2.3.4:1234" @@ -228,5 +250,6 @@ func handlerOptions(targetUrl string) HandlerOptions { xSendfileEnabled: true, maxCacheableResponseBody: 1024, badGatewayPage: "", + forwardHeaders: true, } } diff --git a/internal/proxy_handler.go b/internal/proxy_handler.go index 50b4214..c6e157d 100644 --- a/internal/proxy_handler.go +++ b/internal/proxy_handler.go @@ -9,13 +9,12 @@ import ( "os" ) -func NewProxyHandler(targetUrl *url.URL, badGatewayPage string) http.Handler { +func NewProxyHandler(targetUrl *url.URL, badGatewayPage string, forwardHeaders bool) http.Handler { return &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(targetUrl) r.Out.Host = r.In.Host - r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] - setXForwarded(r) + setXForwarded(r, forwardHeaders) }, ErrorHandler: ProxyErrorHandler(badGatewayPage), Transport: createProxyTransport(), @@ -47,16 +46,21 @@ func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *htt } } -func setXForwarded(r *httputil.ProxyRequest) { - // Populate new headers by default +func setXForwarded(r *httputil.ProxyRequest, forwardHeaders bool) { + if forwardHeaders { + r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] + } + r.SetXForwarded() - // Preserve original headers if we had them - if r.In.Header.Get("X-Forwarded-Host") != "" { - r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host")) - } - if r.In.Header.Get("X-Forwarded-Proto") != "" { - r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto")) + if forwardHeaders { + // Preserve original headers if we had them + if r.In.Header.Get("X-Forwarded-Host") != "" { + r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host")) + } + if r.In.Header.Get("X-Forwarded-Proto") != "" { + r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto")) + } } } diff --git a/internal/server.go b/internal/server.go index 1f48227..b168402 100644 --- a/internal/server.go +++ b/internal/server.go @@ -31,7 +31,7 @@ func (s *Server) Start() { httpAddress := fmt.Sprintf(":%d", s.config.HttpPort) httpsAddress := fmt.Sprintf(":%d", s.config.HttpsPort) - if len(s.config.TLSDomains) > 0 { + if s.config.HasTLS() { manager := s.certManager() s.httpServer = s.defaultHttpServer(httpAddress) diff --git a/internal/service.go b/internal/service.go index eedea69..5418eb0 100644 --- a/internal/service.go +++ b/internal/service.go @@ -25,6 +25,7 @@ func (s *Service) Run() int { maxCacheableResponseBody: s.config.MaxCacheItemSizeBytes, maxRequestBody: s.config.MaxRequestBody, badGatewayPage: s.config.BadGatewayPage, + forwardHeaders: s.config.ForwardHeaders, } handler := NewHandler(handlerOptions)