Skip to content

Commit

Permalink
Only forward X-Forwarded-* when not using TLS
Browse files Browse the repository at this point in the history
In [8d19698] we added forwarding of existing X-Forwarded-* headers,
because these are often needed when running behind another proxy layer.

However, we shouldn't be trusting the headers as set by the original
client. So in the case where there is no additional proxy layer between
us and the client, we should drop the headers instead.

Given that our TLS mode is intended for direct client access (and
requires direct access for certificate provisioning), let's make the
condition as follows:

- When running with TLS, assume requests are from external clients, and
  drop the headers.
- When running without TLS, allow that there may be another proxy
  terminating TLS downstream of us, and keep the headers.

In a case where the default logic does not apply, the `FORWARD_HEADERS`
env can be used to override it.
  • Loading branch information
kevinmcconnell committed Aug 6, 2024
1 parent 0ad16ad commit f840043
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ environment variables that you can set.
| `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) |
| `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None |
| `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None |
| `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client.| Disabled when running with TLS; enabled otherwise |
| `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled |

To prevent naming clashes with your application's own environment variables,
Expand Down
14 changes: 12 additions & 2 deletions internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type Config struct {
HttpReadTimeout time.Duration
HttpWriteTimeout time.Duration

ForwardHeaders bool

LogLevel slog.Level
}

Expand All @@ -72,7 +74,7 @@ func NewConfig() (*Config, error) {
logLevel = slog.LevelDebug
}

return &Config{
config := &Config{
TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort),
UpstreamCommand: os.Args[1],
UpstreamArgs: os.Args[2:],
Expand All @@ -96,7 +98,15 @@ func NewConfig() (*Config, error) {
HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout),

LogLevel: logLevel,
}, nil
}

config.ForwardHeaders = getEnvBool("FORWARD_HEADERS", !config.HasTLS())

return config, nil
}

func (c *Config) HasTLS() bool {
return len(c.TLSDomains) > 0
}

func findEnv(key string) (string, bool) {
Expand Down
35 changes: 35 additions & 0 deletions internal/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ func TestConfig_tls(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, []string{}, c.TLSDomains)
assert.False(t, c.HasTLS())
assert.True(t, c.ForwardHeaders)
})

t.Run("with an empty TLS_DOMAIN", func(t *testing.T) {
Expand All @@ -27,6 +29,8 @@ func TestConfig_tls(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, []string{}, c.TLSDomains)
assert.False(t, c.HasTLS())
assert.True(t, c.ForwardHeaders)
})

t.Run("with single TLS_DOMAIN", func(t *testing.T) {
Expand All @@ -37,6 +41,8 @@ func TestConfig_tls(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, []string{"example.com"}, c.TLSDomains)
assert.True(t, c.HasTLS())
assert.False(t, c.ForwardHeaders)
})

t.Run("with multiple TLS_DOMAIN", func(t *testing.T) {
Expand All @@ -47,6 +53,8 @@ func TestConfig_tls(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains)
assert.True(t, c.HasTLS())
assert.False(t, c.ForwardHeaders)
})

t.Run("with TLS_DOMAIN containing whitespace", func(t *testing.T) {
Expand All @@ -57,6 +65,33 @@ func TestConfig_tls(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains)
assert.True(t, c.HasTLS())
assert.False(t, c.ForwardHeaders)
})

t.Run("overriding with FORWARD_HEADERS when using TLS", func(t *testing.T) {
usingProgramArgs(t, "thruster", "echo", "hello")
usingEnvVar(t, "TLS_DOMAIN", "example.com")
usingEnvVar(t, "FORWARD_HEADERS", "true")

c, err := NewConfig()
require.NoError(t, err)

assert.Equal(t, []string{"example.com"}, c.TLSDomains)
assert.True(t, c.HasTLS())
assert.True(t, c.ForwardHeaders)
})

t.Run("overriding with FORWARD_HEADERS when not using TLS", func(t *testing.T) {
usingProgramArgs(t, "thruster", "echo", "hello")
usingEnvVar(t, "FORWARD_HEADERS", "false")

c, err := NewConfig()
require.NoError(t, err)

assert.Empty(t, c.TLSDomains)
assert.False(t, c.HasTLS())
assert.False(t, c.ForwardHeaders)
})
}

Expand Down
3 changes: 2 additions & 1 deletion internal/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ type HandlerOptions struct {
maxRequestBody int
targetUrl *url.URL
xSendfileEnabled bool
forwardHeaders bool
}

func NewHandler(options HandlerOptions) http.Handler {
handler := NewProxyHandler(options.targetUrl, options.badGatewayPage)
handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders)
handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler)
handler = NewSendfileHandler(options.xSendfileEnabled, handler)
handler = gzhttp.GzipHandler(handler)
Expand Down
27 changes: 25 additions & 2 deletions internal/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) {
h.ServeHTTP(w, r)
}

func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) {
func TestHandlerXForwardedHeadersForwardsExistingHeadersWhenForwardingEnabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For"))
assert.Equal(t, "4.3.2.1, 1.2.3.4", r.Header.Get("X-Forwarded-For"))
assert.Equal(t, "other.example.com", r.Header.Get("X-Forwarded-Host"))
assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto"))
}))
Expand All @@ -211,6 +211,28 @@ func TestHandlerXForwardedHeadersRespectExistingHeaders(t *testing.T) {

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.org", nil)
r.Header.Set("X-Forwarded-For", "4.3.2.1")
r.Header.Set("X-Forwarded-Proto", "https")
r.Header.Set("X-Forwarded-Host", "other.example.com")
r.RemoteAddr = "1.2.3.4:1234"
h.ServeHTTP(w, r)
}

func TestHandlerXForwardedHeadersDropsExistingHeadersWhenForwardingNotEnabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For"))
assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host"))
assert.Equal(t, "http", r.Header.Get("X-Forwarded-Proto"))
}))
defer upstream.Close()

options := handlerOptions(upstream.URL)
options.forwardHeaders = false
h := NewHandler(options)

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.org", nil)
r.Header.Set("X-Forwarded-For", "4.3.2.1")
r.Header.Set("X-Forwarded-Proto", "https")
r.Header.Set("X-Forwarded-Host", "other.example.com")
r.RemoteAddr = "1.2.3.4:1234"
Expand All @@ -228,5 +250,6 @@ func handlerOptions(targetUrl string) HandlerOptions {
xSendfileEnabled: true,
maxCacheableResponseBody: 1024,
badGatewayPage: "",
forwardHeaders: true,
}
}
26 changes: 15 additions & 11 deletions internal/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import (
"os"
)

func NewProxyHandler(targetUrl *url.URL, badGatewayPage string) http.Handler {
func NewProxyHandler(targetUrl *url.URL, badGatewayPage string, forwardHeaders bool) http.Handler {
return &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(targetUrl)
r.Out.Host = r.In.Host
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
setXForwarded(r)
setXForwarded(r, forwardHeaders)
},
ErrorHandler: ProxyErrorHandler(badGatewayPage),
Transport: createProxyTransport(),
Expand Down Expand Up @@ -47,16 +46,21 @@ func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *htt
}
}

func setXForwarded(r *httputil.ProxyRequest) {
// Populate new headers by default
func setXForwarded(r *httputil.ProxyRequest, forwardHeaders bool) {
if forwardHeaders {
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
}

r.SetXForwarded()

// Preserve original headers if we had them
if r.In.Header.Get("X-Forwarded-Host") != "" {
r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host"))
}
if r.In.Header.Get("X-Forwarded-Proto") != "" {
r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto"))
if forwardHeaders {
// Preserve original headers if we had them
if r.In.Header.Get("X-Forwarded-Host") != "" {
r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host"))
}
if r.In.Header.Get("X-Forwarded-Proto") != "" {
r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto"))
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (s *Server) Start() {
httpAddress := fmt.Sprintf(":%d", s.config.HttpPort)
httpsAddress := fmt.Sprintf(":%d", s.config.HttpsPort)

if len(s.config.TLSDomains) > 0 {
if s.config.HasTLS() {
manager := s.certManager()

s.httpServer = s.defaultHttpServer(httpAddress)
Expand Down
1 change: 1 addition & 0 deletions internal/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func (s *Service) Run() int {
maxCacheableResponseBody: s.config.MaxCacheItemSizeBytes,
maxRequestBody: s.config.MaxRequestBody,
badGatewayPage: s.config.BadGatewayPage,
forwardHeaders: s.config.ForwardHeaders,
}

handler := NewHandler(handlerOptions)
Expand Down

0 comments on commit f840043

Please sign in to comment.