diff --git a/arbitrum/handler_p2p.go b/arbitrum/handler_p2p.go new file mode 100644 index 0000000000..9a6146b35e --- /dev/null +++ b/arbitrum/handler_p2p.go @@ -0,0 +1,563 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package arbitrum + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/protocols/arb" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/trie" +) + +type SyncHelper interface { + LastConfirmed() (*types.Header, uint64, error) + LastCheckpoint() (*types.Header, error) + CheckpointSupported(*types.Header) (bool, error) + ValidateConfirmed(*types.Header, uint64) (bool, error) +} + +type Peer struct { + mutex sync.Mutex + arb *arb.Peer + eth *eth.Peer + snap *snap.Peer +} + +func NewPeer() *Peer { + return &Peer{} +} + +type protocolHandler struct { + chain *core.BlockChain + eventMux *event.TypeMux + downloader *downloader.Downloader + db ethdb.Database + helper SyncHelper + + peersLock sync.RWMutex + peers map[string]*Peer + + beaconBackFiller downloader.Backfiller + + confirmed *types.Header + checkpoint *types.Header + syncedBlockNum uint64 // blocks that were synced by skeleton-downloader + syncedCond *sync.Cond + headersLock sync.RWMutex + + syncing atomic.Bool +} + +func NewProtocolHandler(db ethdb.Database, bc *core.BlockChain, helper SyncHelper, syncing bool) *protocolHandler { + evMux := new(event.TypeMux) + p := &protocolHandler{ + chain: bc, + eventMux: evMux, + db: db, + helper: helper, + peers: make(map[string]*Peer), + } + p.syncedCond = sync.NewCond(&p.headersLock) + p.syncing.Store(syncing) + backfillerCreator := func(dl *downloader.Downloader) downloader.Backfiller { + success := func() { + p.syncing.Store(false) + log.Info("DOWNLOADER DONE") + } + p.beaconBackFiller = downloader.NewBeaconBackfiller(dl, success) + return (*filler)(p) + } + p.downloader = downloader.New(db, evMux, bc, nil, p.peerDrop, backfillerCreator) + return p +} + +func (h *protocolHandler) MakeProtocols(dnsdisc enode.Iterator) []p2p.Protocol { + protos := eth.MakeProtocols((*ethHandler)(h), h.chain.Config().ChainID.Uint64(), nil) + protos = append(protos, snap.MakeProtocols((*snapHandler)(h), nil)...) + protos = append(protos, arb.MakeProtocols((*arbHandler)(h), dnsdisc)...) + return protos +} + +func (h *protocolHandler) getCreatePeer(id string) *Peer { + h.peersLock.Lock() + defer h.peersLock.Unlock() + peer := h.peers[id] + if peer != nil { + return peer + } + peer = NewPeer() + h.peers[id] = peer + return peer +} + +func (h *protocolHandler) waitBlockSync(num uint64) error { + h.headersLock.Lock() + defer h.headersLock.Unlock() + for { + if h.syncedBlockNum >= num { + break + } + h.syncedCond.Wait() + } + return nil +} + +func (h *protocolHandler) getRemovePeer(id string) *Peer { + h.peersLock.Lock() + defer h.peersLock.Unlock() + peer := h.peers[id] + if peer != nil { + h.peers[id] = nil + } + return peer +} + +func (h *protocolHandler) getPeer(id string) *Peer { + h.peersLock.RLock() + defer h.peersLock.RUnlock() + return h.peers[id] +} + +func (h *protocolHandler) peerDrop(id string) { + log.Info("dropping peer", "id", id) + hPeer := h.getRemovePeer(id) + if hPeer == nil { + return + } + hPeer.mutex.Lock() + defer hPeer.mutex.Unlock() + hPeer.arb = nil + if hPeer.eth != nil { + hPeer.eth.Disconnect(p2p.DiscUselessPeer) + err := h.downloader.UnregisterPeer(id) + if err != nil { + log.Warn("failed deregistering peer from downloader", "err", err) + } + hPeer.eth = nil + } + if hPeer.snap != nil { + err := h.downloader.SnapSyncer.Unregister(id) + if err != nil { + log.Warn("failed deregistering peer from downloader", "err", err) + } + } +} + +func (h *protocolHandler) getHeaders() (*types.Header, *types.Header) { + h.peersLock.RLock() + defer h.peersLock.RUnlock() + return h.checkpoint, h.confirmed +} + +func (h *protocolHandler) advanceCheckpoint(checkpoint *types.Header) { + h.peersLock.Lock() + defer h.peersLock.Unlock() + if h.checkpoint != nil { + compare := h.checkpoint.Number.Cmp(checkpoint.Number) + if compare > 0 { + return + } + if compare == 0 { + if h.checkpoint.Hash() != checkpoint.Hash() { + log.Error("arbitrum_p2p: hash for checkpoint changed", "number", checkpoint.Number, "old", h.checkpoint.Hash(), "new", checkpoint.Hash()) + } else { + return + } + } + } + if h.confirmed == nil || checkpoint.Number.Cmp(h.confirmed.Number) > 0 { + confirmedNum := common.Big0 + if h.confirmed != nil { + confirmedNum = h.confirmed.Number + } + log.Error("arbitrum_p2p: trying to move checkpont ahead of confirmed", "number", checkpoint.Number, "confirmed", confirmedNum) + return + } + h.checkpoint = checkpoint + log.Info("arbitrum_p2p: checkpoint", "number", checkpoint.Number, "hash", checkpoint.Hash()) + h.downloader.PivotSync(h.confirmed, h.checkpoint) +} + +func (h *protocolHandler) advanceConfirmed(confirmed *types.Header) { + h.peersLock.Lock() + defer h.peersLock.Unlock() + if h.confirmed != nil { + compare := h.confirmed.Number.Cmp(confirmed.Number) + if compare > 0 { + return + } + if compare == 0 { + if h.confirmed.Hash() != confirmed.Hash() { + log.Error("arbitrum_p2p: hash for confirmed changed", "number", confirmed.Number, "old", h.confirmed.Hash(), "new", confirmed.Hash()) + } else { + return + } + } + } + h.confirmed = confirmed + log.Info("arbitrum_p2p: confirmed", "number", confirmed.Number, "hash", confirmed.Hash()) + h.downloader.PivotSync(h.confirmed, h.checkpoint) +} + +type filler protocolHandler + +func (h *filler) Suspend() *types.Header { + h.headersLock.Lock() + defer h.headersLock.Unlock() + if h.syncedBlockNum > 0 && h.syncing.Load() { + log.Warn("arbitrum_p2p: suspend while syncing", "head", h.syncedBlockNum) + } + return h.beaconBackFiller.Suspend() +} + +func (h *filler) Resume() { + defer h.beaconBackFiller.Resume() + head, err := h.downloader.SkeletonHead() + if err != nil || head == nil { + log.Error("arbitrum_p2p: error from SkeletonHead", "err", err) + return + } + if !head.Number.IsUint64() { + log.Error("arbitrum_p2p: syncedBlockNum bad number", "num", head.Number) + return + } + h.headersLock.Lock() + if h.confirmed.Number.Cmp(head.Number) < 0 { + // confirmed only moves forward and is used as head for sync.. somerthing bad already happened + log.Error("arbitrum_p2p: skeleton head ahead of confirmed", "skeleton", head.Number, "confirmed", h.confirmed.Number) + } + h.syncedBlockNum = head.Number.Uint64() + h.syncedCond.Broadcast() + h.headersLock.Unlock() + log.Trace("arbitrum_p2p: resume", "skeletonhead", h.syncedBlockNum) +} + +func (h *filler) SetMode(mode downloader.SyncMode) { + h.beaconBackFiller.SetMode(mode) +} + +type arbHandler protocolHandler + +func (h *arbHandler) PeerInfo(id enode.ID) interface{} { + return nil +} + +func (h *arbHandler) HandleLastConfirmed(peer *arb.Peer, confirmed *types.Header, node uint64) { + protoHandler := (*protocolHandler)(h) + validated := false + valid := false + current, _ := protoHandler.getHeaders() + if current != nil { + if confirmed.Number.Cmp(current.Number) == 0 { + validated = true + valid = current.Hash() == confirmed.Hash() + } + } + if !validated { + var err error + valid, err = h.helper.ValidateConfirmed(confirmed, node) + if err != nil { + log.Error("error in validate confirmed", "id", peer.ID(), "err", err) + return + } + } + if !valid { + protoHandler.peerDrop(peer.ID()) + return + } + hPeer := protoHandler.getPeer(peer.ID()) + if hPeer == nil { + log.Warn("hPeer not found on HandleLastConfirmed") + return + } + peer.RequestCheckpoint(nil) + protoHandler.advanceConfirmed(confirmed) +} + +func (h *arbHandler) HandleCheckpoint(peer *arb.Peer, checkpoint *types.Header, supported bool) { + protoHandler := (*protocolHandler)(h) + log.Error("got checkpoint", "from", peer.ID(), "checkpoint", checkpoint, "supported", supported) + if !supported { + return + } + if !h.syncing.Load() { + return + } + if !checkpoint.Number.IsUint64() { + log.Warn("got bad header from peer - number not uint64", "peer", peer.ID()) + protoHandler.peerDrop(peer.ID()) + return + } + number := checkpoint.Number.Uint64() + log.Info("handler_p2p: handle checkpoint - before", "peer", peer.ID()) + protoHandler.waitBlockSync(number) + log.Info("handler_p2p: handle checkpoint - after", "peer", peer.ID()) + if !h.syncing.Load() { + return + } + canonical := rawdb.ReadCanonicalHash(h.db, number) + if canonical == (common.Hash{}) { + skeleton := rawdb.ReadSkeletonHeader(h.db, number) + if skeleton == nil { + log.Error("arbitrum handler_p2p: canonical not found", "number", number, "peer", peer.ID()) + return + } + canonical = skeleton.Hash() + } + if canonical != checkpoint.Hash() { + log.Warn("got bad header from peer - bad hash", "peer", peer.ID(), "number", number, "expected", canonical, "peer", checkpoint.Hash()) + protoHandler.peerDrop(peer.ID()) + return + } + protoHandler.advanceCheckpoint(checkpoint) +} + +func (h *arbHandler) LastConfirmed() (*types.Header, uint64, error) { + return h.helper.LastConfirmed() +} + +func (h *arbHandler) LastCheckpoint() (*types.Header, error) { + return h.helper.LastCheckpoint() +} + +func (h *arbHandler) CheckpointSupported(checkpoint *types.Header) (bool, error) { + return h.helper.CheckpointSupported(checkpoint) +} + +func (h *arbHandler) RunPeer(peer *arb.Peer, handler arb.Handler) error { + //id := h.peers[] + hPeer := (*protocolHandler)(h).getCreatePeer(peer.ID()) + hPeer.mutex.Lock() + if hPeer.arb != nil { + hPeer.mutex.Unlock() + return fmt.Errorf("peer id already known") + } + hPeer.arb = peer + hPeer.mutex.Unlock() + if h.syncing.Load() { + err := peer.RequestLastConfirmed() + if err != nil { + return err + } + } + return handler(peer) +} + +// ethHandler implements the eth.Backend interface to handle the various network +// packets that are sent as replies or broadcasts. +type ethHandler protocolHandler + +func (h *ethHandler) Chain() *core.BlockChain { return h.chain } + +type dummyTxPool struct{} + +func (d *dummyTxPool) Get(hash common.Hash) *types.Transaction { + return nil +} + +func (h *ethHandler) TxPool() eth.TxPool { return &dummyTxPool{} } + +// RunPeer is invoked when a peer joins on the `eth` protocol. +func (h *ethHandler) RunPeer(peer *eth.Peer, hand eth.Handler) error { + hPeer := (*protocolHandler)(h).getCreatePeer(peer.ID()) + hPeer.mutex.Lock() + if hPeer.eth != nil { + hPeer.mutex.Unlock() + return fmt.Errorf("peer id already known") + } + hPeer.eth = peer + err := h.downloader.RegisterPeer(peer.ID(), peer.Version(), peer) + hPeer.mutex.Unlock() + if err != nil { + peer.Log().Error("Failed to register peer in eth syncer", "err", err) + return err + } + return hand(peer) +} + +// PeerInfo retrieves all known `eth` information about a peer. +func (h *ethHandler) PeerInfo(id enode.ID) interface{} { + return nil +} + +// AcceptTxs retrieves whether transaction processing is enabled on the node +// or if inbound transactions should simply be dropped. +func (h *ethHandler) AcceptTxs() bool { + return false +} + +// Handle is invoked from a peer's message handler when it receives a new remote +// message that the handler couldn't consume and serve itself. +func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error { + // Consume any broadcasts and announces, forwarding the rest to the downloader + switch packet := packet.(type) { + case *eth.NewBlockHashesPacket: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + + case *eth.NewBlockPacket: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + + case *eth.NewPooledTransactionHashesPacket: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + + case *eth.TransactionsPacket: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + + case *eth.PooledTransactionsPacket: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + default: + return fmt.Errorf("unexpected eth packet type for nitro: %T", packet) + } +} + +type snapHandler protocolHandler + +func (h *snapHandler) ContractCodeWithPrefix(codeHash common.Hash) ([]byte, error) { + return h.chain.ContractCodeWithPrefix(codeHash) +} + +func (h *snapHandler) TrieDB() *trie.Database { + return h.chain.StateCache().TrieDB() +} + +func (h *snapHandler) Snapshot(root common.Hash) snapshot.Snapshot { + return nil +} + +type trieIteratorWrapper struct { + iter *trie.Iterator + triedb *trie.Database +} + +func (i trieIteratorWrapper) Next() bool { return i.iter.Next() } +func (i trieIteratorWrapper) Error() error { return i.iter.Err } +func (i trieIteratorWrapper) Hash() common.Hash { return common.BytesToHash(i.iter.Key) } +func (i trieIteratorWrapper) Release() { i.triedb.Close() } + +type trieAccountIterator struct { + trieIteratorWrapper +} + +func (i trieAccountIterator) Account() []byte { return i.iter.Value } + +func (h *snapHandler) AccountIterator(root, account common.Hash) (snapshot.AccountIterator, error) { + triedb := trie.NewDatabase(h.db, h.chain.CacheConfig().TriedbConfig()) + t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) + if err != nil { + log.Error("Failed to open trie", "root", root, "err", err) + return nil, err + } + accIter, err := t.NodeIterator(account[:]) + if err != nil { + log.Error("Failed to open nodeIterator for trie", "root", root, "err", err) + return nil, err + } + return trieAccountIterator{trieIteratorWrapper{ + iter: trie.NewIterator((accIter)), + triedb: triedb, + }}, nil +} + +type trieStoreageIterator struct { + trieIteratorWrapper +} + +func (i trieStoreageIterator) Slot() []byte { return i.iter.Value } + +type nilStoreageIterator struct{} + +func (i nilStoreageIterator) Next() bool { return false } +func (i nilStoreageIterator) Error() error { return nil } +func (i nilStoreageIterator) Hash() common.Hash { return types.EmptyRootHash } +func (i nilStoreageIterator) Release() {} +func (i nilStoreageIterator) Slot() []byte { return nil } + +func (h *snapHandler) StorageIterator(root, account, origin common.Hash) (snapshot.StorageIterator, error) { + triedb := trie.NewDatabase(h.db, h.chain.CacheConfig().TriedbConfig()) + t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) + if err != nil { + log.Error("Failed to open trie", "root", root, "err", err) + return nil, err + } + acc, err := t.GetAccountByHash(account) + if err != nil { + log.Error("Failed to find account in trie", "root", root, "account", account, "err", err) + return nil, err + } + if acc.Root == types.EmptyRootHash { + return nilStoreageIterator{}, nil + } + id := trie.StorageTrieID(root, account, acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return nil, err + } + nodeIter, err := storageTrie.NodeIterator(origin[:]) + if err != nil { + log.Error("Failed creating node iterator to open storage trie", "root", acc.Root, "err", err) + return nil, err + } + return trieStoreageIterator{trieIteratorWrapper{ + iter: trie.NewIterator(nodeIter), + triedb: triedb, + }}, nil +} + +// RunPeer is invoked when a peer joins on the `snap` protocol. +func (h *snapHandler) RunPeer(peer *snap.Peer, hand snap.Handler) error { + hPeer := (*protocolHandler)(h).getCreatePeer(peer.ID()) + hPeer.mutex.Lock() + if hPeer.snap != nil { + hPeer.mutex.Unlock() + return fmt.Errorf("peer id already known") + } + hPeer.snap = peer + err := h.downloader.SnapSyncer.Register(peer) + hPeer.mutex.Unlock() + if err != nil { + peer.Log().Error("Failed to register peer in snap syncer", "err", err) + return err + } + return hand(peer) +} + +// PeerInfo retrieves all known `snap` information about a peer. +func (h *snapHandler) PeerInfo(id enode.ID) interface{} { + return nil +} + +// Handle is invoked from a peer's message handler when it receives a new remote +// message that the handler couldn't consume and serve itself. +func (h *snapHandler) Handle(peer *snap.Peer, packet snap.Packet) error { + return h.downloader.DeliverSnapPacket(peer, packet) +} diff --git a/arbitrum/sync_test.go b/arbitrum/sync_test.go new file mode 100644 index 0000000000..15838b3b76 --- /dev/null +++ b/arbitrum/sync_test.go @@ -0,0 +1,372 @@ +package arbitrum + +import ( + "encoding/hex" + "math/big" + "net" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" +) + +type dummyIterator struct { + lock sync.Mutex + nodes []*enode.Node //first one is never used +} + +func (i *dummyIterator) Next() bool { // moves to next node + i.lock.Lock() + defer i.lock.Unlock() + + if len(i.nodes) == 0 { + log.Info("dummy iterator: done") + return false + } + i.nodes = i.nodes[1:] + return len(i.nodes) > 0 +} + +func (i *dummyIterator) Node() *enode.Node { // returns current node + i.lock.Lock() + defer i.lock.Unlock() + if len(i.nodes) == 0 { + return nil + } + if i.nodes[0] != nil { + log.Info("dummy iterator: emit", "id", i.nodes[0].ID(), "ip", i.nodes[0].IP(), "tcp", i.nodes[0].TCP(), "udp", i.nodes[0].UDP()) + } + return i.nodes[0] +} + +func (i *dummyIterator) Close() { // ends the iterator + i.nodes = nil +} + +type dummySyncHelper struct { + confirmed *types.Header + checkpoint *types.Header +} + +func (d *dummySyncHelper) LastConfirmed() (*types.Header, uint64, error) { + return d.confirmed, 0, nil +} + +func (d *dummySyncHelper) LastCheckpoint() (*types.Header, error) { + if d.confirmed == nil { + return nil, nil + } + return d.checkpoint, nil +} + +func (d *dummySyncHelper) CheckpointSupported(*types.Header) (bool, error) { + return true, nil +} + +func (d *dummySyncHelper) ValidateConfirmed(header *types.Header, node uint64) (bool, error) { + if d.confirmed == nil { + return true, nil + } + if header == nil { + return false, nil + } + if d.confirmed.Hash() == header.Hash() { + return true, nil + } + return false, nil +} + +func testHasBlock(t *testing.T, chain *core.BlockChain, block *types.Block, shouldHaveState bool) { + t.Helper() + hasHeader := chain.GetHeaderByNumber(block.NumberU64()) + if hasHeader == nil { + t.Fatal("block not found") + } + if hasHeader.Hash() != block.Hash() { + t.Fatal("wrong block in blockchain") + } + _, err := chain.StateAt(hasHeader.Root) + if err != nil && shouldHaveState { + t.Fatal("should have state, but doesn't") + } + if err == nil && !shouldHaveState { + t.Fatal("should not have state, but does") + } +} + +func portFromAddress(address string) (int, error) { + splitAddr := strings.Split(address, ":") + return strconv.Atoi(splitAddr[len(splitAddr)-1]) +} + +func TestSimpleSync(t *testing.T) { + const pivotBlockNum = 50 + const syncBlockNum = 70 + const extraBlocks = 200 + + log.SetDefault(log.NewLogger(log.NewTerminalHandlerWithLevel(os.Stderr, log.LevelTrace, false))) + + // key for source node p2p + sourceKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + + // key for dest node p2p + destKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + + // key for bad node p2p + badNodeKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + + // source node + sourceStackConf := node.DefaultConfig + sourceStackConf.DataDir = t.TempDir() + sourceStackConf.P2P.DiscoveryV4 = false + sourceStackConf.P2P.DiscoveryV5 = false + sourceStackConf.P2P.ListenAddr = "127.0.0.1:0" + sourceStackConf.P2P.PrivateKey = sourceKey + + sourceStack, err := node.New(&sourceStackConf) + if err != nil { + t.Fatal(err) + } + sourceDb, err := sourceStack.OpenDatabaseWithFreezer("l2chaindata", 2048, 512, "", "", false) + if err != nil { + t.Fatal(err) + } + + // create and populate chain + + // code for contractcodehex below: + // pragma solidity ^0.8.20; + // + // contract Temmp { + // uint256[0x10000] private store; + // + // fallback(bytes calldata data) external payable returns (bytes memory) { + // uint16 index = uint16(uint256(bytes32(data[0:32]))); + // store[index] += 1; + // return ""; + // } + // } + contractCodeHex := "608060405234801561001057600080fd5b50610218806100206000396000f3fe608060405260003660606000838360009060209261001f9392919061008a565b9061002a91906100e7565b60001c9050600160008261ffff1662010000811061004b5761004a610146565b5b01600082825461005b91906101ae565b9250508190555060405180602001604052806000815250915050915050805190602001f35b600080fd5b600080fd5b6000808585111561009e5761009d610080565b5b838611156100af576100ae610085565b5b6001850283019150848603905094509492505050565b600082905092915050565b6000819050919050565b600082821b905092915050565b60006100f383836100c5565b826100fe81356100d0565b9250602082101561013e576101397fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff836020036008026100da565b831692505b505092915050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052603260045260246000fd5b6000819050919050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b60006101b982610175565b91506101c483610175565b92508282019050808211156101dc576101db61017f565b5b9291505056fea26469706673582212202777d6cb94519b9aa7026cf6dad162739731e124c6379b15c343ff1c6e84a5f264736f6c63430008150033" + contractCode, err := hex.DecodeString(contractCodeHex) + if err != nil { + t.Fatal("decode contract error:", err) + } + testUser, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + testUserAddress := crypto.PubkeyToAddress(testUser.PublicKey) + + testUser2, err := crypto.GenerateKey() + if err != nil { + t.Fatal("generate key err:", err) + } + testUser2Address := crypto.PubkeyToAddress(testUser2.PublicKey) + + gspec := &core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{ + testUserAddress: {Balance: new(big.Int).Lsh(big.NewInt(1), 250)}, + testUser2Address: {Balance: new(big.Int).Lsh(big.NewInt(1), 250)}, + }, + } + sourceChain, _ := core.NewBlockChain(sourceDb, nil, nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) + signer := types.MakeSigner(sourceChain.Config(), big.NewInt(1), 0) + + firstAddress := common.Address{} + _, blocks, allReceipts := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), syncBlockNum+extraBlocks, func(i int, gen *core.BlockGen) { + creationNonce := gen.TxNonce(testUser2Address) + tx, err := types.SignTx(types.NewContractCreation(creationNonce, new(big.Int), 1000000, gen.BaseFee(), contractCode), signer, testUser2) + if err != nil { + t.Fatalf("failed to create contract: %v", err) + } + gen.AddTx(tx) + + contractAddress := crypto.CreateAddress(testUser2Address, creationNonce) + + nonce := gen.TxNonce(testUserAddress) + tx, err = types.SignNewTx(testUser, signer, &types.LegacyTx{ + Nonce: nonce, + GasPrice: gen.BaseFee(), + Gas: uint64(1000001), + }) + if err != nil { + t.Fatalf("failed to create tx: %v", err) + } + gen.AddTx(tx) + + iterHash := common.BigToHash(big.NewInt(int64(i))) + tx, err = types.SignNewTx(testUser, signer, &types.LegacyTx{ + To: &contractAddress, + Nonce: nonce + 1, + GasPrice: gen.BaseFee(), + Gas: uint64(1000001), + Data: iterHash[:], + }) + if err != nil { + t.Fatalf("failed to create tx: %v", err) + } + gen.AddTx(tx) + + if firstAddress == (common.Address{}) { + firstAddress = contractAddress + } + + tx, err = types.SignNewTx(testUser, signer, &types.LegacyTx{ + To: &firstAddress, + Nonce: nonce + 2, + GasPrice: gen.BaseFee(), + Gas: uint64(1000001), + Data: iterHash[:], + }) + if err != nil { + t.Fatalf("failed to create tx: %v", err) + } + gen.AddTx(tx) + }) + + for _, receipts := range allReceipts { + if len(receipts) < 3 { + t.Fatal("missing receipts") + } + for _, receipt := range receipts { + if receipt.Status == 0 { + t.Fatal("failed transaction") + } + } + } + pivotBlock := blocks[pivotBlockNum-1] + syncBlock := blocks[syncBlockNum-1] + if _, err := sourceChain.InsertChain(blocks[:pivotBlockNum]); err != nil { + t.Fatal(err) + } + sourceChain.TrieDB().Commit(blocks[pivotBlockNum-1].Root(), true) + if _, err := sourceChain.InsertChain(blocks[pivotBlockNum:]); err != nil { + t.Fatal(err) + } + + // should have state of pivot but nothing around + testHasBlock(t, sourceChain, blocks[pivotBlockNum-2], false) + testHasBlock(t, sourceChain, blocks[pivotBlockNum-1], true) + testHasBlock(t, sourceChain, blocks[pivotBlockNum], false) + + // source node + sourceHandler := NewProtocolHandler(sourceDb, sourceChain, &dummySyncHelper{syncBlock.Header(), pivotBlock.Header()}, false) + sourceStack.RegisterProtocols(sourceHandler.MakeProtocols(&dummyIterator{})) + if err := sourceStack.Start(); err != nil { + t.Fatal(err) + } + + // bad node (on wrong blockchain) + _, badBlocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), syncBlockNum+extraBlocks, func(i int, gen *core.BlockGen) { + creationNonce := gen.TxNonce(testUser2Address) + tx, err := types.SignTx(types.NewContractCreation(creationNonce, new(big.Int), 1000000, gen.BaseFee(), contractCode), signer, testUser2) + if err != nil { + t.Fatalf("failed to create contract: %v", err) + } + gen.AddTx(tx) + }) + badStackConf := sourceStackConf + badStackConf.DataDir = t.TempDir() + badStackConf.P2P.PrivateKey = badNodeKey + badStack, err := node.New(&badStackConf) + if err != nil { + t.Fatal(err) + } + + badDb, err := badStack.OpenDatabaseWithFreezer("l2chaindata", 2048, 512, "", "", false) + if err != nil { + t.Fatal(err) + } + badChain, _ := core.NewBlockChain(badDb, nil, nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) + if _, err := badChain.InsertChain(badBlocks[:pivotBlockNum]); err != nil { + t.Fatal(err) + } + badChain.TrieDB().Commit(badBlocks[pivotBlockNum-1].Root(), true) + if _, err := badChain.InsertChain(badBlocks[pivotBlockNum:]); err != nil { + t.Fatal(err) + } + badHandler := NewProtocolHandler(badDb, badChain, &dummySyncHelper{blocks[syncBlockNum-1].Header(), badBlocks[pivotBlockNum-1].Header()}, false) + badStack.RegisterProtocols(badHandler.MakeProtocols(&dummyIterator{})) + if err := badStack.Start(); err != nil { + t.Fatal(err) + } + + // figure out port of the source node and create dummy iter that points to it + sourcePort, err := portFromAddress(sourceStack.Server().Config.ListenAddr) + if err != nil { + t.Fatal(err) + } + badNodePort, err := portFromAddress(badStack.Server().Config.ListenAddr) + if err != nil { + t.Fatal(err) + } + badEnode := enode.NewV4(&badNodeKey.PublicKey, net.IPv4(127, 0, 0, 1), badNodePort, 0) + sourceEnode := enode.NewV4(&sourceKey.PublicKey, net.IPv4(127, 0, 0, 1), sourcePort, 0) + iter := &dummyIterator{ + nodes: []*enode.Node{nil, badEnode, sourceEnode}, + } + + // dest node + destStackConf := sourceStackConf + destStackConf.DataDir = t.TempDir() + destStackConf.P2P.PrivateKey = destKey + destStack, err := node.New(&destStackConf) + if err != nil { + t.Fatal(err) + } + + destDb, err := destStack.OpenDatabaseWithFreezer("l2chaindata", 2048, 512, "", "", false) + if err != nil { + t.Fatal(err) + } + destChain, _ := core.NewBlockChain(destDb, nil, nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) + destHandler := NewProtocolHandler(destDb, destChain, &dummySyncHelper{syncBlock.Header(), nil}, true) + destStack.RegisterProtocols(destHandler.MakeProtocols(iter)) + + // start sync + log.Info("dest listener", "address", destStack.Server().Config.ListenAddr) + log.Info("initial source", "head", sourceChain.CurrentBlock()) + log.Info("initial dest", "head", destChain.CurrentBlock()) + log.Info("pivot", "head", pivotBlock.Header()) + if err := destStack.Start(); err != nil { + t.Fatal(err) + } + + <-time.After(time.Second * 5) + + log.Info("final source", "head", sourceChain.CurrentBlock()) + log.Info("final dest", "head", destChain.CurrentBlock()) + log.Info("sync block", "header", syncBlock.Header()) + + // check sync + if destChain.CurrentBlock().Number.Cmp(syncBlock.Number()) != 0 { + t.Fatal("did not sync to sync block") + } + + testHasBlock(t, destChain, syncBlock, true) + testHasBlock(t, destChain, pivotBlock, true) + testHasBlock(t, destChain, blocks[pivotBlockNum-2], false) +} diff --git a/core/blockchain.go b/core/blockchain.go index 98b5cbc55c..35b5394676 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -176,6 +176,11 @@ func (c *CacheConfig) triedbConfig() *trie.Config { return config } +// arbitrum: expose triedbConfig +func (c *CacheConfig) TriedbConfig() *trie.Config { + return c.triedbConfig() +} + // defaultCacheConfig are the default caching values if none are specified by the // user (also used during testing). var defaultCacheConfig = &CacheConfig{ @@ -2579,3 +2584,7 @@ func (bc *BlockChain) SetTrieFlushInterval(interval time.Duration) { func (bc *BlockChain) GetTrieFlushInterval() time.Duration { return time.Duration(bc.flushInterval.Load()) } + +func (bc *BlockChain) CacheConfig() *CacheConfig { + return bc.cacheConfig +} diff --git a/eth/downloader/beaconsync.go b/eth/downloader/beaconsync.go index d3f75c8527..c004c378f5 100644 --- a/eth/downloader/beaconsync.go +++ b/eth/downloader/beaconsync.go @@ -17,6 +17,7 @@ package downloader import ( + "errors" "fmt" "sync" "time" @@ -41,8 +42,8 @@ type beaconBackfiller struct { lock sync.Mutex // Mutex protecting the sync lock } -// newBeaconBackfiller is a helper method to create the backfiller. -func newBeaconBackfiller(dl *Downloader, success func()) backfiller { +// NewBeaconBackfiller is a helper method to create the backfiller. +func NewBeaconBackfiller(dl *Downloader, success func()) Backfiller { return &beaconBackfiller{ downloader: dl, success: success, @@ -52,7 +53,7 @@ func newBeaconBackfiller(dl *Downloader, success func()) backfiller { // suspend cancels any background downloader threads and returns the last header // that has been successfully backfilled (potentially in a previous run), or the // genesis. -func (b *beaconBackfiller) suspend() *types.Header { +func (b *beaconBackfiller) Suspend() *types.Header { // If no filling is running, don't waste cycles b.lock.Lock() filling := b.filling @@ -80,7 +81,7 @@ func (b *beaconBackfiller) suspend() *types.Header { } // resume starts the downloader threads for backfilling state and chain data. -func (b *beaconBackfiller) resume() { +func (b *beaconBackfiller) Resume() { b.lock.Lock() if b.filling { // If a previous filling cycle is still running, just ignore this start @@ -120,7 +121,7 @@ func (b *beaconBackfiller) resume() { // setMode updates the sync mode from the current one to the requested one. If // there's an active sync in progress, it will be cancelled and restarted. -func (b *beaconBackfiller) setMode(mode SyncMode) { +func (b *beaconBackfiller) SetMode(mode SyncMode) { // Update the old sync mode and track if it was changed b.lock.Lock() updated := b.syncMode != mode @@ -134,8 +135,8 @@ func (b *beaconBackfiller) setMode(mode SyncMode) { return } log.Error("Downloader sync mode changed mid-run", "old", mode.String(), "new", mode.String()) - b.suspend() - b.resume() + b.Suspend() + b.Resume() } // SetBadBlockCallback sets the callback to run when a bad block is hit by the @@ -165,6 +166,18 @@ func (d *Downloader) BeaconExtend(mode SyncMode, head *types.Header) error { return d.beaconSync(mode, head, nil, false) } +// PivotSync sets an explicit pivot and syncs from there. Pivot state will be read from peers. +func (d *Downloader) PivotSync(head *types.Header, pivot *types.Header) error { + if pivot != nil && head.Number.Cmp(pivot.Number) < 0 { + return errors.New("pivot must be behind head") + } + d.pivotLock.Lock() + d.pivotHeader = pivot + d.pivotExplicit = true + d.pivotLock.Unlock() + return d.beaconSync(SnapSync, head, nil, true) +} + // beaconSync is the post-merge version of the chain synchronization, where the // chain is not downloaded from genesis onward, rather from trusted head announces // backwards. @@ -178,7 +191,7 @@ func (d *Downloader) beaconSync(mode SyncMode, head *types.Header, final *types. // // Super crazy dangerous type cast. Should be fine (TM), we're only using a // different backfiller implementation for skeleton tests. - d.skeleton.filler.(*beaconBackfiller).setMode(mode) + d.skeleton.filler.SetMode(mode) // Signal the skeleton sync to switch to a new head, however it wants if err := d.skeleton.Sync(head, final, force); err != nil { @@ -268,6 +281,14 @@ func (d *Downloader) findBeaconAncestor() (uint64, error) { return start, nil } +func (d *Downloader) SkeletonHead() (*types.Header, error) { + head, _, _, err := d.skeleton.Bounds() + if err != nil { + return nil, err + } + return head, nil +} + // fetchBeaconHeaders feeds skeleton headers to the downloader queue for scheduling // until sync errors or is finished. func (d *Downloader) fetchBeaconHeaders(from uint64) error { @@ -299,7 +320,7 @@ func (d *Downloader) fetchBeaconHeaders(from uint64) error { // If the pivot became stale (older than 2*64-8 (bit of wiggle room)), // move it ahead to HEAD-64 d.pivotLock.Lock() - if d.pivotHeader != nil { + if d.pivotHeader != nil && !d.pivotExplicit { if head.Number.Uint64() > d.pivotHeader.Number.Uint64()+2*uint64(fsMinFullBlocks)-8 { // Retrieve the next pivot header, either from skeleton chain // or the filled chain diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 8d449246a6..9543461463 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -131,8 +131,9 @@ type Downloader struct { skeleton *skeleton // Header skeleton to backfill the chain with (eth2 mode) // State sync - pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root - pivotLock sync.RWMutex // Lock protecting pivot header reads from updates + pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root + pivotExplicit bool // arbitrum: pivot is set explicitly only + pivotLock sync.RWMutex // Lock protecting pivot header reads from updates SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now stateSyncStart chan *stateSync @@ -216,7 +217,7 @@ type BlockChain interface { } // New creates a new downloader to fetch hashes and blocks from remote peers. -func New(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn, success func()) *Downloader { +func New(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn, backFillerCreator func(*Downloader) Backfiller) *Downloader { if lightchain == nil { lightchain = chain } @@ -235,7 +236,7 @@ func New(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchai syncStartBlock: chain.CurrentSnapBlock().Number.Uint64(), } // Create the post-merge skeleton syncer and start the process - dl.skeleton = newSkeleton(stateDb, dl.peers, dropPeer, newBeaconBackfiller(dl, success)) + dl.skeleton = newSkeleton(stateDb, dl.peers, dropPeer, backFillerCreator(dl)) go dl.stateFetcher() return dl @@ -476,7 +477,29 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * // Look up the sync boundaries: the common ancestor and the target block var latest, pivot, final *types.Header - if !beaconMode { + var pivotExplicit bool + d.pivotLock.Lock() + if d.pivotExplicit { + pivotExplicit = true + pivot = d.pivotHeader + } + d.pivotLock.Unlock() + if pivotExplicit { + latest, _, _, err = d.skeleton.Bounds() + if err != nil { + return err + } + if pivot != nil { + localPivot := d.skeleton.Header(pivot.Number.Uint64()) + if localPivot == nil { + return fmt.Errorf("pivot not in skeleton chain") + } + if localPivot.Hash() != pivot.Hash() { + return fmt.Errorf("pivot disagrees with skeleton") + } + final = localPivot + } + } else if !beaconMode { // In legacy mode, use the master peer to retrieve the headers from latest, pivot, err = d.fetchHead(p) if err != nil { @@ -517,7 +540,7 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * // threshold (i.e. new chain). In that case we won't really snap sync // anyway, but still need a valid pivot block to avoid some code hitting // nil panics on access. - if mode == SnapSync && pivot == nil { + if mode == SnapSync && pivot == nil && !pivotExplicit { pivot = d.blockchain.CurrentBlock() } height := latest.Number.Uint64() @@ -545,7 +568,11 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * // Ensure our origin point is below any snap sync pivot point if mode == SnapSync { - if height <= uint64(fsMinFullBlocks) { + if pivotExplicit { + if pivot != nil { + rawdb.WriteLastPivotNumber(d.stateDB, pivot.Nonce.Uint64()) + } + } else if height <= uint64(fsMinFullBlocks) { origin = 0 } else { pivotNumber := pivot.Number.Uint64() @@ -558,7 +585,7 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * } } d.committed.Store(true) - if mode == SnapSync && pivot.Number.Uint64() != 0 { + if mode == SnapSync && pivot != nil && pivot.Number.Uint64() != 0 { d.committed.Store(false) } if mode == SnapSync { @@ -635,10 +662,19 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * } if mode == SnapSync { d.pivotLock.Lock() - d.pivotHeader = pivot + if !d.pivotExplicit { + d.pivotHeader = pivot + } d.pivotLock.Unlock() - - fetchers = append(fetchers, func() error { return d.processSnapSyncContent() }) + if pivot != nil { + fetchers = append(fetchers, func() error { return d.processSnapSyncContent() }) + } else { + // no pivot yet - cannot complete this sync + fetchers = append(fetchers, func() error { + <-d.cancelCh + return errCanceled + }) + } } else if mode == FullSync { fetchers = append(fetchers, func() error { return d.processFullSyncContent(ttd, beaconMode) }) } @@ -1034,8 +1070,13 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, head uint64) e case pivoting: d.pivotLock.RLock() pivot := d.pivotHeader.Number.Uint64() + pivotExplicit := d.pivotExplicit d.pivotLock.RUnlock() + if pivotExplicit { + pivoting = false + continue + } p.log.Trace("Fetching next pivot header", "number", pivot+uint64(fsMinFullBlocks)) headers, hashes, err = d.fetchHeadersByNumber(p, pivot+uint64(fsMinFullBlocks), 2, fsMinFullBlocks-9, false) // move +64 when it's 2x64-8 deep @@ -1080,6 +1121,9 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, head uint64) e if d.pivotHeader != nil { pivot = d.pivotHeader.Number.Uint64() } + if d.pivotExplicit { + pivoting = false + } d.pivotLock.RUnlock() if pivoting { @@ -1599,6 +1643,7 @@ func (d *Downloader) processSnapSyncContent() error { // notifications from the header downloader d.pivotLock.RLock() pivot := d.pivotHeader + pivotExplicit := d.pivotExplicit d.pivotLock.RUnlock() if oldPivot == nil { // no results piling up, we can move the pivot @@ -1623,7 +1668,7 @@ func (d *Downloader) processSnapSyncContent() error { // Note, we have `reorgProtHeaderDelay` number of blocks withheld, Those // need to be taken into account, otherwise we're detecting the pivot move // late and will drop peers due to unavailable state!!! - if height := latest.Number.Uint64(); height >= pivot.Number.Uint64()+2*uint64(fsMinFullBlocks)-uint64(reorgProtHeaderDelay) { + if height := latest.Number.Uint64(); height >= pivot.Number.Uint64()+2*uint64(fsMinFullBlocks)-uint64(reorgProtHeaderDelay) && !pivotExplicit { log.Warn("Pivot became stale, moving", "old", pivot.Number.Uint64(), "new", height-uint64(fsMinFullBlocks)+uint64(reorgProtHeaderDelay)) pivot = results[len(results)-1-fsMinFullBlocks+reorgProtHeaderDelay].Header // must exist as lower old pivot is uncommitted diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index f8abfd2b22..90587e3bf7 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -31,12 +31,14 @@ import ( "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" @@ -81,7 +83,8 @@ func newTesterWithNotification(t *testing.T, success func()) *downloadTester { chain: chain, peers: make(map[string]*downloadTesterPeer), } - tester.downloader = New(db, new(event.TypeMux), tester.chain, nil, tester.dropPeer, success) + backfillerCreator := func(dl *Downloader) Backfiller { return NewBeaconBackfiller(dl, success) } + tester.downloader = New(db, new(event.TypeMux), tester.chain, nil, tester.dropPeer, backfillerCreator) return tester } @@ -153,6 +156,38 @@ type downloadTesterPeer struct { withholdHeaders map[common.Hash]struct{} } +type snapBackend struct { + chain *core.BlockChain +} + +func (d *snapBackend) ContractCodeWithPrefix(codeHash common.Hash) ([]byte, error) { + return d.chain.ContractCodeWithPrefix(codeHash) +} + +func (d *snapBackend) TrieDB() *trie.Database { + return d.chain.TrieDB() +} + +func (d *snapBackend) Snapshot(root common.Hash) snapshot.Snapshot { + return d.chain.Snapshots().Snapshot(root) +} + +func (d *snapBackend) AccountIterator(root, account common.Hash) (snapshot.AccountIterator, error) { + return d.chain.Snapshots().AccountIterator(root, account) +} + +func (d *snapBackend) StorageIterator(root, account, origin common.Hash) (snapshot.StorageIterator, error) { + return d.chain.Snapshots().StorageIterator(root, account, origin) +} + +func (d *snapBackend) RunPeer(*snap.Peer, snap.Handler) error { return nil } +func (d *snapBackend) PeerInfo(enode.ID) interface{} { return "Foo" } +func (d *snapBackend) Handle(*snap.Peer, snap.Packet) error { return nil } + +func (dlp *downloadTesterPeer) SnapBackend() *snapBackend { + return &snapBackend{dlp.chain} +} + // Head constructs a function to retrieve a peer's current head hash // and total difficulty. func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int) { @@ -344,7 +379,7 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi Limit: limit, Bytes: bytes, } - slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req) + slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.SnapBackend(), req) // We need to convert to non-slim format, delegate to the packet code res := &snap.AccountRangePacket{ @@ -371,7 +406,7 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, Limit: limit, Bytes: bytes, } - storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req) + storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.SnapBackend(), req) // We need to convert to demultiplex, delegate to the packet code res := &snap.StorageRangesPacket{ @@ -392,7 +427,7 @@ func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, Hashes: hashes, Bytes: bytes, } - codes := snap.ServiceGetByteCodesQuery(dlp.chain, req) + codes := snap.ServiceGetByteCodesQuery(dlp.SnapBackend(), req) go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes) return nil } @@ -406,7 +441,7 @@ func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, pat Paths: paths, Bytes: bytes, } - nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now()) + nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.SnapBackend(), req, time.Now()) go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes) return nil } diff --git a/eth/downloader/skeleton.go b/eth/downloader/skeleton.go index 873ee950b6..0d7d68ddd7 100644 --- a/eth/downloader/skeleton.go +++ b/eth/downloader/skeleton.go @@ -154,7 +154,7 @@ type headerResponse struct { // backfiller is a callback interface through which the skeleton sync can tell // the downloader that it should suspend or resume backfilling on specific head // events (e.g. suspend on forks or gaps, resume on successful linkups). -type backfiller interface { +type Backfiller interface { // suspend requests the backfiller to abort any running full or snap sync // based on the skeleton chain as it might be invalid. The backfiller should // gracefully handle multiple consecutive suspends without a resume, even @@ -162,13 +162,15 @@ type backfiller interface { // // The method should return the last block header that has been successfully // backfilled (in the current or a previous run), falling back to the genesis. - suspend() *types.Header + Suspend() *types.Header // resume requests the backfiller to start running fill or snap sync based on // the skeleton chain as it has successfully been linked. Appending new heads // to the end of the chain will not result in suspend/resume cycles. // leaking too much sync logic out to the filler. - resume() + Resume() + + SetMode(mode SyncMode) } // skeleton represents a header chain synchronized after the merge where blocks @@ -200,7 +202,7 @@ type backfiller interface { // for now. type skeleton struct { db ethdb.Database // Database backing the skeleton - filler backfiller // Chain syncer suspended/resumed by head events + filler Backfiller // Chain syncer suspended/resumed by head events peers *peerSet // Set of peers we can sync from idles map[string]*peerConnection // Set of idle peers in the current sync cycle @@ -227,7 +229,7 @@ type skeleton struct { // newSkeleton creates a new sync skeleton that tracks a potentially dangling // header chain until it's linked into an existing set of blocks. -func newSkeleton(db ethdb.Database, peers *peerSet, drop peerDropFn, filler backfiller) *skeleton { +func newSkeleton(db ethdb.Database, peers *peerSet, drop peerDropFn, filler Backfiller) *skeleton { sk := &skeleton{ db: db, filler: filler, @@ -372,7 +374,7 @@ func (s *skeleton) sync(head *types.Header) (*types.Header, error) { rawdb.HasBody(s.db, s.progress.Subchains[0].Next, s.scratchHead) && rawdb.HasReceipts(s.db, s.progress.Subchains[0].Next, s.scratchHead) if linked { - s.filler.resume() + s.filler.Resume() } defer func() { // The filler needs to be suspended, but since it can block for a while @@ -382,7 +384,7 @@ func (s *skeleton) sync(head *types.Header) (*types.Header, error) { done := make(chan struct{}) go func() { defer close(done) - filled := s.filler.suspend() + filled := s.filler.Suspend() if filled == nil { log.Error("Latest filled block is not available") return @@ -486,7 +488,7 @@ func (s *skeleton) sync(head *types.Header) (*types.Header, error) { // is still running, it will pick it up. If it already terminated, // a new cycle needs to be spun up. if linked { - s.filler.resume() + s.filler.Resume() } case req := <-requestFails: @@ -1198,14 +1200,14 @@ func (s *skeleton) cleanStales(filled *types.Header) error { // Bounds retrieves the current head and tail tracked by the skeleton syncer // and optionally the last known finalized header if any was announced and if -// it is still in the sync range. This method is used by the backfiller, whose +// it is still in the sync range. This method is used by the Backfiller, whose // life cycle is controlled by the skeleton syncer. // // Note, the method will not use the internal state of the skeleton, but will // rather blindly pull stuff from the database. This is fine, because the back- // filler will only run when the skeleton chain is fully downloaded and stable. // There might be new heads appended, but those are atomic from the perspective -// of this method. Any head reorg will first tear down the backfiller and only +// of this method. Any head reorg will first tear down the Backfiller and only // then make the modification. func (s *skeleton) Bounds() (head *types.Header, tail *types.Header, final *types.Header, err error) { // Read the current sync progress from disk and figure out the current head. @@ -1238,7 +1240,7 @@ func (s *skeleton) Bounds() (head *types.Header, tail *types.Header, final *type } // Header retrieves a specific header tracked by the skeleton syncer. This method -// is meant to be used by the backfiller, whose life cycle is controlled by the +// is meant to be used by the Backfiller, whose life cycle is controlled by the // skeleton syncer. // // Note, outside the permitted runtimes, this method might return nil results and diff --git a/eth/downloader/skeleton_test.go b/eth/downloader/skeleton_test.go index 2b108dfe93..bfadb49acd 100644 --- a/eth/downloader/skeleton_test.go +++ b/eth/downloader/skeleton_test.go @@ -46,7 +46,7 @@ type hookedBackfiller struct { // newHookedBackfiller creates a hooked backfiller with all callbacks disabled, // essentially acting as a noop. -func newHookedBackfiller() backfiller { +func newHookedBackfiller() Backfiller { return new(hookedBackfiller) } @@ -54,7 +54,7 @@ func newHookedBackfiller() backfiller { // based on the skeleton chain as it might be invalid. The backfiller should // gracefully handle multiple consecutive suspends without a resume, even // on initial startup. -func (hf *hookedBackfiller) suspend() *types.Header { +func (hf *hookedBackfiller) Suspend() *types.Header { if hf.suspendHook != nil { return hf.suspendHook() } @@ -64,12 +64,14 @@ func (hf *hookedBackfiller) suspend() *types.Header { // resume requests the backfiller to start running fill or snap sync based on // the skeleton chain as it has successfully been linked. Appending new heads // to the end of the chain will not result in suspend/resume cycles. -func (hf *hookedBackfiller) resume() { +func (hf *hookedBackfiller) Resume() { if hf.resumeHook != nil { hf.resumeHook() } } +func (hf *hookedBackfiller) SetMode(SyncMode) {} + // skeletonTestPeer is a mock peer that can only serve header requests from a // pre-perated header chain (which may be arbitrarily wrong for testing). // diff --git a/eth/handler.go b/eth/handler.go index a327af6113..861741bc10 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -182,8 +182,11 @@ func newHandler(config *handlerConfig) (*handler, error) { if h.snapSync.Load() && config.Chain.Snapshots() == nil { return nil, errors.New("snap sync not supported with snapshots disabled") } + backfillerCreator := func(dl *downloader.Downloader) downloader.Backfiller { + return downloader.NewBeaconBackfiller(dl, h.enableSyncedFeatures) + } // Construct the downloader (long sync) - h.downloader = downloader.New(config.Database, h.eventMux, h.chain, nil, h.removePeer, h.enableSyncedFeatures) + h.downloader = downloader.New(config.Database, h.eventMux, h.chain, nil, h.removePeer, backfillerCreator) if ttd := h.chain.Config().TerminalTotalDifficulty; ttd != nil { if h.chain.Config().TerminalTotalDifficultyPassed { log.Info("Chain post-merge, sync via beacon client") diff --git a/eth/handler_snap.go b/eth/handler_snap.go index 767416ffd6..7db98e266a 100644 --- a/eth/handler_snap.go +++ b/eth/handler_snap.go @@ -17,16 +17,36 @@ package eth import ( - "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/trie" ) // snapHandler implements the snap.Backend interface to handle the various network // packets that are sent as replies or broadcasts. type snapHandler handler -func (h *snapHandler) Chain() *core.BlockChain { return h.chain } +func (h *snapHandler) ContractCodeWithPrefix(codeHash common.Hash) ([]byte, error) { + return (*handler)(h).chain.ContractCodeWithPrefix(codeHash) +} + +func (h *snapHandler) TrieDB() *trie.Database { + return (*handler)(h).chain.StateCache().TrieDB() +} + +func (h *snapHandler) Snapshot(root common.Hash) snapshot.Snapshot { + return (*handler)(h).chain.Snapshots().Snapshot(root) +} + +func (h *snapHandler) AccountIterator(root, account common.Hash) (snapshot.AccountIterator, error) { + return (*handler)(h).chain.Snapshots().AccountIterator(root, account) +} + +func (h *snapHandler) StorageIterator(root, account, origin common.Hash) (snapshot.StorageIterator, error) { + return (*handler)(h).chain.Snapshots().StorageIterator(root, account, origin) +} // RunPeer is invoked when a peer joins on the `snap` protocol. func (h *snapHandler) RunPeer(peer *snap.Peer, hand snap.Handler) error { diff --git a/eth/protocols/arb/enr.go b/eth/protocols/arb/enr.go new file mode 100644 index 0000000000..1da4cdce69 --- /dev/null +++ b/eth/protocols/arb/enr.go @@ -0,0 +1,14 @@ +package arb + +import "github.com/ethereum/go-ethereum/rlp" + +// enrEntry is the ENR entry which advertises `snap` protocol on the discovery. +type enrEntry struct { + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` +} + +// ENRKey implements enr.Entry. +func (e enrEntry) ENRKey() string { + return "arb" +} diff --git a/eth/protocols/arb/handler.go b/eth/protocols/arb/handler.go new file mode 100644 index 0000000000..8a2f99b8e4 --- /dev/null +++ b/eth/protocols/arb/handler.go @@ -0,0 +1,139 @@ +package arb + +import ( + "fmt" + "time" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +type Peer struct { + p2pPeer *p2p.Peer + rw p2p.MsgReadWriter +} + +func NewPeer(p2pPeer *p2p.Peer, rw p2p.MsgReadWriter) *Peer { + return &Peer{ + p2pPeer: p2pPeer, + rw: rw, + } +} + +func (p *Peer) RequestCheckpoint(header *types.Header) error { + if header == nil { + return p2p.Send(p.rw, GetLastCheckpointMsg, struct{}{}) + } + return p2p.Send(p.rw, CheckpointQueryMsg, &CheckpointQueryPacket{ + Header: header, + }) +} + +func (p *Peer) RequestLastConfirmed() error { + return p2p.Send(p.rw, GetLastConfirmedMsg, struct{}{}) +} + +func (p *Peer) ID() string { + return p.p2pPeer.ID().String() +} + +func (p *Peer) Node() *enode.Node { + return p.p2pPeer.Node() +} + +// Handle is the callback invoked to manage the life cycle of a `snap` peer. +// When this function terminates, the peer is disconnected. +func Handle(backend Backend, peer *Peer) error { + for { + if err := HandleMessage(backend, peer); err != nil { + log.Debug("Message handling failed in `arb`", "err", err) + return err + } + } +} + +// HandleMessage is invoked whenever an inbound message is received from a +// remote peer on the `snap` protocol. The remote connection is torn down upon +// returning any error. +func HandleMessage(backend Backend, peer *Peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := peer.rw.ReadMsg() + if err != nil { + return err + } + // if msg.Size > maxMessageSize { + // return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) + // } + defer msg.Discard() + start := time.Now() + // Track the amount of time it takes to serve the request and run the handler + if metrics.Enabled { + h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, ARB1, msg.Code) + defer func(start time.Time) { + sampler := func() metrics.Sample { + return metrics.NewBoundedHistogramSample() + } + metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds()) + }(start) + } + switch { + case msg.Code == GetLastConfirmedMsg: + confirmed, node, err := backend.LastConfirmed() + if err != nil || confirmed == nil { + return err + } + response := LastConfirmedMsgPacket{ + Header: confirmed, + Node: node, + } + return p2p.Send(peer.rw, LastConfirmedMsg, &response) + case msg.Code == LastConfirmedMsg: + var incoming LastConfirmedMsgPacket + err := msg.Decode(&incoming) + if err != nil { + return err + } + if incoming.Header == nil { + return nil + } + backend.HandleLastConfirmed(peer, incoming.Header, incoming.Node) + return nil + case msg.Code == GetLastCheckpointMsg: + checkpoint, err := backend.LastCheckpoint() + if err != nil { + return err + } + response := CheckpointMsgPacket{ + Header: checkpoint, + HasState: true, + } + return p2p.Send(peer.rw, CheckpointMsg, &response) + case msg.Code == CheckpointQueryMsg: + incoming := CheckpointQueryPacket{} + err := msg.Decode(&incoming) + if err != nil { + return err + } + hasState, err := backend.CheckpointSupported(incoming.Header) + if err != nil { + return err + } + response := CheckpointMsgPacket{ + Header: incoming.Header, + HasState: hasState, + } + return p2p.Send(peer.rw, CheckpointMsg, &response) + case msg.Code == CheckpointMsg: + incoming := CheckpointMsgPacket{} + err := msg.Decode(&incoming) + if err != nil { + return err + } + backend.HandleCheckpoint(peer, incoming.Header, incoming.HasState) + return nil + } + return fmt.Errorf("Invalid message: %v", msg.Code) +} diff --git a/eth/protocols/arb/protocol.go b/eth/protocols/arb/protocol.go new file mode 100644 index 0000000000..62a377f104 --- /dev/null +++ b/eth/protocols/arb/protocol.go @@ -0,0 +1,103 @@ +package arb + +import ( + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// Constants to match up protocol versions and messages +const ( + ARB1 = 1 +) + +// ProtocolName is the official short name of the `snap` protocol used during +// devp2p capability negotiation. +const ProtocolName = "arb" + +// ProtocolVersions are the supported versions of the `snap` protocol (first +// is primary). +var ProtocolVersions = []uint{ARB1} + +// protocolLengths are the number of implemented message corresponding to +// different protocol versions. +var protocolLengths = map[uint]uint64{ARB1: ProtocolLenArb1} + +const ( + GetLastConfirmedMsg = 0x00 + LastConfirmedMsg = 0x01 + GetLastCheckpointMsg = 0x02 + CheckpointQueryMsg = 0x03 + CheckpointMsg = 0x04 + ProtocolLenArb1 = 5 +) + +type LastConfirmedMsgPacket struct { + Header *types.Header + Node uint64 +} + +type CheckpointMsgPacket struct { + Header *types.Header + HasState bool +} + +type CheckpointQueryPacket struct { + Header *types.Header +} + +// NodeInfo represents a short summary of the `arb` sub-protocol metadata +// known about the host peer. +type NodeInfo struct{} + +// nodeInfo retrieves some `arb` protocol metadata about the running host node. +func nodeInfo() *NodeInfo { + return &NodeInfo{} +} + +type Handler func(peer *Peer) error + +// Backend defines the data retrieval methods to serve remote requests and the +// callback methods to invoke on remote deliveries. +type Backend interface { + PeerInfo(id enode.ID) interface{} + HandleLastConfirmed(peer *Peer, confirmed *types.Header, node uint64) + HandleCheckpoint(peer *Peer, header *types.Header, supported bool) + LastConfirmed() (*types.Header, uint64, error) + LastCheckpoint() (*types.Header, error) + CheckpointSupported(*types.Header) (bool, error) + // RunPeer is invoked when a peer joins on the `eth` protocol. The handler + // should do any peer maintenance work, handshakes and validations. If all + // is passed, control should be given back to the `handler` to process the + // inbound messages going forward. + RunPeer(peer *Peer, handler Handler) error +} + +func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol { + protocols := make([]p2p.Protocol, len(ProtocolVersions)) + for i, version := range ProtocolVersions { + version := version // Closure + + protocols[i] = p2p.Protocol{ + Name: ProtocolName, + Version: version, + Length: protocolLengths[version], + Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw) + return backend.RunPeer(peer, func(peer *Peer) error { + return Handle(backend, peer) + }) + }, + NodeInfo: func() interface{} { + return nodeInfo() + }, + PeerInfo: func(id enode.ID) interface{} { + return backend.PeerInfo(id) + }, + Attributes: []enr.Entry{&enrEntry{}}, + DialCandidates: dnsdisc, + } + } + return protocols +} diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 968fcfbfa5..06993b4905 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -22,7 +22,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -64,7 +64,11 @@ type Handler func(peer *Peer) error // callback methods to invoke on remote deliveries. type Backend interface { // Chain retrieves the blockchain object to serve data. - Chain() *core.BlockChain + ContractCodeWithPrefix(codeHash common.Hash) ([]byte, error) + TrieDB() *trie.Database + Snapshot(root common.Hash) snapshot.Snapshot + AccountIterator(root, account common.Hash) (snapshot.AccountIterator, error) + StorageIterator(root, account, origin common.Hash) (snapshot.StorageIterator, error) // RunPeer is invoked when a peer joins on the `eth` protocol. The handler // should do any peer maintenance work, handshakes and validations. If all @@ -84,10 +88,12 @@ type Backend interface { // MakeProtocols constructs the P2P protocol definitions for `snap`. func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol { // Filter the discovery iterator for nodes advertising snap support. - dnsdisc = enode.Filter(dnsdisc, func(n *enode.Node) bool { - var snap enrEntry - return n.Load(&snap) == nil - }) + if dnsdisc != nil { + dnsdisc = enode.Filter(dnsdisc, func(n *enode.Node) bool { + var snap enrEntry + return n.Load(&snap) == nil + }) + } protocols := make([]p2p.Protocol, len(ProtocolVersions)) for i, version := range ProtocolVersions { @@ -103,7 +109,7 @@ func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol { }) }, NodeInfo: func() interface{} { - return nodeInfo(backend.Chain()) + return nodeInfo() }, PeerInfo: func(id enode.ID) interface{} { return backend.PeerInfo(id) @@ -159,7 +165,7 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } // Service the request, potentially returning nothing in case of errors - accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req) + accounts, proofs := ServiceGetAccountRangeQuery(backend, &req) // Send back anything accumulated (or empty in case of errors) return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ @@ -191,7 +197,7 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } // Service the request, potentially returning nothing in case of errors - slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req) + slots, proofs := ServiceGetStorageRangesQuery(backend, &req) // Send back anything accumulated (or empty in case of errors) return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ @@ -225,7 +231,7 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } // Service the request, potentially returning nothing in case of errors - codes := ServiceGetByteCodesQuery(backend.Chain(), &req) + codes := ServiceGetByteCodesQuery(backend, &req) // Send back anything accumulated (or empty in case of errors) return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ @@ -250,7 +256,7 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } // Service the request, potentially returning nothing in case of errors - nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req, start) + nodes, err := ServiceGetTrieNodesQuery(backend, &req, start) if err != nil { return err } @@ -277,16 +283,16 @@ func HandleMessage(backend Backend, peer *Peer) error { // ServiceGetAccountRangeQuery assembles the response to an account range query. // It is exposed to allow external packages to test protocol behavior. -func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { +func ServiceGetAccountRangeQuery(backend Backend, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { if req.Bytes > softResponseLimit { req.Bytes = softResponseLimit } // Retrieve the requested state and bail out if non existent - tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB()) + tr, err := trie.New(trie.StateTrieID(req.Root), backend.TrieDB()) if err != nil { return nil, nil } - it, err := chain.Snapshots().AccountIterator(req.Root, req.Origin) + it, err := backend.AccountIterator(req.Root, req.Origin) if err != nil { return nil, nil } @@ -337,7 +343,7 @@ func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePac return accounts, proofs } -func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { +func ServiceGetStorageRangesQuery(backend Backend, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { if req.Bytes > softResponseLimit { req.Bytes = softResponseLimit } @@ -370,7 +376,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP limit, req.Limit = common.BytesToHash(req.Limit), nil } // Retrieve the requested state and bail out if non existent - it, err := chain.Snapshots().StorageIterator(req.Root, account, origin) + it, err := backend.StorageIterator(req.Root, account, origin) if err != nil { return nil, nil } @@ -412,7 +418,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP if origin != (common.Hash{}) || (abort && len(storage) > 0) { // Request started at a non-zero hash or was capped prematurely, add // the endpoint Merkle proofs - accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB()) + accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), backend.TrieDB()) if err != nil { return nil, nil } @@ -421,7 +427,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP return nil, nil } id := trie.StorageTrieID(req.Root, account, acc.Root) - stTrie, err := trie.NewStateTrie(id, chain.TrieDB()) + stTrie, err := trie.NewStateTrie(id, backend.TrieDB()) if err != nil { return nil, nil } @@ -450,7 +456,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP // ServiceGetByteCodesQuery assembles the response to a byte codes query. // It is exposed to allow external packages to test protocol behavior. -func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte { +func ServiceGetByteCodesQuery(backend Backend, req *GetByteCodesPacket) [][]byte { if req.Bytes > softResponseLimit { req.Bytes = softResponseLimit } @@ -467,7 +473,7 @@ func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [ // Peers should not request the empty code, but if they do, at // least sent them back a correct response without db lookups codes = append(codes, []byte{}) - } else if blob, err := chain.ContractCodeWithPrefix(hash); err == nil { + } else if blob, err := backend.ContractCodeWithPrefix(hash); err == nil { codes = append(codes, blob) bytes += uint64(len(blob)) } @@ -480,12 +486,12 @@ func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [ // ServiceGetTrieNodesQuery assembles the response to a trie nodes query. // It is exposed to allow external packages to test protocol behavior. -func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) { +func ServiceGetTrieNodesQuery(backend Backend, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) { if req.Bytes > softResponseLimit { req.Bytes = softResponseLimit } // Make sure we have the state associated with the request - triedb := chain.TrieDB() + triedb := backend.TrieDB() accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb) if err != nil { @@ -493,7 +499,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s return nil, nil } // The 'snap' might be nil, in which case we cannot serve storage slots. - snap := chain.Snapshots().Snapshot(req.Root) + snap := backend.Snapshot(req.Root) // Retrieve trie nodes until the packet size limit is reached var ( nodes [][]byte @@ -570,6 +576,6 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s type NodeInfo struct{} // nodeInfo retrieves some `snap` protocol metadata about the running host node. -func nodeInfo(chain *core.BlockChain) *NodeInfo { +func nodeInfo() *NodeInfo { return &NodeInfo{} } diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go index ddc7a44a2a..c6a2f6950b 100644 --- a/eth/protocols/snap/handler_fuzzing_test.go +++ b/eth/protocols/snap/handler_fuzzing_test.go @@ -28,11 +28,13 @@ import ( "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" fuzz "github.com/google/gofuzz" ) @@ -136,6 +138,26 @@ type dummyBackend struct { chain *core.BlockChain } +func (d *dummyBackend) ContractCodeWithPrefix(codeHash common.Hash) ([]byte, error) { + return d.chain.ContractCodeWithPrefix(codeHash) +} + +func (d *dummyBackend) TrieDB() *trie.Database { + return d.chain.TrieDB() +} + +func (d *dummyBackend) Snapshot(root common.Hash) snapshot.Snapshot { + return d.chain.Snapshots().Snapshot(root) +} + +func (d *dummyBackend) AccountIterator(root, account common.Hash) (snapshot.AccountIterator, error) { + return d.chain.Snapshots().AccountIterator(root, account) +} + +func (d *dummyBackend) StorageIterator(root, account, origin common.Hash) (snapshot.StorageIterator, error) { + return d.chain.Snapshots().StorageIterator(root, account, origin) +} + func (d *dummyBackend) Chain() *core.BlockChain { return d.chain } func (d *dummyBackend) RunPeer(*Peer, Handler) error { return nil } func (d *dummyBackend) PeerInfo(enode.ID) interface{} { return "Foo" }