Skip to content

Commit

Permalink
polygon/p2p: add request chunking to FetchHeaders (#9536)
Browse files Browse the repository at this point in the history
This PR adds chunking logic to FetchHeaders and corresponding unit tests
so that we stay within the soft limits:
1. 2 MB response size
2. 1024 headers
  • Loading branch information
taratorio authored Feb 29, 2024
1 parent 6267419 commit d18fdaa
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 57 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
modernc.org/mathutil v1.6.0
modernc.org/sqlite v1.28.0
pgregory.net/rapid v1.1.0
sigs.k8s.io/yaml v1.4.0
Expand Down Expand Up @@ -280,7 +281,6 @@ require (
modernc.org/cc/v3 v3.41.0 // indirect
modernc.org/ccgo/v3 v3.16.15 // indirect
modernc.org/libc v1.29.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.7.2 // indirect
modernc.org/opt v0.1.3 // indirect
modernc.org/strutil v1.2.0 // indirect
Expand Down
111 changes: 68 additions & 43 deletions polygon/p2p/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
"time"

"github.com/ledgerwatch/log/v3"
"modernc.org/mathutil"

libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/gointerfaces/sentry"
"github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/eth/protocols/eth"
Expand Down Expand Up @@ -55,38 +56,82 @@ func (f *fetcher) FetchHeaders(ctx context.Context, start uint64, end uint64, pe
}
}

// Soft response limits are:
// 1. 2 MB size
// 2. 1024 headers
//
// A header is approximately 500 bytes, hence 1024 headers should be less than 2 MB.
// As a simplification we can only use MaxHeadersServe for chunking.
amount := end - start
requestId := f.requestIdGenerator()
observer := make(ChanMessageObserver[*sentry.InboundMessage])
chunks := amount / eth.MaxHeadersServe
if amount%eth.MaxHeadersServe > 0 {
chunks++
}

headers := make([]*types.Header, 0, amount)
observer := make(ChanMessageObserver[*sentry.InboundMessage])
f.messageListener.RegisterBlockHeadersObserver(observer)
defer f.messageListener.UnregisterBlockHeadersObserver(observer)

//
// TODO 1) chunk request into smaller ranges if needed to fit in the 2 MiB response size soft limit
// and also 1024 max headers server (check AnswerGetBlockHeadersQuery)
err := f.messageSender.SendGetBlockHeaders(ctx, peerId, eth.GetBlockHeadersPacket66{
RequestId: requestId,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Number: start,
for i := uint64(0); i < chunks; i++ {
chunkStart := start + i*eth.MaxHeadersServe
chunkAmount := mathutil.MinUint64(end-chunkStart, eth.MaxHeadersServe)
requestId := f.requestIdGenerator()

err := f.messageSender.SendGetBlockHeaders(ctx, peerId, eth.GetBlockHeadersPacket66{
RequestId: requestId,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Number: chunkStart,
},
Amount: chunkAmount,
},
Amount: amount,
},
})
if err != nil {
})
if err != nil {
return nil, err
}

headerChunk, err := f.awaitHeadersResponse(ctx, requestId, peerId, observer)
if err != nil {
return nil, err
}

headers = append(headers, headerChunk...)
}

if err := f.validateHeadersResponse(headers, start, end, amount); err != nil {
shouldPenalize := errors.Is(err, &ErrIncorrectOriginHeader{}) ||
errors.Is(err, &ErrTooManyHeaders{}) ||
errors.Is(err, &ErrDisconnectedHeaders{})

if shouldPenalize {
f.logger.Debug("penalizing peer", "peerId", peerId, "err", err.Error())

penalizeErr := f.peerPenalizer.Penalize(ctx, peerId)
if penalizeErr != nil {
err = fmt.Errorf("%w: %w", penalizeErr, err)
}
}

return nil, err
}

return headers, nil
}

func (f *fetcher) awaitHeadersResponse(
ctx context.Context,
requestId uint64,
peerId PeerId,
observer ChanMessageObserver[*sentry.InboundMessage],
) ([]*types.Header, error) {
ctx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()

var headers []*types.Header
var requestReceived bool
for !requestReceived {
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("interrupted while waiting for msg from peer: %w", ctx.Err())
return nil, fmt.Errorf("await headers response interrupted: %w", ctx.Err())
case msg := <-observer:
msgPeerId := PeerIdFromH512(msg.PeerId)
if msgPeerId != peerId {
Expand All @@ -110,29 +155,9 @@ func (f *fetcher) FetchHeaders(ctx context.Context, start uint64, end uint64, pe
continue
}

headers = pkt.BlockHeadersPacket
requestReceived = true
}
}

if err = f.validateHeadersResponse(headers, start, end, amount); err != nil {
shouldPenalize := errors.Is(err, &ErrIncorrectOriginHeader{}) ||
errors.Is(err, &ErrTooManyHeaders{}) ||
errors.Is(err, &ErrDisconnectedHeaders{})

if shouldPenalize {
f.logger.Debug("penalizing peer", "peerId", peerId, "err", err.Error())

penalizeErr := f.peerPenalizer.Penalize(ctx, peerId)
if penalizeErr != nil {
err = fmt.Errorf("%w: %w", penalizeErr, err)
}
return pkt.BlockHeadersPacket, nil
}

return nil, err
}

return headers, nil
}

func (f *fetcher) validateHeadersResponse(headers []*types.Header, start, end, amount uint64) error {
Expand Down Expand Up @@ -244,10 +269,10 @@ func (e ErrTooManyHeaders) Is(err error) bool {
}

type ErrDisconnectedHeaders struct {
currentHash libcommon.Hash
currentParentHash libcommon.Hash
currentHash common.Hash
currentParentHash common.Hash
currentNum uint64
parentHash libcommon.Hash
parentHash common.Hash
parentNum uint64
}

Expand Down
70 changes: 57 additions & 13 deletions polygon/p2p/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"math/big"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -41,14 +40,13 @@ func newServiceTest(t *testing.T, requestIdGenerator RequestIdGenerator) *servic
}

type serviceTest struct {
ctx context.Context
ctxCancel context.CancelFunc
t *testing.T
sentryClient *direct.MockSentryClient
service Service
headersRequestResponseMocksMu sync.Mutex
headersRequestResponseMocks map[uint64]requestResponseMock
peerEvents chan *sentry.PeerEvent
ctx context.Context
ctxCancel context.CancelFunc
t *testing.T
sentryClient *direct.MockSentryClient
service Service
headersRequestResponseMocks map[uint64]requestResponseMock
peerEvents chan *sentry.PeerEvent
}

// run is needed so that we can properly shut down tests involving the p2p service due to how the sentry multi
Expand Down Expand Up @@ -112,8 +110,6 @@ func (st *serviceTest) mockSentryStreams(mocks ...requestResponseMock) {
}

func (st *serviceTest) mockSentryInboundMessagesStream(mocks ...requestResponseMock) {
st.headersRequestResponseMocksMu.Lock()
defer st.headersRequestResponseMocksMu.Unlock()
for _, mock := range mocks {
st.headersRequestResponseMocks[mock.requestId] = mock
}
Expand Down Expand Up @@ -142,13 +138,12 @@ func (st *serviceTest) mockSentryInboundMessagesStream(mocks ...requestResponseM
return nil, err
}

st.headersRequestResponseMocksMu.Lock()
defer st.headersRequestResponseMocksMu.Unlock()
mock, ok := st.headersRequestResponseMocks[pkt.RequestId]
if !ok {
return nil, fmt.Errorf("unexpected request id: %d", pkt.RequestId)
}

delete(st.headersRequestResponseMocks, pkt.RequestId)
reqPeerId := PeerIdFromH512(req.PeerId)
if mock.wantRequestPeerId != reqPeerId {
return nil, fmt.Errorf("wantRequestPeerId != reqPeerId - %v vs %v", mock.wantRequestPeerId, reqPeerId)
Expand Down Expand Up @@ -371,6 +366,55 @@ func TestServiceFetchHeaders(t *testing.T) {
})
}

func TestServiceFetchHeadersWithChunking(t *testing.T) {
t.Parallel()

peerId := PeerIdFromUint64(1)
mockHeaders := newMockBlockHeaders(1999)
requestId1 := uint64(1234)
mockInboundMessages1 := []*sentry.InboundMessage{
{
Id: sentry.MessageId_BLOCK_HEADERS_66,
PeerId: peerId.H512(),
// 1024 headers in first response
Data: blockHeadersPacket66Bytes(t, requestId1, mockHeaders[:1025]),
},
}
mockRequestResponse1 := requestResponseMock{
requestId: requestId1,
mockResponseInboundMessages: mockInboundMessages1,
wantRequestPeerId: peerId,
wantRequestOriginNumber: 1,
wantRequestAmount: 1024,
}
requestId2 := uint64(1235)
mockInboundMessages2 := []*sentry.InboundMessage{
{
Id: sentry.MessageId_BLOCK_HEADERS_66,
PeerId: peerId.H512(),
// remaining 975 headers in second response
Data: blockHeadersPacket66Bytes(t, requestId2, mockHeaders[1025:]),
},
}
mockRequestResponse2 := requestResponseMock{
requestId: requestId2,
mockResponseInboundMessages: mockInboundMessages2,
wantRequestPeerId: peerId,
wantRequestOriginNumber: 1025,
wantRequestAmount: 975,
}

test := newServiceTest(t, newMockRequestGenerator(requestId1, requestId2))
test.mockSentryStreams(mockRequestResponse1, mockRequestResponse2)
test.run(func(ctx context.Context, t *testing.T) {
headers, err := test.service.FetchHeaders(ctx, 1, 2000, peerId)
require.NoError(t, err)
require.Len(t, headers, 1999)
require.Equal(t, uint64(1), headers[0].Number.Uint64())
require.Equal(t, uint64(1999), headers[len(headers)-1].Number.Uint64())
})
}

func TestServiceErrInvalidFetchHeadersRange(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit d18fdaa

Please sign in to comment.