Skip to content

Commit

Permalink
Separate dialing and tracking from p2p [#3443]
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Oct 9, 2023
1 parent 8930c82 commit c75200b
Show file tree
Hide file tree
Showing 22 changed files with 642 additions and 720 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: ^(dialerHost|dialerPeers|MessageStreamHandler)$
name: ^(Connector|Discoverer|MessageStreamHandler)$
inpackage: true
testonly: true
with-expecter: true
Expand Down
86 changes: 21 additions & 65 deletions pkg/api/v3/p2p/dial_test.go → pkg/api/v3/p2p/dial/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

package p2p
package dial

import (
"context"
Expand All @@ -29,9 +29,9 @@ import (

type fakeStream struct{}

func (fakeStream) Read([]byte) (int, error) { return 0, io.EOF }
func (fakeStream) Write([]byte) (int, error) { return 0, io.EOF }
func (fakeStream) Close() error { return nil }
func (fakeStream) Read() (message.Message, error) { return nil, io.EOF }
func (fakeStream) Write(message.Message) error { return io.EOF }
func (fakeStream) Close() error { return nil }

func peerId(t testing.TB, s string) peer.ID {
id, err := peer.Decode(s)
Expand All @@ -54,12 +54,10 @@ func TestDialAddress(t *testing.T) {
pid := peerId(t, "QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N")
ch := make(chan peer.AddrInfo, 1)

host := newMockDialerHost(t)
host.EXPECT().selfID().Return("").Maybe()
host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(nil, false).Maybe()
host.EXPECT().getPeerService(mock.Anything, mock.Anything, mock.Anything).Return(fakeStream{}, nil).Maybe()
peers := newMockDialerPeers(t)
peers.EXPECT().getPeers(mock.Anything, mock.Anything, mock.Anything).Return(ch, nil).Run(func(context.Context, multiaddr.Multiaddr, int) {
host := NewMockConnector(t)
host.EXPECT().Connect(mock.Anything, mock.Anything).Return(fakeStream{}, nil).Maybe()
peers := NewMockDiscoverer(t)
peers.EXPECT().Discover(mock.Anything, mock.Anything).Return(DiscoveredPeers(ch), nil).Run(func(context.Context, *DiscoveryRequest) {
ch <- peer.AddrInfo{ID: pid}
}).Maybe()

Expand All @@ -76,7 +74,7 @@ func TestDialAddress(t *testing.T) {
"tcp-network-partition": {false, "/tcp/123/acc/foo/acc-svc/query:foo"},
}

dialer := &dialer{host: host, peers: peers, tracker: &simpleTracker{}}
dialer := &dialer{host: host, peers: peers, tracker: fakeTracker{}}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
_, err := dialer.Dial(context.Background(), addr(t, c.Addr))
Expand All @@ -89,40 +87,6 @@ func TestDialAddress(t *testing.T) {
}
}

func TestDialSelfPeer(t *testing.T) {
// Node dial requests that match the node's ID are processed directly

done := make(chan struct{})
handler := NewMockMessageStreamHandler(t)
handler.EXPECT().Execute(mock.Anything).Run(func(message.Stream) { close(done) })

pid := peerId(t, "QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N")
host := newMockDialerHost(t)
host.EXPECT().selfID().Return(pid)
host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(&serviceHandler{handler: handler.Execute}, true)

dialer := &dialer{host: host, peers: nil, tracker: &simpleTracker{}}
_, err := dialer.Dial(context.Background(), addr(t, "/acc/foo/acc-svc/query:foo/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N"))
require.NoError(t, err)
<-done
}

func TestDialSelfPartition(t *testing.T) {
// Partition dial requests that match a partition the node participates in are processed directly

done := make(chan struct{})
handler := NewMockMessageStreamHandler(t)
handler.EXPECT().Execute(mock.Anything).Run(func(message.Stream) { close(done) })

host := newMockDialerHost(t)
host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(&serviceHandler{handler: handler.Execute}, true)

dialer := &dialer{host: host, peers: nil, tracker: &simpleTracker{}}
_, err := dialer.Dial(context.Background(), addr(t, "/acc/foo/acc-svc/query:foo"))
require.NoError(t, err)
<-done
}

func newPeer(t *testing.T, seed ...any) peer.ID {
h := storage.MakeKey(seed...)
std := ed25519.NewKeyFromSeed(h[:])
Expand Down Expand Up @@ -177,7 +141,7 @@ func TestDialServices1(t *testing.T) {
}

// Setup the peers mock
peers := staticPeers(func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo {
peers := staticPeers(func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo {
var all []peer.AddrInfo
for _, p := range goodPeerIDs {
all = append(all, peer.AddrInfo{ID: p})
Expand All @@ -188,7 +152,7 @@ func TestDialServices1(t *testing.T) {
return all
})

tracker := new(simpleTracker)
tracker := new(SimpleTracker)
dialer := &dialer{host: host, peers: peers, tracker: tracker}

start := time.Now()
Expand Down Expand Up @@ -216,19 +180,19 @@ func TestDialServices1(t *testing.T) {
}
}

type staticPeers func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo
type staticPeers func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo

func (f staticPeers) getPeers(ctx context.Context, ma multiaddr.Multiaddr, limit int) (<-chan peer.AddrInfo, error) {
list := f(ctx, ma)
if len(list) > limit {
list = list[:limit]
func (f staticPeers) Discover(ctx context.Context, req *DiscoveryRequest) (DiscoveryResponse, error) {
list := f(ctx, req)
if len(list) > req.Limit {
list = list[:req.Limit]
}
ch := make(chan peer.AddrInfo, len(list))
for _, x := range list {
ch <- x
}
close(ch)
return ch, nil
return DiscoveredPeers(ch), nil
}

type fakeHost struct {
Expand Down Expand Up @@ -256,18 +220,10 @@ func (h *fakeHost) clear(id string) {
delete(h.good, id)
}

func (h *fakeHost) selfID() peer.ID {
return h.peerID
}

func (h *fakeHost) getOwnService(network string, sa *api.ServiceAddress) (*serviceHandler, bool) {
return nil, false
}

func (h *fakeHost) getPeerService(ctx context.Context, peer peer.ID, service *api.ServiceAddress) (io.ReadWriteCloser, error) {
func (h *fakeHost) Connect(ctx context.Context, req *ConnectionRequest) (message.Stream, error) {
h.RLock()
defer h.RUnlock()
if h.good[peer.String()+"|"+service.String()] {
if h.good[req.PeerID.String()+"|"+req.Service.String()] {
return fakeStream{}, nil
}
// The specific error doesn't matter but this one won't get logged
Expand Down Expand Up @@ -312,7 +268,7 @@ func TestDialServices2(t *testing.T) {
}

// Setup the peers mock
peers := staticPeers(func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo {
peers := staticPeers(func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo {
var all []peer.AddrInfo
for _, p := range goodPeerIDs {
all = append(all, peer.AddrInfo{ID: p})
Expand All @@ -323,7 +279,7 @@ func TestDialServices2(t *testing.T) {
return all
})

dialer := &dialer{host: host, peers: peers, tracker: &simpleTracker{}}
dialer := &dialer{host: host, peers: peers, tracker: fakeTracker{}}

start := time.Now()
service := services[0]
Expand Down
Loading

0 comments on commit c75200b

Please sign in to comment.