Skip to content

Commit

Permalink
Allowlist with specified rate limits (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-shim authored Jan 8, 2024
1 parent 7de8438 commit e5165a3
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 64 deletions.
18 changes: 2 additions & 16 deletions common/ratelimit/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ratelimit

import (
"context"
"strings"
"time"

"github.com/Layr-Labs/eigenda/common"
Expand All @@ -12,35 +11,23 @@ type BucketStore = common.KVStore[common.RateBucketParams]

type rateLimiter struct {
globalRateParams common.GlobalRateParams

bucketStore BucketStore
allowlist []string
bucketStore BucketStore

logger common.Logger
}

func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, allowlist []string, logger common.Logger) common.RateLimiter {
func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, logger common.Logger) common.RateLimiter {
return &rateLimiter{
globalRateParams: rateParams,
bucketStore: bucketStore,
allowlist: allowlist,
logger: logger,
}
}

// Checks whether a request from the given requesterID is allowed
func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.RequesterID, blobSize uint, rate common.RateParam) (bool, error) {
// TODO: temporary allowlist that unconditionally allows request
// for testing purposes only
for _, id := range d.allowlist {
if strings.Contains(requesterID, id) {
return true, nil
}
}

// Retrieve bucket params for the requester ID
// This will be from dynamo for Disperser and from local storage for DA node

bucketParams, err := d.bucketStore.GetItem(ctx, requesterID)
if err != nil {

Expand Down Expand Up @@ -68,7 +55,6 @@ func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.Reque

// Update the bucket level
bucketParams.BucketLevels[i] = getBucketLevel(bucketParams.BucketLevels[i], size, interval, deduction)

allowed = allowed && bucketParams.BucketLevels[i] > 0
}

Expand Down
10 changes: 0 additions & 10 deletions common/ratelimit/limiter_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ const (
BucketMultipliersFlagName = "bucket-multipliers"
CountFailedFlagName = "count-failed"
BucketStoreSizeFlagName = "bucket-store-size"
AllowlistFlagName = "allowlist"
)

type Config struct {
common.GlobalRateParams
BucketStoreSize int
UniformRateParam common.RateParam
Allowlist []string
}

func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag {
Expand Down Expand Up @@ -54,13 +52,6 @@ func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag {
EnvVar: common.PrefixEnvVar(envPrefix, "BUCKET_STORE_SIZE"),
Required: false,
},
cli.StringSliceFlag{
Name: common.PrefixFlag(flagPrefix, AllowlistFlagName),
Usage: "Allowlist of IPs to bypass rate limiting",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
},
}
}

Expand Down Expand Up @@ -106,7 +97,6 @@ func ReadCLIConfig(ctx *cli.Context, flagPrefix string) (Config, error) {
cfg.Multipliers = multipliers
cfg.GlobalRateParams.CountFailed = ctx.Bool(common.PrefixFlag(flagPrefix, CountFailedFlagName))
cfg.BucketStoreSize = ctx.Int(common.PrefixFlag(flagPrefix, BucketStoreSizeFlagName))
cfg.Allowlist = ctx.StringSlice(common.PrefixFlag(flagPrefix, AllowlistFlagName))

err := validateConfig(cfg)
if err != nil {
Expand Down
18 changes: 1 addition & 17 deletions common/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func makeTestRatelimiter() (common.RateLimiter, error) {
return nil, err
}

ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, []string{"testRetriever2"}, &mock.Logger{})
ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, &mock.Logger{})

return ratelimiter, nil

Expand All @@ -50,19 +50,3 @@ func TestRatelimit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, false, allow)
}

func TestRatelimitAllowlist(t *testing.T) {
ratelimiter, err := makeTestRatelimiter()
assert.NoError(t, err)

ctx := context.Background()

retreiverID := "testRetriever2"

// 10x more requests allowed for allowlisted IDs
for i := 0; i < 100; i++ {
allow, err := ratelimiter.AllowRequest(ctx, retreiverID, 10, 100)
assert.NoError(t, err)
assert.Equal(t, true, allow)
}
}
51 changes: 51 additions & 0 deletions disperser/apiserver/rate_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package apiserver

import (
"fmt"
"log"
"strconv"
"strings"

"github.com/Layr-Labs/eigenda/common"
"github.com/Layr-Labs/eigenda/core"
Expand All @@ -16,6 +18,7 @@ const (
TotalUnauthBlobRateFlagName = "auth.total-unauth-blob-rate"
PerUserUnauthBlobRateFlagName = "auth.per-user-unauth-blob-rate"
ClientIPHeaderFlagName = "auth.client-ip-header"
AllowlistFlagName = "auth.allowlist"

// We allow the user to specify the blob rate in blobs/sec, but internally we use blobs/sec * 1e6 (i.e. blobs/microsec).
// This is because the rate limiter takes an integer rate.
Expand All @@ -29,9 +32,17 @@ type QuorumRateInfo struct {
TotalUnauthBlobRate common.RateParam
}

type PerUserRateInfo struct {
Throughput common.RateParam
BlobRate common.RateParam
}

type Allowlist = map[string]map[core.QuorumID]PerUserRateInfo

type RateConfig struct {
QuorumRateInfos map[core.QuorumID]QuorumRateInfo
ClientIPHeader string
Allowlist Allowlist
}

func CLIFlags(envPrefix string) []cli.Flag {
Expand Down Expand Up @@ -73,6 +84,13 @@ func CLIFlags(envPrefix string) []cli.Flag {
Value: "",
EnvVar: common.PrefixEnvVar(envPrefix, "CLIENT_IP_HEADER"),
},
cli.StringSliceFlag{
Name: AllowlistFlagName,
Usage: "Allowlist of IPs and corresponding blob/byte rates to bypass rate limiting. Format: <IP>:<quorum ID>:<blob rate>:<byte rate>. Example: 127.0.0.1:0:10:10485760",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
},
}
}

Expand Down Expand Up @@ -112,8 +130,41 @@ func ReadCLIConfig(c *cli.Context) (RateConfig, error) {
}
}

// Parse allowlist
allowlist := make(Allowlist)
for _, allowlistEntry := range c.StringSlice(AllowlistFlagName) {
allowlistEntrySplit := strings.Split(allowlistEntry, ":")
if len(allowlistEntrySplit) != 4 {
log.Printf("invalid allowlist entry: entry should contain exactly 4 elements: %s", allowlistEntry)
continue
}
ip := allowlistEntrySplit[0]
quorumID, err := strconv.Atoi(allowlistEntrySplit[1])
if err != nil {
log.Printf("invalid allowlist entry: failed to convert quorum ID from string: %s", allowlistEntry)
continue
}
blobRate, err := strconv.ParseFloat(allowlistEntrySplit[2], 64)
if err != nil {
log.Printf("invalid allowlist entry: failed to convert blob rate from string: %s", allowlistEntry)
continue
}
byteRate, err := strconv.ParseFloat(allowlistEntrySplit[3], 64)
if err != nil {
log.Printf("invalid allowlist entry: failed to convert throughput from string: %s", allowlistEntry)
continue
}
allowlist[ip] = map[core.QuorumID]PerUserRateInfo{
core.QuorumID(quorumID): {
Throughput: common.RateParam(byteRate),
BlobRate: common.RateParam(blobRate * blobRateMultiplier),
},
}
}

return RateConfig{
QuorumRateInfos: quorumRateInfos,
ClientIPHeader: c.String(ClientIPHeaderFlagName),
Allowlist: allowlist,
}, nil
}
80 changes: 65 additions & 15 deletions disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"

Expand All @@ -18,8 +19,10 @@ import (
"google.golang.org/grpc/reflection"
)

var errSystemRateLimit = fmt.Errorf("request ratelimited: system limit")
var errAccountRateLimit = fmt.Errorf("request ratelimited: account limit")
var errSystemBlobRateLimit = fmt.Errorf("request ratelimited: system blob limit")
var errSystemThroughputRateLimit = fmt.Errorf("request ratelimited: system throughput limit")
var errAccountBlobRateLimit = fmt.Errorf("request ratelimited: account blob limit")
var errAccountThroughputRateLimit = fmt.Errorf("request ratelimited: account throughput limit")

const systemAccountKey = "system"

Expand Down Expand Up @@ -55,6 +58,11 @@ func NewDispersalServer(
ratelimiter common.RateLimiter,
rateConfig RateConfig,
) *DispersalServer {
for ip, rateInfoByQuorum := range rateConfig.Allowlist {
for quorumID, rateInfo := range rateInfoByQuorum {
logger.Info("[Allowlist]", "ip", ip, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate)
}
}
return &DispersalServer{
config: config,
blobStore: store,
Expand Down Expand Up @@ -139,9 +147,9 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob
if err != nil {
for _, param := range securityParams {
quorumId := string(uint8(param.GetQuorumId()))
if errors.Is(err, errSystemRateLimit) {
if errors.Is(err, errSystemBlobRateLimit) || errors.Is(err, errSystemThroughputRateLimit) {
s.metrics.HandleSystemRateLimitedRequest(quorumId, blobSize, "DisperseBlob")
} else if errors.Is(err, errAccountRateLimit) {
} else if errors.Is(err, errAccountBlobRateLimit) || errors.Is(err, errAccountThroughputRateLimit) {
s.metrics.HandleAccountRateLimitedRequest(quorumId, blobSize, "DisperseBlob")
} else {
s.metrics.HandleFailedRequest(quorumId, blobSize, "DisperseBlob")
Expand Down Expand Up @@ -173,23 +181,65 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob
}, nil
}

func (s *DispersalServer) getAccountRate(origin string, quorumID core.QuorumID) (*PerUserRateInfo, error) {
unauthRates, ok := s.rateConfig.QuorumRateInfos[quorumID]
if !ok {
return nil, fmt.Errorf("no configured rate exists for quorum %d", quorumID)
}

for ip, rateInfoByQuorum := range s.rateConfig.Allowlist {
if !strings.Contains(origin, ip) {
continue
}

rateInfo, ok := rateInfoByQuorum[quorumID]
if !ok {
continue
}

throughput := unauthRates.PerUserUnauthThroughput
if rateInfo.Throughput > 0 {
throughput = rateInfo.Throughput
}

blobRate := unauthRates.PerUserUnauthBlobRate
if rateInfo.BlobRate > 0 {
blobRate = rateInfo.BlobRate
}

return &PerUserRateInfo{
Throughput: throughput,
BlobRate: blobRate,
}, nil
}

return &PerUserRateInfo{
Throughput: unauthRates.PerUserUnauthThroughput,
BlobRate: unauthRates.PerUserUnauthBlobRate,
}, nil
}

func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *core.Blob, origin string) error {

// TODO(robert): Remove these locks once we have resolved ratelimiting approach
s.mu.Lock()
defer s.mu.Unlock()

for _, param := range blob.RequestHeader.SecurityParams {
for i, param := range blob.RequestHeader.SecurityParams {

rates, ok := s.rateConfig.QuorumRateInfos[param.QuorumID]
if !ok {
return fmt.Errorf("no configured rate exists for quorum %d", param.QuorumID)
}
accountRates, err := s.getAccountRate(origin, param.QuorumID)
if err != nil {
return err
}

// Get the encoded blob size from the blob header. Calculation is done in a way that nodes can replicate
blobSize := len(blob.Data)
length := core.GetBlobLength(uint(blobSize))
encodedLength := core.GetEncodedBlobLength(length, uint8(blob.RequestHeader.SecurityParams[param.QuorumID].QuorumThreshold), uint8(blob.RequestHeader.SecurityParams[param.QuorumID].AdversaryThreshold))
encodedLength := core.GetEncodedBlobLength(length, uint8(param.QuorumThreshold), uint8(param.AdversaryThreshold))
encodedSize := core.GetBlobSize(encodedLength)

s.logger.Debug("checking rate limits", "origin", origin, "quorum", param.QuorumID, "encodedSize", encodedSize, "blobSize", blobSize)
Expand All @@ -202,7 +252,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *
}
if !allowed {
s.logger.Warn("system byte ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", rates.TotalUnauthThroughput)
return errSystemRateLimit
return errSystemThroughputRateLimit
}

systemQuorumKey = fmt.Sprintf("%s:%d-blobrate", systemAccountKey, param.QuorumID)
Expand All @@ -212,35 +262,35 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *
}
if !allowed {
s.logger.Warn("system blob ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", float32(rates.TotalUnauthBlobRate)/blobRateMultiplier)
return errSystemRateLimit
return errSystemBlobRateLimit
}

// Check Account Ratelimit

blob.RequestHeader.AccountID = "ip:" + origin

userQuorumKey := fmt.Sprintf("%s:%d", blob.RequestHeader.AccountID, param.QuorumID)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, rates.PerUserUnauthThroughput)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, accountRates.Throughput)
if err != nil {
return fmt.Errorf("ratelimiter error: %v", err)
}
if !allowed {
s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", rates.PerUserUnauthThroughput)
return errAccountRateLimit
s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", accountRates.Throughput)
return errAccountThroughputRateLimit
}

userQuorumKey = fmt.Sprintf("%s:%d-blobrate", blob.RequestHeader.AccountID, param.QuorumID)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, rates.PerUserUnauthBlobRate)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, accountRates.BlobRate)
if err != nil {
return fmt.Errorf("ratelimiter error: %v", err)
}
if !allowed {
s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(rates.PerUserUnauthBlobRate)/blobRateMultiplier)
return errAccountRateLimit
s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(accountRates.BlobRate)/blobRateMultiplier)
return errAccountBlobRateLimit
}

// Update the quorum rate
blob.RequestHeader.SecurityParams[param.QuorumID].QuorumRate = rates.PerUserUnauthThroughput
blob.RequestHeader.SecurityParams[i].QuorumRate = accountRates.Throughput
}
return nil

Expand Down
Loading

0 comments on commit e5165a3

Please sign in to comment.