diff --git a/eth/stagedsync/stage_polygon_sync.go b/eth/stagedsync/stage_polygon_sync.go index 9c9e94ac7d9..00d921d34f0 100644 --- a/eth/stagedsync/stage_polygon_sync.go +++ b/eth/stagedsync/stage_polygon_sync.go @@ -1514,7 +1514,11 @@ func (e *polygonSyncStageExecutionEngine) UpdateForkChoice(ctx context.Context, case <-ctx.Done(): return common.Hash{}, ctx.Err() case result := <-resultCh: - return result.latestValidHash, result.validationErr + err := result.validationErr + if err != nil { + err = fmt.Errorf("%w: %w", polygonsync.ErrForkChoiceUpdateBadBlock, err) + } + return result.latestValidHash, err } } diff --git a/polygon/sync/canonical_chain_builder.go b/polygon/sync/canonical_chain_builder.go index 62e43a00eec..8321246e783 100644 --- a/polygon/sync/canonical_chain_builder.go +++ b/polygon/sync/canonical_chain_builder.go @@ -36,8 +36,10 @@ type CanonicalChainBuilder interface { Tip() *types.Header Root() *types.Header HeadersInRange(start uint64, count uint64) []*types.Header - Prune(newRootNum uint64) error + PruneRoot(newRootNum uint64) error + PruneNode(hash libcommon.Hash) error Connect(ctx context.Context, headers []*types.Header) (newConnectedHeaders []*types.Header, err error) + LowestCommonAncestor(a, b libcommon.Hash) (*types.Header, bool) } type producerSlotIndex uint64 @@ -156,9 +158,9 @@ func (ccb *canonicalChainBuilder) HeadersInRange(start uint64, count uint64) []* return headers[offset : offset+count] } -func (ccb *canonicalChainBuilder) Prune(newRootNum uint64) error { +func (ccb *canonicalChainBuilder) PruneRoot(newRootNum uint64) error { if (newRootNum < ccb.root.header.Number.Uint64()) || (newRootNum > ccb.Tip().Number.Uint64()) { - return errors.New("canonicalChainBuilder.Prune: newRootNum outside of the canonical chain") + return errors.New("canonicalChainBuilder.PruneRoot: newRootNum outside of the canonical chain") } newRoot := ccb.tip @@ -170,6 +172,35 @@ func (ccb *canonicalChainBuilder) Prune(newRootNum uint64) error { return nil } +func (ccb *canonicalChainBuilder) PruneNode(hash libcommon.Hash) error { + if ccb.root.headerHash == hash { + return errors.New("canonicalChainBuilder.PruneNode: can't prune root node") + } + + var exists bool + ccb.enumerate(func(node *forkTreeNode) bool { + if node.headerHash != hash { + return true + } + + for idx, parentChild := range node.parent.children { + if parentChild.headerHash == hash { + exists = true + delete(node.parent.children, idx) + break + } + } + + return false + }) + if !exists { + return errors.New("canonicalChainBuilder.PruneNode: could not find node to prune") + } + + ccb.tip = ccb.recalcTip() // tip may have changed after prunning, re-calc + return nil +} + // compareForkTreeNodes compares 2 fork tree nodes. // It returns a positive number if the chain ending at node1 is "better" than the chain ending at node2. // The better node belongs to the canonical chain, and it has: @@ -195,6 +226,23 @@ func (ccb *canonicalChainBuilder) updateTipIfNeeded(tipCandidate *forkTreeNode) } } +func (ccb *canonicalChainBuilder) recalcTip() *forkTreeNode { + var tip *forkTreeNode + ccb.enumerate(func(node *forkTreeNode) bool { + if tip == nil { + tip = node + return true + } + + if compareForkTreeNodes(tip, node) < 0 { + tip = node + } + + return true + }) + return tip +} + // Connect connects a list of headers to the canonical chain builder tree. // Returns the list of newly connected headers (filtering out headers that already exist in the tree) // or an error in case the header is invalid or the header chain cannot reach any of the nodes in the tree. @@ -291,3 +339,59 @@ func (ccb *canonicalChainBuilder) Connect(ctx context.Context, headers []*types. return headers, nil } + +func (ccb *canonicalChainBuilder) LowestCommonAncestor(a, b libcommon.Hash) (*types.Header, bool) { + pathA := ccb.pathToRoot(a) + if len(pathA) == 0 { + // 'a' doesn't exist in the tree + return nil, false + } + + pathB := ccb.pathToRoot(b) + if len(pathB) == 0 { + // 'b' doesn't exist in the tree + return nil, false + } + + heightA := pathA[0].header.Number.Uint64() + heightB := pathB[0].header.Number.Uint64() + for heightA != heightB { + if heightA < heightB { + pathB = pathB[1:] + heightB = pathB[0].header.Number.Uint64() + } else if heightA > heightB { + pathA = pathA[1:] + heightA = pathA[0].header.Number.Uint64() + } + } + + for i := 0; i < len(pathA); i++ { + if pathA[i].headerHash == pathB[i].headerHash { + return pathA[i].header, true + } + } + + return nil, false +} + +func (ccb *canonicalChainBuilder) pathToRoot(from libcommon.Hash) []*forkTreeNode { + path := make([]*forkTreeNode, 0, ccb.Tip().Number.Uint64()-ccb.Root().Number.Uint64()) + pathToRootRec(ccb.root, from, &path) + return path +} + +func pathToRootRec(node *forkTreeNode, from libcommon.Hash, path *[]*forkTreeNode) bool { + if node.headerHash == from { + *path = append(*path, node) + return true + } + + for _, child := range node.children { + if pathToRootRec(child, from, path) { + *path = append(*path, node) + return true + } + } + + return false +} diff --git a/polygon/sync/canonical_chain_builder_mock.go b/polygon/sync/canonical_chain_builder_mock.go index d6259d7cc13..e64604fce94 100644 --- a/polygon/sync/canonical_chain_builder_mock.go +++ b/polygon/sync/canonical_chain_builder_mock.go @@ -42,11 +42,12 @@ func (m *MockCanonicalChainBuilder) EXPECT() *MockCanonicalChainBuilderMockRecor } // Connect mocks base method. -func (m *MockCanonicalChainBuilder) Connect(arg0 context.Context, arg1 []*types.Header) error { +func (m *MockCanonicalChainBuilder) Connect(arg0 context.Context, arg1 []*types.Header) ([]*types.Header, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Connect", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]*types.Header) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Connect indicates an expected call of Connect. @@ -62,19 +63,19 @@ type MockCanonicalChainBuilderConnectCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockCanonicalChainBuilderConnectCall) Return(arg0 error) *MockCanonicalChainBuilderConnectCall { - c.Call = c.Call.Return(arg0) +func (c *MockCanonicalChainBuilderConnectCall) Return(arg0 []*types.Header, arg1 error) *MockCanonicalChainBuilderConnectCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockCanonicalChainBuilderConnectCall) Do(f func(context.Context, []*types.Header) error) *MockCanonicalChainBuilderConnectCall { +func (c *MockCanonicalChainBuilderConnectCall) Do(f func(context.Context, []*types.Header) ([]*types.Header, error)) *MockCanonicalChainBuilderConnectCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockCanonicalChainBuilderConnectCall) DoAndReturn(f func(context.Context, []*types.Header) error) *MockCanonicalChainBuilderConnectCall { +func (c *MockCanonicalChainBuilderConnectCall) DoAndReturn(f func(context.Context, []*types.Header) ([]*types.Header, error)) *MockCanonicalChainBuilderConnectCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -155,40 +156,117 @@ func (c *MockCanonicalChainBuilderHeadersInRangeCall) DoAndReturn(f func(uint64, return c } -// Prune mocks base method. -func (m *MockCanonicalChainBuilder) Prune(arg0 uint64) error { +// LowestCommonAncestor mocks base method. +func (m *MockCanonicalChainBuilder) LowestCommonAncestor(arg0, arg1 common.Hash) (*types.Header, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LowestCommonAncestor", arg0, arg1) + ret0, _ := ret[0].(*types.Header) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// LowestCommonAncestor indicates an expected call of LowestCommonAncestor. +func (mr *MockCanonicalChainBuilderMockRecorder) LowestCommonAncestor(arg0, arg1 any) *MockCanonicalChainBuilderLowestCommonAncestorCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LowestCommonAncestor", reflect.TypeOf((*MockCanonicalChainBuilder)(nil).LowestCommonAncestor), arg0, arg1) + return &MockCanonicalChainBuilderLowestCommonAncestorCall{Call: call} +} + +// MockCanonicalChainBuilderLowestCommonAncestorCall wrap *gomock.Call +type MockCanonicalChainBuilderLowestCommonAncestorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCanonicalChainBuilderLowestCommonAncestorCall) Return(arg0 *types.Header, arg1 bool) *MockCanonicalChainBuilderLowestCommonAncestorCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCanonicalChainBuilderLowestCommonAncestorCall) Do(f func(common.Hash, common.Hash) (*types.Header, bool)) *MockCanonicalChainBuilderLowestCommonAncestorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCanonicalChainBuilderLowestCommonAncestorCall) DoAndReturn(f func(common.Hash, common.Hash) (*types.Header, bool)) *MockCanonicalChainBuilderLowestCommonAncestorCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// PruneNode mocks base method. +func (m *MockCanonicalChainBuilder) PruneNode(arg0 common.Hash) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PruneNode", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// PruneNode indicates an expected call of PruneNode. +func (mr *MockCanonicalChainBuilderMockRecorder) PruneNode(arg0 any) *MockCanonicalChainBuilderPruneNodeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PruneNode", reflect.TypeOf((*MockCanonicalChainBuilder)(nil).PruneNode), arg0) + return &MockCanonicalChainBuilderPruneNodeCall{Call: call} +} + +// MockCanonicalChainBuilderPruneNodeCall wrap *gomock.Call +type MockCanonicalChainBuilderPruneNodeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCanonicalChainBuilderPruneNodeCall) Return(arg0 error) *MockCanonicalChainBuilderPruneNodeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCanonicalChainBuilderPruneNodeCall) Do(f func(common.Hash) error) *MockCanonicalChainBuilderPruneNodeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCanonicalChainBuilderPruneNodeCall) DoAndReturn(f func(common.Hash) error) *MockCanonicalChainBuilderPruneNodeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// PruneRoot mocks base method. +func (m *MockCanonicalChainBuilder) PruneRoot(arg0 uint64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Prune", arg0) + ret := m.ctrl.Call(m, "PruneRoot", arg0) ret0, _ := ret[0].(error) return ret0 } -// Prune indicates an expected call of Prune. -func (mr *MockCanonicalChainBuilderMockRecorder) Prune(arg0 any) *MockCanonicalChainBuilderPruneCall { +// PruneRoot indicates an expected call of PruneRoot. +func (mr *MockCanonicalChainBuilderMockRecorder) PruneRoot(arg0 any) *MockCanonicalChainBuilderPruneRootCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prune", reflect.TypeOf((*MockCanonicalChainBuilder)(nil).Prune), arg0) - return &MockCanonicalChainBuilderPruneCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PruneRoot", reflect.TypeOf((*MockCanonicalChainBuilder)(nil).PruneRoot), arg0) + return &MockCanonicalChainBuilderPruneRootCall{Call: call} } -// MockCanonicalChainBuilderPruneCall wrap *gomock.Call -type MockCanonicalChainBuilderPruneCall struct { +// MockCanonicalChainBuilderPruneRootCall wrap *gomock.Call +type MockCanonicalChainBuilderPruneRootCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockCanonicalChainBuilderPruneCall) Return(arg0 error) *MockCanonicalChainBuilderPruneCall { +func (c *MockCanonicalChainBuilderPruneRootCall) Return(arg0 error) *MockCanonicalChainBuilderPruneRootCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockCanonicalChainBuilderPruneCall) Do(f func(uint64) error) *MockCanonicalChainBuilderPruneCall { +func (c *MockCanonicalChainBuilderPruneRootCall) Do(f func(uint64) error) *MockCanonicalChainBuilderPruneRootCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockCanonicalChainBuilderPruneCall) DoAndReturn(f func(uint64) error) *MockCanonicalChainBuilderPruneCall { +func (c *MockCanonicalChainBuilderPruneRootCall) DoAndReturn(f func(uint64) error) *MockCanonicalChainBuilderPruneRootCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/polygon/sync/canonical_chain_builder_test.go b/polygon/sync/canonical_chain_builder_test.go index 1f019e0f6fe..6978b6dd690 100644 --- a/polygon/sync/canonical_chain_builder_test.go +++ b/polygon/sync/canonical_chain_builder_test.go @@ -127,6 +127,7 @@ func (test *connectCCBTest) testConnect( } func TestCCBEmptyState(t *testing.T) { + t.Parallel() test, root := newConnectCCBTest(t) tip := test.builder.Tip() @@ -138,6 +139,7 @@ func TestCCBEmptyState(t *testing.T) { } func TestCCBConnectEmpty(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -146,6 +148,7 @@ func TestCCBConnectEmpty(t *testing.T) { // connect 0 to 0 func TestCCBConnectRoot(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -154,6 +157,7 @@ func TestCCBConnectRoot(t *testing.T) { // connect 1 to 0 func TestCCBConnectOneToRoot(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -163,6 +167,7 @@ func TestCCBConnectOneToRoot(t *testing.T) { // connect 1-2-3 to 0 func TestCCBConnectSomeToRoot(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -172,6 +177,7 @@ func TestCCBConnectSomeToRoot(t *testing.T) { // connect any subset of 0-1-2-3 to 0-1-2-3 func TestCCBConnectOverlapsFull(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -193,6 +199,7 @@ func TestCCBConnectOverlapsFull(t *testing.T) { // connect 0-1 to 0 func TestCCBConnectOverlapPartialOne(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -202,6 +209,7 @@ func TestCCBConnectOverlapPartialOne(t *testing.T) { // connect 2-3-4-5 to 0-1-2-3 func TestCCBConnectOverlapPartialSome(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -219,6 +227,7 @@ func TestCCBConnectOverlapPartialSome(t *testing.T) { // connect 2 to 0-1 at 0, then connect 10 to 0-1 func TestCCBConnectAltMainBecomesFork(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -237,6 +246,7 @@ func TestCCBConnectAltMainBecomesFork(t *testing.T) { // connect 1 to 0-2 at 0, then connect 10 to 0-1 func TestCCBConnectAltForkBecomesMain(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -255,6 +265,7 @@ func TestCCBConnectAltForkBecomesMain(t *testing.T) { // connect 10 and 11 to 1, then 20 and 22 to 2 one by one starting from a [0-1, 0-2] tree func TestCCBConnectAltForksAtLevel2(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -280,6 +291,7 @@ func TestCCBConnectAltForksAtLevel2(t *testing.T) { // connect 11 and 10 to 1, then 22 and 20 to 2 one by one starting from a [0-1, 0-2] tree // then connect 100 to 10, and 200 to 20 func TestCCBConnectAltForksAtLevel2Reverse(t *testing.T) { + t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) test, root := newConnectCCBTest(t) @@ -306,3 +318,212 @@ func TestCCBConnectAltForksAtLevel2Reverse(t *testing.T) { test.testConnect(ctx, []*types.Header{header100}, header100, []*types.Header{root, header1, header10, header100}, []*types.Header{header100}) test.testConnect(ctx, []*types.Header{header200}, header200, []*types.Header{root, header2, header20, header200}, []*types.Header{header200}) } + +func TestCCBPruneNode(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + // R(td:0) -> A(td:1) -> B(td:2) + // | + // +--------> X(td:2) -> Y(td:4) -> Z(td:6) + // | | + // | +--------> K(td:5) + // | + // +--------> P(td:3) + type example struct { + ccb CanonicalChainBuilder + headerR *types.Header + headerA *types.Header + headerB *types.Header + headerX *types.Header + headerY *types.Header + headerZ *types.Header + headerK *types.Header + headerP *types.Header + } + constructExample := func() example { + test, headerR := newConnectCCBTest(t) + ccb := test.builder + headerA := test.makeHeader(headerR, 1) + headerB := test.makeHeader(headerA, 1) + _, err := ccb.Connect(ctx, []*types.Header{headerA, headerB}) + require.NoError(t, err) + headerX := test.makeHeader(headerR, 2) + headerY := test.makeHeader(headerX, 2) + headerZ := test.makeHeader(headerY, 2) + _, err = ccb.Connect(ctx, []*types.Header{headerX, headerY, headerZ}) + require.NoError(t, err) + headerK := test.makeHeader(headerX, 3) + _, err = ccb.Connect(ctx, []*types.Header{headerK}) + require.NoError(t, err) + headerP := test.makeHeader(headerR, 3) + _, err = ccb.Connect(ctx, []*types.Header{headerP}) + require.NoError(t, err) + require.Equal(t, headerZ, ccb.Tip()) + return example{ + ccb: ccb, + headerR: headerR, + headerA: headerA, + headerB: headerB, + headerX: headerX, + headerY: headerY, + headerZ: headerZ, + headerK: headerK, + headerP: headerP, + } + } + t.Run("unknown hash", func(t *testing.T) { + ex := constructExample() + headerU := &types.Header{Number: big.NewInt(777)} + err := ex.ccb.PruneNode(headerU.Hash()) + require.Error(t, err) + require.Contains(t, err.Error(), "could not find node to prune") + }) + t.Run("can't prune root", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerR.Hash()) + require.Error(t, err) + require.Contains(t, err.Error(), "can't prune root node") + }) + t.Run("prune Z - change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerZ.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerK, ex.ccb.Tip()) + }) + t.Run("prune Y - change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerY.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerK, ex.ccb.Tip()) + }) + t.Run("prune K - no change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerK.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerZ, ex.ccb.Tip()) + }) + t.Run("prune X - no change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerX.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerP, ex.ccb.Tip()) + }) + t.Run("prune P - no change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerP.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerZ, ex.ccb.Tip()) + }) + t.Run("prune A - no change of tip", func(t *testing.T) { + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerA.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerZ, ex.ccb.Tip()) + }) + t.Run("prune P, prune Y, prune K, prune X, prune A", func(t *testing.T) { + // prune P - no change (tip Z) + ex := constructExample() + err := ex.ccb.PruneNode(ex.headerP.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerZ, ex.ccb.Tip()) + // prune Y - change (tip K) + err = ex.ccb.PruneNode(ex.headerY.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerK, ex.ccb.Tip()) + // prune K - change (tip X) + err = ex.ccb.PruneNode(ex.headerK.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerX, ex.ccb.Tip()) + // prune X - change (tip B) + err = ex.ccb.PruneNode(ex.headerX.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerB, ex.ccb.Tip()) + // prune A - change (tip R) - only root left + err = ex.ccb.PruneNode(ex.headerA.Hash()) + require.NoError(t, err) + require.Equal(t, ex.headerR, ex.ccb.Tip()) + require.Equal(t, ex.headerR, ex.ccb.Root()) + }) +} + +func TestCCBLowestCommonAncestor(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + // R(td:0) -> A(td:1) -> B(td:2) + // | + // +--------> X(td:2) -> Y(td:4) -> Z(td:6) + // | | + // | +--------> K(td:5) + // | + // +--------> P(td:3) + test, headerR := newConnectCCBTest(t) + ccb := test.builder + headerA := test.makeHeader(headerR, 1) + headerB := test.makeHeader(headerA, 1) + _, err := ccb.Connect(ctx, []*types.Header{headerA, headerB}) + require.NoError(t, err) + headerX := test.makeHeader(headerR, 2) + headerY := test.makeHeader(headerX, 2) + headerZ := test.makeHeader(headerY, 2) + _, err = ccb.Connect(ctx, []*types.Header{headerX, headerY, headerZ}) + require.NoError(t, err) + headerK := test.makeHeader(headerX, 3) + _, err = ccb.Connect(ctx, []*types.Header{headerK}) + require.NoError(t, err) + headerP := test.makeHeader(headerR, 3) + _, err = ccb.Connect(ctx, []*types.Header{headerP}) + require.NoError(t, err) + require.Equal(t, headerZ, ccb.Tip()) + headerU := &types.Header{Number: big.NewInt(777)} + headerU2 := &types.Header{Number: big.NewInt(999)} + t.Run("LCA(R,U)=nil,false", func(t *testing.T) { + assertLca(t, ccb, headerR, headerU, nil, false) + }) + t.Run("LCA(U,R)=nil,false", func(t *testing.T) { + assertLca(t, ccb, headerU, headerR, nil, false) + }) + t.Run("LCA(U,U)=nil,false", func(t *testing.T) { + assertLca(t, ccb, headerU, headerU, nil, false) + }) + t.Run("LCA(U2,U)=nil,false", func(t *testing.T) { + assertLca(t, ccb, headerU2, headerU, nil, false) + }) + t.Run("LCA(R,R)=R", func(t *testing.T) { + assertLca(t, ccb, headerR, headerR, headerR, true) + }) + t.Run("LCA(Y,Y)=Y", func(t *testing.T) { + assertLca(t, ccb, headerY, headerY, headerY, true) + }) + t.Run("LCA(Y,Z)=Y", func(t *testing.T) { + assertLca(t, ccb, headerY, headerZ, headerY, true) + }) + t.Run("LCA(X,Y)=X", func(t *testing.T) { + assertLca(t, ccb, headerX, headerY, headerX, true) + }) + t.Run("LCA(R,Z)=R", func(t *testing.T) { + assertLca(t, ccb, headerR, headerZ, headerR, true) + }) + t.Run("LCA(R,A)=R", func(t *testing.T) { + assertLca(t, ccb, headerR, headerA, headerR, true) + }) + t.Run("LCA(R,P)=R", func(t *testing.T) { + assertLca(t, ccb, headerR, headerP, headerR, true) + }) + t.Run("LCA(K,B)=R", func(t *testing.T) { + assertLca(t, ccb, headerK, headerB, headerR, true) + }) + t.Run("LCA(X,A)=R", func(t *testing.T) { + assertLca(t, ccb, headerX, headerA, headerR, true) + }) + t.Run("LCA(Z,K)=X", func(t *testing.T) { + assertLca(t, ccb, headerZ, headerK, headerX, true) + }) +} + +func assertLca(t *testing.T, ccb CanonicalChainBuilder, a, b, wantLca *types.Header, wantOk bool) { + lca, ok := ccb.LowestCommonAncestor(a.Hash(), b.Hash()) + require.Equal(t, wantOk, ok) + require.Equal(t, wantLca, lca) +} diff --git a/polygon/sync/execution_client.go b/polygon/sync/execution_client.go index 6e5bca2d4c0..585eeef6173 100644 --- a/polygon/sync/execution_client.go +++ b/polygon/sync/execution_client.go @@ -32,7 +32,8 @@ import ( eth1utils "github.com/erigontech/erigon/turbo/execution/eth1/eth1_utils" ) -var errForkChoiceUpdateFailure = errors.New("fork choice update failed") +var ErrForkChoiceUpdateFailure = errors.New("fork choice update failure") +var ErrForkChoiceUpdateBadBlock = errors.New("fork choice update bad block") type ExecutionClient interface { InsertBlocks(ctx context.Context, blocks []*types.Block) error @@ -102,15 +103,24 @@ func (e *executionClient) UpdateForkChoice(ctx context.Context, tip *types.Heade latestValidHash = gointerfaces.ConvertH256ToHash(response.LatestValidHash) } - if len(response.ValidationError) > 0 { - return latestValidHash, fmt.Errorf("%w: validationErr=%s", errForkChoiceUpdateFailure, response.Status) + switch response.Status { + case executionproto.ExecutionStatus_Success: + return latestValidHash, nil + case executionproto.ExecutionStatus_BadBlock: + return latestValidHash, fmt.Errorf( + "%w: status=%d, validationErr='%s'", + ErrForkChoiceUpdateBadBlock, + response.Status, + response.ValidationError, + ) + default: + return latestValidHash, fmt.Errorf( + "%w: status=%d, validationErr='%s'", + ErrForkChoiceUpdateFailure, + response.Status, + response.ValidationError, + ) } - - if response.Status != executionproto.ExecutionStatus_Success { - return latestValidHash, fmt.Errorf("%w: status=%s", errForkChoiceUpdateFailure, response.Status) - } - - return latestValidHash, nil } func (e *executionClient) CurrentHeader(ctx context.Context) (*types.Header, error) { diff --git a/polygon/sync/sync.go b/polygon/sync/sync.go index f16eef31676..bee44daf9c9 100644 --- a/polygon/sync/sync.go +++ b/polygon/sync/sync.go @@ -22,6 +22,8 @@ import ( "fmt" "time" + "github.com/hashicorp/golang-lru/v2/simplelru" + "github.com/erigontech/erigon-lib/common" "github.com/erigontech/erigon-lib/log/v3" "github.com/erigontech/erigon/core/types" @@ -38,6 +40,7 @@ type heimdallSynchronizer interface { type bridgeSynchronizer interface { Synchronize(ctx context.Context, blockNum uint64) error Unwind(ctx context.Context, blockNum uint64) error + ProcessNewBlocks(ctx context.Context, blocks []*types.Block) error } type Sync struct { @@ -51,6 +54,7 @@ type Sync struct { heimdallSync heimdallSynchronizer bridgeSync bridgeSynchronizer events <-chan Event + badBlocks *simplelru.LRU[common.Hash, struct{}] logger log.Logger } @@ -67,6 +71,11 @@ func NewSync( events <-chan Event, logger log.Logger, ) *Sync { + badBlocksLru, err := simplelru.NewLRU[common.Hash, struct{}](1024, nil) + if err != nil { + panic(err) + } + return &Sync{ store: store, execution: execution, @@ -78,6 +87,7 @@ func NewSync( heimdallSync: heimdallSync, bridgeSync: bridgeSync, events: events, + badBlocks: badBlocksLru, logger: logger, } } @@ -146,8 +156,10 @@ func (s *Sync) handleMilestoneTipMismatch( ) } - if err = s.commitExecution(ctx, newTip, newTip); err != nil { - return err + if err := s.commitExecution(ctx, newTip, newTip); err != nil { + // note: if we face a failure during execution of finalized waypoints blocks, it means that + // we're wrong and the blocks are not considered as bad blocks, so we should terminate + return s.handleWaypointExecutionErr(ctx, ccBuilder.Root(), err) } ccBuilder.Reset(newTip) @@ -177,7 +189,7 @@ func (s *Sync) applyNewMilestoneOnTip( return s.handleMilestoneTipMismatch(ctx, ccBuilder, milestone) } - return ccBuilder.Prune(milestone.EndBlock().Uint64()) + return ccBuilder.PruneRoot(milestone.EndBlock().Uint64()) } func (s *Sync) applyNewBlockOnTip( @@ -193,6 +205,28 @@ func (s *Sync) applyNewBlockOnTip( return nil } + if s.badBlocks.Contains(newBlockHeaderHash) { + s.logger.Warn(syncLogPrefix("bad block received from peer"), + "blockHash", newBlockHeaderHash, + "blockNum", newBlockHeaderNum, + "peerId", event.PeerId, + ) + s.maybePenalizePeerOnBadBlockEvent(ctx, event) + return nil + } + + if s.badBlocks.Contains(newBlockHeader.ParentHash) { + s.logger.Warn(syncLogPrefix("block with bad parent received from peer"), + "blockHash", newBlockHeaderHash, + "blockNum", newBlockHeaderNum, + "parentHash", newBlockHeader.ParentHash, + "peerId", event.PeerId, + ) + s.badBlocks.Add(newBlockHeaderHash, struct{}{}) + s.maybePenalizePeerOnBadBlockEvent(ctx, event) + return nil + } + s.logger.Debug( syncLogPrefix("applying new block event"), "blockNum", newBlockHeaderNum, @@ -279,11 +313,9 @@ func (s *Sync) applyNewBlockOnTip( } newTip := ccBuilder.Tip() - firstConnectedHeader := newConnectedHeaders[0] - if newTip != oldTip && oldTip.Hash() != firstConnectedHeader.ParentHash { - // forks have changed, we need to unwind unwindable data - blockNum := max(1, firstConnectedHeader.Number.Uint64()) - 1 - if err := s.bridgeSync.Unwind(ctx, blockNum); err != nil { + firstNewConnectedHeader := newConnectedHeaders[0] + if newTip != oldTip && oldTip.Hash() != firstNewConnectedHeader.ParentHash { + if err := s.handleBridgeOnForkChange(ctx, ccBuilder, oldTip); err != nil { return err } } @@ -306,10 +338,19 @@ func (s *Sync) applyNewBlockOnTip( } if newTip == oldTip { + lastConnectedNum := newConnectedHeaders[len(newConnectedHeaders)-1].Number.Uint64() + if tipNum := newTip.Number.Uint64(); lastConnectedNum > tipNum { + return s.handleBridgeOnBlocksInsertAheadOfTip(ctx, tipNum, lastConnectedNum) + } + return nil } if err := s.commitExecution(ctx, newTip, ccBuilder.Root()); err != nil { + if errors.Is(err, ErrForkChoiceUpdateBadBlock) { + return s.handleBadBlockErr(ctx, ccBuilder, event, firstNewConnectedHeader, oldTip, err) + } + return err } @@ -339,6 +380,17 @@ func (s *Sync) applyNewBlockHashesOnTip( continue } + if s.badBlocks.Contains(hashOrNum.Hash) { + // note: we do not penalize peer for bad blocks on new block hash events since they have + // not necessarily been executed by the peer but just propagated as per the devp2p spec + s.logger.Warn(syncLogPrefix("bad block hash received from peer"), + "blockHash", hashOrNum.Hash, + "blockNum", hashOrNum.Number, + "peerId", event.PeerId, + ) + return nil + } + s.logger.Debug( syncLogPrefix("applying new block hash event"), "blockNum", hashOrNum.Number, @@ -391,6 +443,157 @@ func (s *Sync) publishNewBlock(ctx context.Context, block *types.Block) { s.p2pService.PublishNewBlock(block, td) } +func (s *Sync) handleBridgeOnForkChange(ctx context.Context, ccb CanonicalChainBuilder, oldTip *types.Header) error { + // forks have changed, we need to unwind unwindable data + newTip := ccb.Tip() + s.logger.Debug( + syncLogPrefix("handling bridge on fork change"), + "oldNum", oldTip.Number.Uint64(), + "oldHash", oldTip.Hash(), + "newNum", newTip.Number.Uint64(), + "newHash", newTip.Hash(), + ) + + // Find unwind point + lca, ok := ccb.LowestCommonAncestor(newTip.Hash(), oldTip.Hash()) + if !ok { + return errors.New("could not find lowest common ancestor of old and new tip") + } + + return s.reorganiseBridge(ctx, ccb, lca) +} + +func (s *Sync) reorganiseBridge(ctx context.Context, ccb CanonicalChainBuilder, forksLca *types.Header) error { + newTip := ccb.Tip() + newTipNum := ccb.Tip().Number.Uint64() + unwindPoint := forksLca.Number.Uint64() + s.logger.Debug( + syncLogPrefix("reorganise bridge"), + "newTip", newTipNum, + "newTipHash", newTip.Hash(), + "unwindPointNum", unwindPoint, + "unwindPointHash", forksLca.Hash(), + ) + + if newTipNum < unwindPoint { // defensive check against underflow & unexpected newTipNum and unwindPoint + return fmt.Errorf("unexpected newTipNum <= unwindPoint: %d < %d", newTipNum, unwindPoint) + } + + // 1. Do the unwind from the old tip (on the old canonical fork) to the unwindPoint + if err := s.bridgeSync.Unwind(ctx, unwindPoint); err != nil { + return err + } + + // 2. Replay the new canonical blocks from the unwindPoint+1 to the new tip (on the new canonical fork). Note, + // that there may be a case where the newTip == unwindPoint in which case the below will be a no-op. + if newTipNum == unwindPoint { + return nil + } + + start := unwindPoint + 1 + amount := newTipNum - start + 1 + canonicalHeaders := ccb.HeadersInRange(start, amount) + canonicalBlocks := make([]*types.Block, len(canonicalHeaders)) + for i, header := range canonicalHeaders { + canonicalBlocks[i] = types.NewBlockWithHeader(header) + } + if err := s.bridgeSync.ProcessNewBlocks(ctx, canonicalBlocks); err != nil { + return err + } + + return s.bridgeSync.Synchronize(ctx, newTipNum) +} + +func (s *Sync) handleBridgeOnBlocksInsertAheadOfTip(ctx context.Context, tipNum, lastInsertedNum uint64) error { + // this is a hack that should disappear when changing the bridge to not track blocks (future work) + // make sure the bridge does not go past the tip (it may happen when we insert blocks from another fork that + // has a higher block number than the canonical tip but lower difficulty) - this is to prevent the bridge + // from recording incorrect bor txn hashes + s.logger.Debug( + syncLogPrefix("unwinding bridge due to inserting headers past the tip"), + "tip", tipNum, + "lastInsertedNum", lastInsertedNum, + ) + + // wait for the insert blocks flush + if err := s.store.Flush(ctx); err != nil { + return err + } + + // wait for the bridge processing + if err := s.bridgeSync.Synchronize(ctx, lastInsertedNum); err != nil { + return err + } + + return s.bridgeSync.Unwind(ctx, tipNum) +} + +func (s *Sync) handleBadBlockErr( + ctx context.Context, + ccb CanonicalChainBuilder, + event EventNewBlock, + firstNewConnectedHeader *types.Header, + oldTip *types.Header, + badBlockErr error, +) error { + badTip := ccb.Tip() + badTipHash := badTip.Hash() + oldTipNum := oldTip.Number.Uint64() + oldTipHash := oldTip.Hash() + s.logger.Warn( + syncLogPrefix("handling bad block after execution"), + "peerId", event.PeerId, + "badTipNum", badTip.Number.Uint64(), + "badTipHash", badTipHash, + "oldTipNum", oldTipNum, + "oldTipHash", oldTipHash, + "firstNewConnectedNum", firstNewConnectedHeader.Number.Uint64(), + "firstNewConnectedHash", firstNewConnectedHeader.Hash(), + "err", badBlockErr, + ) + + // 1. Mark block as bad and penalize peer + s.badBlocks.Add(event.NewBlock.Hash(), struct{}{}) + s.maybePenalizePeerOnBadBlockEvent(ctx, event) + + // 2. Find unwind point + lca, ok := ccb.LowestCommonAncestor(oldTipHash, badTip.Hash()) + if !ok { + return errors.New("could not find lowest common ancestor of old and new tip") + } + + // 3. Prune newly inserted nodes in the tree => should roll back the ccb to the old tip + if err := ccb.PruneNode(firstNewConnectedHeader.Hash()); err != nil { + return err + } + + newTip := ccb.Tip() + newTipNum := newTip.Number.Uint64() + newTipHash := newTip.Hash() + if oldTipHash != newTipHash { // defensive check for unexpected behaviour + return fmt.Errorf( + "old tip hash does not match new tip hash (%d,%s) vs (%d, %s)", + oldTipNum, oldTipHash, newTipNum, newTipHash, + ) + } + + // 4. Update bridge + return s.reorganiseBridge(ctx, ccb, lca) +} + +func (s *Sync) maybePenalizePeerOnBadBlockEvent(ctx context.Context, event EventNewBlock) { + if event.Source == EventSourceP2PNewBlockHashes { + // note: we do not penalize peer for bad blocks on new block hash events since they have + // not necessarily been executed by the peer but just propagated as per the devp2p spec + return + } + + s.logger.Debug(syncLogPrefix("penalizing peer for bad block"), "peerId", event.PeerId) + if err := s.p2pService.Penalize(ctx, event.PeerId); err != nil { + s.logger.Debug(syncLogPrefix("issue with penalizing peer for bad block"), "peerId", event.PeerId, "err", err) + } +} + // // TODO (subsequent PRs) - unit test initial sync + on new event cases // @@ -408,6 +611,10 @@ func (s *Sync) Run(ctx context.Context) error { return err } + inactivityDuration := 30 * time.Second + lastProcessedEventTime := time.Now() + inactivityTicker := time.NewTicker(inactivityDuration) + defer inactivityTicker.Stop() for { select { case event := <-s.events: @@ -424,7 +631,17 @@ func (s *Sync) Run(ctx context.Context) error { if err = s.applyNewBlockHashesOnTip(ctx, event.AsNewBlockHashes(), ccBuilder); err != nil { return err } + default: + panic(fmt.Sprintf("unexpected event type: %v", event.Type)) + } + + lastProcessedEventTime = time.Now() + case <-inactivityTicker.C: + if time.Since(lastProcessedEventTime) < inactivityDuration { + continue } + + s.logger.Info(syncLogPrefix("waiting for chain tip events...")) case <-ctx.Done(): return ctx.Err() } @@ -552,15 +769,40 @@ func (s *Sync) sync(ctx context.Context, tip *types.Header, tipDownloader tipDow break } - tip = newResult.latestTip - if err = s.commitExecution(ctx, tip, tip); err != nil { + newTip := newResult.latestTip + if err := s.commitExecution(ctx, newTip, newTip); err != nil { + // note: if we face a failure during execution of finalized waypoints blocks, it means that + // we're wrong and the blocks are not considered as bad blocks, so we should terminate + err = s.handleWaypointExecutionErr(ctx, tip, err) return syncToTipResult{}, err } + + tip = newTip } return syncToTipResult{latestTip: tip, latestWaypoint: latestWaypoint}, nil } +func (s *Sync) handleWaypointExecutionErr(ctx context.Context, lastCorrectTip *types.Header, execErr error) error { + s.logger.Debug( + syncLogPrefix("waypoint execution err"), + "lastCorrectTipNum", lastCorrectTip.Number.Uint64(), + "lastCorrectTipHash", lastCorrectTip.Hash(), + "execErr", execErr, + ) + + if !errors.Is(execErr, ErrForkChoiceUpdateBadBlock) { + return execErr + } + + // if it is a bad block try to unwind the bridge to the last known tip so we leave it in a good state + if bridgeUnwindErr := s.bridgeSync.Unwind(ctx, lastCorrectTip.Number.Uint64()); bridgeUnwindErr != nil { + return fmt.Errorf("%w: %w", bridgeUnwindErr, execErr) + } + + return execErr +} + func (s *Sync) ignoreFetchBlocksErrOnTipEvent(err error) bool { return errors.Is(err, &p2p.ErrIncompleteHeaders{}) || errors.Is(err, &p2p.ErrNonSequentialHeaderNumbers{}) ||