Skip to content

Commit

Permalink
Client IP detection (#30) (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
mooselumph authored Nov 16, 2023
1 parent c2e40cb commit a5667d2
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 32 deletions.
56 changes: 26 additions & 30 deletions common/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,23 @@ type RateBucketParams struct {
LastRequestTime time.Time
}

func GetClientAddress(ctx context.Context, header string) (string, error) {
if header != "" {
// GetClientAddress returns the client address from the context. If the header is not empty, it will
// take the ip address located at the `numProxies“ position from the end of the header. If the ip address cannot be
// found in the header, it will use the connection ip if `alloweDirectionConnection` is true. Otherwise, it will return
// an error.
func GetClientAddress(ctx context.Context, header string, numProxies int, allowDirectConnectionFallback bool) (string, error) {

if header != "" && numProxies > 0 {
md, ok := metadata.FromIncomingContext(ctx)
if !ok || len(md.Get(header)) == 0 {
return "", fmt.Errorf("failed to get ip from header")
if ok && len(md.Get(header)) > 0 {
parts := splitHeader(md.Get(header))
if len(parts) >= numProxies {
return parts[len(parts)-numProxies], nil
}
}
return md.Get(header)[len(md.Get(header))-1], nil
} else {
}

if header == "" || allowDirectConnectionFallback {
p, ok := peer.FromContext(ctx)
if !ok {
return "", fmt.Errorf("failed to get peer from request")
Expand All @@ -65,32 +74,19 @@ func GetClientAddress(ctx context.Context, header string) (string, error) {
}
return host, nil
}

return "", fmt.Errorf("failed to get ip")
}

func GetClientAddressCloudfare(ctx context.Context, header string) (string, error) {
if header != "" {
md, ok := metadata.FromIncomingContext(ctx)
if !ok || len(md.Get(header)) == 0 {
return "", fmt.Errorf("failed to get ip from header")
func splitHeader(header []string) []string {
var result []string
for _, h := range header {
for _, p := range strings.Split(h, ",") {
trimmed := strings.TrimSpace(p)
if trimmed != "" {
result = append(result, trimmed)
}
}
addr := md.Get(header)[len(md.Get(header))-1]
// split the address
parts := strings.Split(addr, ",")
if len(parts) == 2 {
return parts[0], nil
}
return addr, nil

} else {
p, ok := peer.FromContext(ctx)
if !ok {
return "", fmt.Errorf("failed to get peer from request")
}
addr := p.Addr.String()
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
return host, nil
}
return result
}
54 changes: 54 additions & 0 deletions common/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package common_test

import (
"context"
"net"
"testing"

"github.com/Layr-Labs/eigenda/common"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
)

func TestGetClientAddress(t *testing.T) {

// Make test context
// Four proxies. The last proxy's IP address will be in the connection, not in the header
md := metadata.Pairs("x-forwarded-for", "dummyheader, clientip", "x-forwarded-for", "proxy1, proxy2", "x-forwarded-for", "proxy3")

ctx := peer.NewContext(context.Background(), &peer.Peer{
Addr: &net.TCPAddr{
IP: net.ParseIP("0.0.0.0"),
Port: 1234,
},
})

ctx = metadata.NewIncomingContext(ctx, md)
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
t.Fatal("failed to get metadata from context")
}
assert.Equal(t, []string{"dummyheader, clientip", "proxy1, proxy2", "proxy3"}, md.Get("x-forwarded-for"))

ip, err := common.GetClientAddress(ctx, "x-forwarded-for", 4, false)
assert.NoError(t, err)
assert.Equal(t, "clientip", ip)

ip, err = common.GetClientAddress(ctx, "x-forwarded-for", 7, false)
assert.Error(t, err)
assert.Equal(t, "", ip)

ip, err = common.GetClientAddress(ctx, "x-forwarded-for", 7, true)
assert.NoError(t, err)
assert.Equal(t, "0.0.0.0", ip)

ip, err = common.GetClientAddress(ctx, "", 0, true)
assert.NoError(t, err)
assert.Equal(t, "0.0.0.0", ip)

ip, err = common.GetClientAddress(ctx, "", 0, false)
assert.NoError(t, err)
assert.Equal(t, "0.0.0.0", ip)

}
2 changes: 1 addition & 1 deletion disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob

blob := getBlobFromRequest(req)

origin, err := common.GetClientAddressCloudfare(ctx, s.rateConfig.ClientIPHeader)
origin, err := common.GetClientAddress(ctx, s.rateConfig.ClientIPHeader, 1, true)
if err != nil {
for _, param := range securityParams {
quorumId := string(uint8(param.GetQuorumId()))
Expand Down
2 changes: 1 addition & 1 deletion node/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (s *Server) RetrieveChunks(ctx context.Context, in *pb.RetrieveChunksReques
return nil, err
}

retrieverID, err := common.GetClientAddress(ctx, s.config.ClientIPHeader)
retrieverID, err := common.GetClientAddress(ctx, s.config.ClientIPHeader, 1, false)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit a5667d2

Please sign in to comment.