diff --git a/common/ratelimit.go b/common/ratelimit.go index 8731d18f61..bd647a21cc 100644 --- a/common/ratelimit.go +++ b/common/ratelimit.go @@ -46,14 +46,23 @@ type RateBucketParams struct { LastRequestTime time.Time } -func GetClientAddress(ctx context.Context, header string) (string, error) { - if header != "" { +// GetClientAddress returns the client address from the context. If the header is not empty, it will +// take the ip address located at the `numProxies“ position from the end of the header. If the ip address cannot be +// found in the header, it will use the connection ip if `alloweDirectionConnection` is true. Otherwise, it will return +// an error. +func GetClientAddress(ctx context.Context, header string, numProxies int, allowDirectConnectionFallback bool) (string, error) { + + if header != "" && numProxies > 0 { md, ok := metadata.FromIncomingContext(ctx) - if !ok || len(md.Get(header)) == 0 { - return "", fmt.Errorf("failed to get ip from header") + if ok && len(md.Get(header)) > 0 { + parts := splitHeader(md.Get(header)) + if len(parts) >= numProxies { + return parts[len(parts)-numProxies], nil + } } - return md.Get(header)[len(md.Get(header))-1], nil - } else { + } + + if header == "" || allowDirectConnectionFallback { p, ok := peer.FromContext(ctx) if !ok { return "", fmt.Errorf("failed to get peer from request") @@ -65,32 +74,19 @@ func GetClientAddress(ctx context.Context, header string) (string, error) { } return host, nil } + + return "", fmt.Errorf("failed to get ip") } -func GetClientAddressCloudfare(ctx context.Context, header string) (string, error) { - if header != "" { - md, ok := metadata.FromIncomingContext(ctx) - if !ok || len(md.Get(header)) == 0 { - return "", fmt.Errorf("failed to get ip from header") +func splitHeader(header []string) []string { + var result []string + for _, h := range header { + for _, p := range strings.Split(h, ",") { + trimmed := strings.TrimSpace(p) + if trimmed != "" { + result = append(result, trimmed) + } } - addr := md.Get(header)[len(md.Get(header))-1] - // split the address - parts := strings.Split(addr, ",") - if len(parts) == 2 { - return parts[0], nil - } - return addr, nil - - } else { - p, ok := peer.FromContext(ctx) - if !ok { - return "", fmt.Errorf("failed to get peer from request") - } - addr := p.Addr.String() - host, _, err := net.SplitHostPort(addr) - if err != nil { - return "", err - } - return host, nil } + return result } diff --git a/common/ratelimit_test.go b/common/ratelimit_test.go new file mode 100644 index 0000000000..e3d6f97cbd --- /dev/null +++ b/common/ratelimit_test.go @@ -0,0 +1,54 @@ +package common_test + +import ( + "context" + "net" + "testing" + + "github.com/Layr-Labs/eigenda/common" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +func TestGetClientAddress(t *testing.T) { + + // Make test context + // Four proxies. The last proxy's IP address will be in the connection, not in the header + md := metadata.Pairs("x-forwarded-for", "dummyheader, clientip", "x-forwarded-for", "proxy1, proxy2", "x-forwarded-for", "proxy3") + + ctx := peer.NewContext(context.Background(), &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("0.0.0.0"), + Port: 1234, + }, + }) + + ctx = metadata.NewIncomingContext(ctx, md) + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + t.Fatal("failed to get metadata from context") + } + assert.Equal(t, []string{"dummyheader, clientip", "proxy1, proxy2", "proxy3"}, md.Get("x-forwarded-for")) + + ip, err := common.GetClientAddress(ctx, "x-forwarded-for", 4, false) + assert.NoError(t, err) + assert.Equal(t, "clientip", ip) + + ip, err = common.GetClientAddress(ctx, "x-forwarded-for", 7, false) + assert.Error(t, err) + assert.Equal(t, "", ip) + + ip, err = common.GetClientAddress(ctx, "x-forwarded-for", 7, true) + assert.NoError(t, err) + assert.Equal(t, "0.0.0.0", ip) + + ip, err = common.GetClientAddress(ctx, "", 0, true) + assert.NoError(t, err) + assert.Equal(t, "0.0.0.0", ip) + + ip, err = common.GetClientAddress(ctx, "", 0, false) + assert.NoError(t, err) + assert.Equal(t, "0.0.0.0", ip) + +} diff --git a/disperser/apiserver/server.go b/disperser/apiserver/server.go index b0349cfaee..3394952529 100644 --- a/disperser/apiserver/server.go +++ b/disperser/apiserver/server.go @@ -114,7 +114,7 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob blob := getBlobFromRequest(req) - origin, err := common.GetClientAddressCloudfare(ctx, s.rateConfig.ClientIPHeader) + origin, err := common.GetClientAddress(ctx, s.rateConfig.ClientIPHeader, 1, true) if err != nil { for _, param := range securityParams { quorumId := string(uint8(param.GetQuorumId())) diff --git a/node/grpc/server.go b/node/grpc/server.go index ddbd78f7ff..6894bd1530 100644 --- a/node/grpc/server.go +++ b/node/grpc/server.go @@ -184,7 +184,7 @@ func (s *Server) RetrieveChunks(ctx context.Context, in *pb.RetrieveChunksReques return nil, err } - retrieverID, err := common.GetClientAddress(ctx, s.config.ClientIPHeader) + retrieverID, err := common.GetClientAddress(ctx, s.config.ClientIPHeader, 1, false) if err != nil { return nil, err }