diff --git a/.changeset/breezy-suits-float.md b/.changeset/breezy-suits-float.md new file mode 100644 index 00000000000..60e061223d8 --- /dev/null +++ b/.changeset/breezy-suits-float.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#added address book remove feature diff --git a/deployment/address_book.go b/deployment/address_book.go index 8385bc0e9f1..076e2a235d6 100644 --- a/deployment/address_book.go +++ b/deployment/address_book.go @@ -3,6 +3,9 @@ package deployment import ( "fmt" "strings" + "sync" + + "golang.org/x/exp/maps" "github.com/Masterminds/semver/v3" "github.com/ethereum/go-ethereum/common" @@ -82,14 +85,16 @@ type AddressBook interface { AddressesForChain(chain uint64) (map[string]TypeAndVersion, error) // Allows for merging address books (e.g. new deployments with existing ones) Merge(other AddressBook) error + Remove(ab AddressBook) error } type AddressBookMap struct { - AddressesByChain map[uint64]map[string]TypeAndVersion + addressesByChain map[uint64]map[string]TypeAndVersion + mtx sync.RWMutex } -// Save will save an address for a given chain selector. It will error if there is a conflicting existing address. -func (m *AddressBookMap) Save(chainSelector uint64, address string, typeAndVersion TypeAndVersion) error { +// save will save an address for a given chain selector. It will error if there is a conflicting existing address. +func (m *AddressBookMap) save(chainSelector uint64, address string, typeAndVersion TypeAndVersion) error { _, exists := chainsel.ChainBySelector(chainSelector) if !exists { return errors.Wrapf(ErrInvalidChainSelector, "chain selector %d", chainSelector) @@ -106,19 +111,34 @@ func (m *AddressBookMap) Save(chainSelector uint64, address string, typeAndVersi if typeAndVersion.Type == "" { return fmt.Errorf("type cannot be empty") } - if _, exists := m.AddressesByChain[chainSelector]; !exists { + + if _, exists := m.addressesByChain[chainSelector]; !exists { // First time chain add, create map - m.AddressesByChain[chainSelector] = make(map[string]TypeAndVersion) + m.addressesByChain[chainSelector] = make(map[string]TypeAndVersion) } - if _, exists := m.AddressesByChain[chainSelector][address]; exists { + if _, exists := m.addressesByChain[chainSelector][address]; exists { return fmt.Errorf("address %s already exists for chain %d", address, chainSelector) } - m.AddressesByChain[chainSelector][address] = typeAndVersion + m.addressesByChain[chainSelector][address] = typeAndVersion return nil } +// Save will save an address for a given chain selector. It will error if there is a conflicting existing address. +// thread safety version of the save method +func (m *AddressBookMap) Save(chainSelector uint64, address string, typeAndVersion TypeAndVersion) error { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.save(chainSelector, address, typeAndVersion) +} + func (m *AddressBookMap) Addresses() (map[uint64]map[string]TypeAndVersion, error) { - return m.AddressesByChain, nil + m.mtx.RLock() + defer m.mtx.RUnlock() + + // maps are mutable and pass via a pointer + // creating a copy of the map to prevent concurrency + // read and changes outside object-bound + return m.cloneAddresses(m.addressesByChain), nil } func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]TypeAndVersion, error) { @@ -126,10 +146,18 @@ func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]Typ if !exists { return nil, errors.Wrapf(ErrInvalidChainSelector, "chain selector %d", chainSelector) } - if _, exists := m.AddressesByChain[chainSelector]; !exists { + + m.mtx.RLock() + defer m.mtx.RUnlock() + + if _, exists := m.addressesByChain[chainSelector]; !exists { return nil, errors.Wrapf(ErrChainNotFound, "chain selector %d", chainSelector) } - return m.AddressesByChain[chainSelector], nil + + // maps are mutable and pass via a pointer + // creating a copy of the map to prevent concurrency + // read and changes outside object-bound + return maps.Clone(m.addressesByChain[chainSelector]), nil } // Merge will merge the addresses from another address book into this one. @@ -139,9 +167,13 @@ func (m *AddressBookMap) Merge(ab AddressBook) error { if err != nil { return err } - for chain, chainAddresses := range addresses { - for address, typeAndVersions := range chainAddresses { - if err := m.Save(chain, address, typeAndVersions); err != nil { + + m.mtx.Lock() + defer m.mtx.Unlock() + + for chainSelector, chainAddresses := range addresses { + for address, typeAndVersion := range chainAddresses { + if err := m.save(chainSelector, address, typeAndVersion); err != nil { return err } } @@ -149,18 +181,57 @@ func (m *AddressBookMap) Merge(ab AddressBook) error { return nil } +// Remove removes the address book addresses specified via the argument from the AddressBookMap. +// Errors if all the addresses in the given address book are not contained in the AddressBookMap. +func (m *AddressBookMap) Remove(ab AddressBook) error { + addresses, err := ab.Addresses() + if err != nil { + return err + } + + m.mtx.Lock() + defer m.mtx.Unlock() + + // State of m.addressesByChain storage must not be changed in case of an error + // need to do double iteration over the address book. First validation, second actual deletion + for chainSelector, chainAddresses := range addresses { + for address, _ := range chainAddresses { + if _, exists := m.addressesByChain[chainSelector][address]; !exists { + return errors.New("AddressBookMap does not contain address from the given address book") + } + } + } + + for chainSelector, chainAddresses := range addresses { + for address, _ := range chainAddresses { + delete(m.addressesByChain[chainSelector], address) + } + } + + return nil +} + +// cloneAddresses creates a deep copy of map[uint64]map[string]TypeAndVersion object +func (m *AddressBookMap) cloneAddresses(input map[uint64]map[string]TypeAndVersion) map[uint64]map[string]TypeAndVersion { + result := make(map[uint64]map[string]TypeAndVersion) + for chainSelector, chainAddresses := range input { + result[chainSelector] = maps.Clone(chainAddresses) + } + return result +} + // TODO: Maybe could add an environment argument // which would ensure only mainnet/testnet chain selectors are used // for further safety? func NewMemoryAddressBook() *AddressBookMap { return &AddressBookMap{ - AddressesByChain: make(map[uint64]map[string]TypeAndVersion), + addressesByChain: make(map[uint64]map[string]TypeAndVersion), } } func NewMemoryAddressBookFromMap(addressesByChain map[uint64]map[string]TypeAndVersion) *AddressBookMap { return &AddressBookMap{ - AddressesByChain: addressesByChain, + addressesByChain: addressesByChain, } } diff --git a/deployment/address_book_test.go b/deployment/address_book_test.go index bf3d2ad965c..9040902a169 100644 --- a/deployment/address_book_test.go +++ b/deployment/address_book_test.go @@ -2,6 +2,8 @@ package deployment import ( "errors" + "math/big" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -118,3 +120,122 @@ func TestAddressBook_Merge(t *testing.T) { }, }) } + +func TestAddressBook_Remove(t *testing.T) { + onRamp100 := NewTypeAndVersion("OnRamp", Version1_0_0) + onRamp110 := NewTypeAndVersion("OnRamp", Version1_1_0) + addr1 := common.HexToAddress("0x1").String() + addr2 := common.HexToAddress("0x2").String() + addr3 := common.HexToAddress("0x3").String() + + baseAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000001.Selector: { + addr1: onRamp100, + addr2: onRamp100, + }, + chainsel.TEST_90000002.Selector: { + addr1: onRamp110, + addr3: onRamp110, + }, + }) + + copyOfBaseAB := NewMemoryAddressBookFromMap(baseAB.cloneAddresses(baseAB.addressesByChain)) + + // this address book shouldn't be removed (state of baseAB not changed, error thrown) + failAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000001.Selector: { + addr1: onRamp100, + addr3: onRamp100, // doesn't exist in TEST_90000001.Selector + }, + }) + require.Error(t, baseAB.Remove(failAB)) + require.EqualValues(t, baseAB, copyOfBaseAB) + + // this Address book should be removed without error + successAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000002.Selector: { + addr3: onRamp100, + }, + chainsel.TEST_90000001.Selector: { + addr2: onRamp100, + }, + }) + + expectingAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000001.Selector: { + addr1: onRamp100, + }, + chainsel.TEST_90000002.Selector: { + addr1: onRamp110}, + }) + + require.NoError(t, baseAB.Remove(successAB)) + require.EqualValues(t, baseAB, expectingAB) +} + +func TestAddressBook_ConcurrencyAndDeadlock(t *testing.T) { + onRamp100 := NewTypeAndVersion("OnRamp", Version1_0_0) + onRamp110 := NewTypeAndVersion("OnRamp", Version1_1_0) + + baseAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000001.Selector: { + common.BigToAddress(big.NewInt(1)).String(): onRamp100, + }, + }) + + // concurrent writes + var i int64 + wg := sync.WaitGroup{} + for i = 2; i < 1000; i++ { + wg.Add(1) + go func(input int64) { + require.NoError(t, baseAB.Save( + chainsel.TEST_90000001.Selector, + common.BigToAddress(big.NewInt(input)).String(), + onRamp100, + )) + wg.Done() + }(i) + } + + // concurrent reads + for i = 0; i < 100; i++ { + wg.Add(1) + go func(input int64) { + addresses, err := baseAB.Addresses() + require.NoError(t, err) + for chainSelector, chainAddresses := range addresses { + // concurrent read chainAddresses from Addresses() method + for address, _ := range chainAddresses { + addresses[chainSelector][address] = onRamp110 + } + + // concurrent read chainAddresses from AddressesForChain() method + chainAddresses, err = baseAB.AddressesForChain(chainSelector) + require.NoError(t, err) + for address, _ := range chainAddresses { + _ = addresses[chainSelector][address] + } + } + require.NoError(t, err) + wg.Done() + }(i) + } + + // concurrent merges, starts from 1001 to avoid address conflicts + for i = 1001; i < 1100; i++ { + wg.Add(1) + go func(input int64) { + // concurrent merge + additionalAB := NewMemoryAddressBookFromMap(map[uint64]map[string]TypeAndVersion{ + chainsel.TEST_90000002.Selector: { + common.BigToAddress(big.NewInt(input)).String(): onRamp100, + }, + }) + require.NoError(t, baseAB.Merge(additionalAB)) + wg.Done() + }(i) + } + + wg.Wait() +}