diff --git a/pkg/api/v3/p2p/.mockery.yml b/pkg/api/v3/p2p/dial/.mockery.yml similarity index 54% rename from pkg/api/v3/p2p/.mockery.yml rename to pkg/api/v3/p2p/dial/.mockery.yml index 3d946f5ce..369f0476f 100644 --- a/pkg/api/v3/p2p/.mockery.yml +++ b/pkg/api/v3/p2p/dial/.mockery.yml @@ -1,4 +1,4 @@ -name: ^(dialerHost|dialerPeers|MessageStreamHandler)$ +name: ^(Connector|Discoverer|MessageStreamHandler)$ inpackage: true testonly: true with-expecter: true diff --git a/pkg/api/v3/p2p/dial_test.go b/pkg/api/v3/p2p/dial/dial_test.go similarity index 73% rename from pkg/api/v3/p2p/dial_test.go rename to pkg/api/v3/p2p/dial/dial_test.go index 9688a20b9..a4389d384 100644 --- a/pkg/api/v3/p2p/dial_test.go +++ b/pkg/api/v3/p2p/dial/dial_test.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -package p2p +package dial import ( "context" @@ -29,9 +29,9 @@ import ( type fakeStream struct{} -func (fakeStream) Read([]byte) (int, error) { return 0, io.EOF } -func (fakeStream) Write([]byte) (int, error) { return 0, io.EOF } -func (fakeStream) Close() error { return nil } +func (fakeStream) Read() (message.Message, error) { return nil, io.EOF } +func (fakeStream) Write(message.Message) error { return io.EOF } +func (fakeStream) Close() error { return nil } func peerId(t testing.TB, s string) peer.ID { id, err := peer.Decode(s) @@ -54,12 +54,10 @@ func TestDialAddress(t *testing.T) { pid := peerId(t, "QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N") ch := make(chan peer.AddrInfo, 1) - host := newMockDialerHost(t) - host.EXPECT().selfID().Return("").Maybe() - host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(nil, false).Maybe() - host.EXPECT().getPeerService(mock.Anything, mock.Anything, mock.Anything).Return(fakeStream{}, nil).Maybe() - peers := newMockDialerPeers(t) - peers.EXPECT().getPeers(mock.Anything, mock.Anything, mock.Anything).Return(ch, nil).Run(func(context.Context, multiaddr.Multiaddr, int) { + host := NewMockConnector(t) + host.EXPECT().Connect(mock.Anything, mock.Anything).Return(fakeStream{}, nil).Maybe() + peers := NewMockDiscoverer(t) + peers.EXPECT().Discover(mock.Anything, mock.Anything).Return(DiscoveredPeers(ch), nil).Run(func(context.Context, *DiscoveryRequest) { ch <- peer.AddrInfo{ID: pid} }).Maybe() @@ -76,7 +74,7 @@ func TestDialAddress(t *testing.T) { "tcp-network-partition": {false, "/tcp/123/acc/foo/acc-svc/query:foo"}, } - dialer := &dialer{host: host, peers: peers, tracker: &simpleTracker{}} + dialer := &dialer{host: host, peers: peers, tracker: fakeTracker{}} for name, c := range cases { t.Run(name, func(t *testing.T) { _, err := dialer.Dial(context.Background(), addr(t, c.Addr)) @@ -89,40 +87,6 @@ func TestDialAddress(t *testing.T) { } } -func TestDialSelfPeer(t *testing.T) { - // Node dial requests that match the node's ID are processed directly - - done := make(chan struct{}) - handler := NewMockMessageStreamHandler(t) - handler.EXPECT().Execute(mock.Anything).Run(func(message.Stream) { close(done) }) - - pid := peerId(t, "QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N") - host := newMockDialerHost(t) - host.EXPECT().selfID().Return(pid) - host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(&serviceHandler{handler: handler.Execute}, true) - - dialer := &dialer{host: host, peers: nil, tracker: &simpleTracker{}} - _, err := dialer.Dial(context.Background(), addr(t, "/acc/foo/acc-svc/query:foo/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N")) - require.NoError(t, err) - <-done -} - -func TestDialSelfPartition(t *testing.T) { - // Partition dial requests that match a partition the node participates in are processed directly - - done := make(chan struct{}) - handler := NewMockMessageStreamHandler(t) - handler.EXPECT().Execute(mock.Anything).Run(func(message.Stream) { close(done) }) - - host := newMockDialerHost(t) - host.EXPECT().getOwnService(mock.Anything, mock.Anything).Return(&serviceHandler{handler: handler.Execute}, true) - - dialer := &dialer{host: host, peers: nil, tracker: &simpleTracker{}} - _, err := dialer.Dial(context.Background(), addr(t, "/acc/foo/acc-svc/query:foo")) - require.NoError(t, err) - <-done -} - func newPeer(t *testing.T, seed ...any) peer.ID { h := storage.MakeKey(seed...) std := ed25519.NewKeyFromSeed(h[:]) @@ -177,7 +141,7 @@ func TestDialServices1(t *testing.T) { } // Setup the peers mock - peers := staticPeers(func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo { + peers := staticPeers(func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo { var all []peer.AddrInfo for _, p := range goodPeerIDs { all = append(all, peer.AddrInfo{ID: p}) @@ -188,7 +152,7 @@ func TestDialServices1(t *testing.T) { return all }) - tracker := new(simpleTracker) + tracker := new(SimpleTracker) dialer := &dialer{host: host, peers: peers, tracker: tracker} start := time.Now() @@ -216,19 +180,19 @@ func TestDialServices1(t *testing.T) { } } -type staticPeers func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo +type staticPeers func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo -func (f staticPeers) getPeers(ctx context.Context, ma multiaddr.Multiaddr, limit int) (<-chan peer.AddrInfo, error) { - list := f(ctx, ma) - if len(list) > limit { - list = list[:limit] +func (f staticPeers) Discover(ctx context.Context, req *DiscoveryRequest) (DiscoveryResponse, error) { + list := f(ctx, req) + if len(list) > req.Limit { + list = list[:req.Limit] } ch := make(chan peer.AddrInfo, len(list)) for _, x := range list { ch <- x } close(ch) - return ch, nil + return DiscoveredPeers(ch), nil } type fakeHost struct { @@ -256,18 +220,10 @@ func (h *fakeHost) clear(id string) { delete(h.good, id) } -func (h *fakeHost) selfID() peer.ID { - return h.peerID -} - -func (h *fakeHost) getOwnService(network string, sa *api.ServiceAddress) (*serviceHandler, bool) { - return nil, false -} - -func (h *fakeHost) getPeerService(ctx context.Context, peer peer.ID, service *api.ServiceAddress) (io.ReadWriteCloser, error) { +func (h *fakeHost) Connect(ctx context.Context, req *ConnectionRequest) (message.Stream, error) { h.RLock() defer h.RUnlock() - if h.good[peer.String()+"|"+service.String()] { + if h.good[req.PeerID.String()+"|"+req.Service.String()] { return fakeStream{}, nil } // The specific error doesn't matter but this one won't get logged @@ -312,7 +268,7 @@ func TestDialServices2(t *testing.T) { } // Setup the peers mock - peers := staticPeers(func(ctx context.Context, ma multiaddr.Multiaddr) []peer.AddrInfo { + peers := staticPeers(func(ctx context.Context, req *DiscoveryRequest) []peer.AddrInfo { var all []peer.AddrInfo for _, p := range goodPeerIDs { all = append(all, peer.AddrInfo{ID: p}) @@ -323,7 +279,7 @@ func TestDialServices2(t *testing.T) { return all }) - dialer := &dialer{host: host, peers: peers, tracker: &simpleTracker{}} + dialer := &dialer{host: host, peers: peers, tracker: fakeTracker{}} start := time.Now() service := services[0] diff --git a/pkg/api/v3/p2p/dial.go b/pkg/api/v3/p2p/dial/dialer.go similarity index 70% rename from pkg/api/v3/p2p/dial.go rename to pkg/api/v3/p2p/dial/dialer.go index ac90f2021..c4e73d4a3 100644 --- a/pkg/api/v3/p2p/dial.go +++ b/pkg/api/v3/p2p/dial/dialer.go @@ -4,11 +4,10 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -package p2p +package dial import ( "context" - "io" "runtime/debug" "sync" "time" @@ -23,39 +22,42 @@ import ( "golang.org/x/exp/slog" ) -// dialer implements [message.MultiDialer]. -type dialer struct { - host dialerHost - peers dialerPeers - tracker peerTracker +type Discoverer interface { + Discover(context.Context, *DiscoveryRequest) (DiscoveryResponse, error) } -var _ message.MultiDialer = (*dialer)(nil) +type DiscoveryRequest struct { + Network string + Service *api.ServiceAddress + Limit int +} -// dialerHost are the parts of [Node] required by [dialer]. dialerHost exists to -// support testing with mocks. -type dialerHost interface { - selfID() peer.ID - getOwnService(network string, sa *api.ServiceAddress) (*serviceHandler, bool) - getPeerService(ctx context.Context, peer peer.ID, service *api.ServiceAddress) (io.ReadWriteCloser, error) +type DiscoveryResponse interface { + isDiscoveryResponse() } -// dialerPeers are the parts of [peerManager] required by [dialer]. dialerPeers -// exists to support testing with mocks. -type dialerPeers interface { - getPeers(ctx context.Context, ma multiaddr.Multiaddr, limit int) (<-chan peer.AddrInfo, error) +type DiscoveredPeers <-chan peer.AddrInfo +type DiscoveredLocal func(context.Context) (message.Stream, error) + +func (DiscoveredPeers) isDiscoveryResponse() {} +func (DiscoveredLocal) isDiscoveryResponse() {} + +type Tracker interface { + Mark(peer peer.ID, service multiaddr.Multiaddr, status api.KnownPeerStatus) + Status(peer peer.ID, service multiaddr.Multiaddr) api.KnownPeerStatus + Next(service multiaddr.Multiaddr, status api.KnownPeerStatus) (peer.ID, bool) + All(service multiaddr.Multiaddr, status api.KnownPeerStatus) []peer.ID } -// DialNetwork returns a [message.MultiDialer] that opens a stream to a node -// that can provides a given service. -func (n *Node) DialNetwork() message.MultiDialer { - return &dialer{ - host: n, - peers: n.peermgr, - tracker: n.tracker, - } +// dialer implements [message.MultiDialer]. +type dialer struct { + host Connector + peers Discoverer + tracker Tracker } +var _ message.MultiDialer = (*dialer)(nil) + // Dial dials the given address. The address must include an /acc component and // may include a /p2p component. Dial will return an error if the address // includes any other components. @@ -75,7 +77,11 @@ func (d *dialer) Dial(ctx context.Context, addr multiaddr.Multiaddr) (stream mes return d.newNetworkStream(ctx, sa, net, nil) } - return d.newPeerStream(ctx, sa, peer) + // Open a new stream + return openStreamFor(ctx, d.host, &ConnectionRequest{ + Service: sa, + PeerID: peer, + }) } // BadDial notifies the dialer that a transport error was encountered while @@ -85,33 +91,10 @@ func (d *dialer) BadDial(ctx context.Context, addr multiaddr.Multiaddr, s messag if !ok { return false } - d.tracker.markBad(ctx, ss.peer, addr) + d.tracker.Mark(ss.peer, addr, api.PeerStatusIsKnownBad) return true } -// newPeerStream dials the given partition of the given peer. If the peer is the -// current node, newPeerStream returns a pipe and spawns a goroutine to handle -// it as if it were an incoming stream. -// -// If the peer ID does not match a peer known by the node, or if the node does -// not have an address for the given peer, newPeerStream will fail. -func (d *dialer) newPeerStream(ctx context.Context, sa *api.ServiceAddress, peer peer.ID) (message.Stream, error) { - // If the peer ID is our ID - if d.host.selfID() == peer { - // Check if we have the service - s, ok := d.host.getOwnService("", sa) - if !ok { - return nil, errors.NotFound // TODO return protocol not supported - } - - // Create a pipe and handle it - return handleLocally(ctx, s), nil - } - - // Open a new stream - return openStreamFor(ctx, d.host, peer, sa) -} - // newNetworkStream opens a stream to the highest priority peer that // participates in the given partition. If the current node participates in the // partition, newNetworkStream returns a pipe and spawns a goroutine to handle @@ -124,11 +107,6 @@ func (d *dialer) newPeerStream(ctx context.Context, sa *api.ServiceAddress, peer // // The wait group is only used for testing. func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddress, netName string, wg *sync.WaitGroup) (message.Stream, error) { - // Check if we participate in this partition - if h, ok := d.host.getOwnService(netName, service); ok { - return handleLocally(ctx, h), nil - } - // Construct an address for the service addr, err := service.MultiaddrFor(netName) if err != nil { @@ -138,11 +116,25 @@ func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddre // Query the DHT for peers that provide the service callCtx, cancel := context.WithCancel(ctx) defer cancel() - peers, err := d.peers.getPeers(callCtx, addr, 10) + resp, err := d.peers.Discover(callCtx, &DiscoveryRequest{ + Network: netName, + Service: service, + Limit: 10, + }) if err != nil { return nil, errors.UnknownError.Wrap(err) } + var peers <-chan peer.AddrInfo + switch resp := resp.(type) { + case DiscoveredLocal: + return resp(ctx) + case DiscoveredPeers: + peers = resp + default: + panic("invalid discovery response") + } + // Check the remaining unknown peers from the DHT (non-blocking) defer func() { for i := 0; i < 10; { @@ -152,7 +144,7 @@ func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddre return } - if d.tracker.status(ctx, peer.ID, addr) != api.PeerStatusIsUnknown { + if d.tracker.Status(peer.ID, addr) != api.PeerStatusIsUnknown { break } @@ -165,7 +157,7 @@ func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddre }() // If there are at least 4 known-good peers, try those - if len(d.tracker.allGood(ctx, addr)) >= 4 { + if len(d.tracker.All(addr, api.PeerStatusIsKnownGood)) >= 4 { s := d.dialFromTracker(ctx, service, addr, wg) if s != nil { return s, nil @@ -176,7 +168,7 @@ func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddre var bad []peer.ID for peer := range peers { // Skip known-bad peers - if d.tracker.status(ctx, peer.ID, addr) == api.PeerStatusIsKnownBad { + if d.tracker.Status(peer.ID, addr) == api.PeerStatusIsKnownBad { bad = append(bad, peer.ID) continue } @@ -207,7 +199,7 @@ func (d *dialer) newNetworkStream(ctx context.Context, service *api.ServiceAddre func (d *dialer) dialFromTracker(ctx context.Context, service *api.ServiceAddress, addr multiaddr.Multiaddr, wg *sync.WaitGroup) *stream { // Asynchronously retry a known-bad peer to check if it has recovered - if bad, ok := d.tracker.nextBad(ctx, addr); ok { + if bad, ok := d.tracker.Next(addr, api.PeerStatusIsKnownBad); ok { d.tryDial(bad, service, addr, wg) } @@ -216,7 +208,7 @@ func (d *dialer) dialFromTracker(ctx context.Context, service *api.ServiceAddres var first peer.ID for i := 0; i < 10; i++ { // Get the next - peer, ok := d.tracker.nextGood(ctx, addr) + peer, ok := d.tracker.Next(addr, api.PeerStatusIsKnownGood) if !ok { return nil } @@ -253,10 +245,7 @@ func (d *dialer) tryDial(peer peer.ID, service *api.ServiceAddress, addr multiad ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - s := d.dial(ctx, peer, service, addr) - if s != nil { - s.conn.Close() - } + d.dial(ctx, peer, service, addr) }() } @@ -269,16 +258,16 @@ func (d *dialer) dial(ctx context.Context, peer peer.ID, service *api.ServiceAdd }() // Open a stream - stream, err := openStreamFor(ctx, d.host, peer, service) + stream, err := openStreamFor(ctx, d.host, &ConnectionRequest{ + Service: service, + PeerID: peer, + }) if err == nil { // Mark the peer good - d.tracker.markGood(ctx, peer, addr) + d.tracker.Mark(peer, addr, api.PeerStatusIsKnownGood) return stream } - // Mark the peer bad - d.tracker.markBad(ctx, peer, addr) - // Log the error var timeoutError interface{ Timeout() bool } switch { @@ -289,18 +278,18 @@ func (d *dialer) dial(ctx context.Context, peer peer.ID, service *api.ServiceAdd case errors.Is(err, network.ErrNoConn), errors.Is(err, network.ErrNoRemoteAddrs): // Mark the peer as dead - d.tracker.markDead(ctx, peer) + d.tracker.Mark(peer, addr, api.PeerStatusIsUnknown) slog.InfoCtx(ctx, "Unable to dial peer", "peer", peer, "service", service, "error", err) case errors.Is(err, swarm.ErrDialBackoff), errors.As(err, &timeoutError) && timeoutError.Timeout(): // Mark the peer bad - d.tracker.markBad(ctx, peer, addr) + d.tracker.Mark(peer, addr, api.PeerStatusIsKnownBad) slog.DebugCtx(ctx, "Unable to dial peer", "peer", peer, "service", service, "error", err) default: // Mark the peer bad - d.tracker.markBad(ctx, peer, addr) + d.tracker.Mark(peer, addr, api.PeerStatusIsKnownBad) slog.WarnCtx(ctx, "Unknown error while dialing peer", "peer", peer, "service", service, "error", err) } return nil diff --git a/pkg/api/v3/p2p/dial/fake.go b/pkg/api/v3/p2p/dial/fake.go new file mode 100644 index 000000000..7ed0def97 --- /dev/null +++ b/pkg/api/v3/p2p/dial/fake.go @@ -0,0 +1,24 @@ +// Copyright 2023 The Accumulate Authors +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +package dial + +import ( + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiformats/go-multiaddr" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" +) + +var FakeTracker fakeTracker + +type fakeTracker struct{} + +var _ Tracker = fakeTracker{} + +func (fakeTracker) Mark(peer.ID, multiaddr.Multiaddr, api.KnownPeerStatus) {} +func (fakeTracker) Status(peer.ID, multiaddr.Multiaddr) api.KnownPeerStatus { return 0 } +func (fakeTracker) Next(multiaddr.Multiaddr, api.KnownPeerStatus) (peer.ID, bool) { return "", false } +func (fakeTracker) All(multiaddr.Multiaddr, api.KnownPeerStatus) []peer.ID { return nil } diff --git a/pkg/api/v3/p2p/dial/mock_Connector_test.go b/pkg/api/v3/p2p/dial/mock_Connector_test.go new file mode 100644 index 000000000..080959d90 --- /dev/null +++ b/pkg/api/v3/p2p/dial/mock_Connector_test.go @@ -0,0 +1,93 @@ +// Code generated by mockery v2.23.1. DO NOT EDIT. + +package dial + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + message "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/message" +) + +// MockConnector is an autogenerated mock type for the Connector type +type MockConnector struct { + mock.Mock +} + +type MockConnector_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConnector) EXPECT() *MockConnector_Expecter { + return &MockConnector_Expecter{mock: &_m.Mock} +} + +// Connect provides a mock function with given fields: _a0, _a1 +func (_m *MockConnector) Connect(_a0 context.Context, _a1 *ConnectionRequest) (message.StreamOf[message.Message], error) { + ret := _m.Called(_a0, _a1) + + var r0 message.StreamOf[message.Message] + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *ConnectionRequest) (message.StreamOf[message.Message], error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *ConnectionRequest) message.StreamOf[message.Message]); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.StreamOf[message.Message]) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *ConnectionRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockConnector_Connect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Connect' +type MockConnector_Connect_Call struct { + *mock.Call +} + +// Connect is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *ConnectionRequest +func (_e *MockConnector_Expecter) Connect(_a0 interface{}, _a1 interface{}) *MockConnector_Connect_Call { + return &MockConnector_Connect_Call{Call: _e.mock.On("Connect", _a0, _a1)} +} + +func (_c *MockConnector_Connect_Call) Run(run func(_a0 context.Context, _a1 *ConnectionRequest)) *MockConnector_Connect_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*ConnectionRequest)) + }) + return _c +} + +func (_c *MockConnector_Connect_Call) Return(_a0 message.StreamOf[message.Message], _a1 error) *MockConnector_Connect_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockConnector_Connect_Call) RunAndReturn(run func(context.Context, *ConnectionRequest) (message.StreamOf[message.Message], error)) *MockConnector_Connect_Call { + _c.Call.Return(run) + return _c +} + +type mockConstructorTestingTNewMockConnector interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockConnector creates a new instance of MockConnector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockConnector(t mockConstructorTestingTNewMockConnector) *MockConnector { + mock := &MockConnector{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/api/v3/p2p/dial/mock_Discoverer_test.go b/pkg/api/v3/p2p/dial/mock_Discoverer_test.go new file mode 100644 index 000000000..e5f2fbe7b --- /dev/null +++ b/pkg/api/v3/p2p/dial/mock_Discoverer_test.go @@ -0,0 +1,92 @@ +// Code generated by mockery v2.23.1. DO NOT EDIT. + +package dial + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockDiscoverer is an autogenerated mock type for the Discoverer type +type MockDiscoverer struct { + mock.Mock +} + +type MockDiscoverer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDiscoverer) EXPECT() *MockDiscoverer_Expecter { + return &MockDiscoverer_Expecter{mock: &_m.Mock} +} + +// Discover provides a mock function with given fields: _a0, _a1 +func (_m *MockDiscoverer) Discover(_a0 context.Context, _a1 *DiscoveryRequest) (DiscoveryResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 DiscoveryResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *DiscoveryRequest) (DiscoveryResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *DiscoveryRequest) DiscoveryResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(DiscoveryResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *DiscoveryRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscoverer_Discover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Discover' +type MockDiscoverer_Discover_Call struct { + *mock.Call +} + +// Discover is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *DiscoveryRequest +func (_e *MockDiscoverer_Expecter) Discover(_a0 interface{}, _a1 interface{}) *MockDiscoverer_Discover_Call { + return &MockDiscoverer_Discover_Call{Call: _e.mock.On("Discover", _a0, _a1)} +} + +func (_c *MockDiscoverer_Discover_Call) Run(run func(_a0 context.Context, _a1 *DiscoveryRequest)) *MockDiscoverer_Discover_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*DiscoveryRequest)) + }) + return _c +} + +func (_c *MockDiscoverer_Discover_Call) Return(_a0 DiscoveryResponse, _a1 error) *MockDiscoverer_Discover_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscoverer_Discover_Call) RunAndReturn(run func(context.Context, *DiscoveryRequest) (DiscoveryResponse, error)) *MockDiscoverer_Discover_Call { + _c.Call.Return(run) + return _c +} + +type mockConstructorTestingTNewMockDiscoverer interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockDiscoverer creates a new instance of MockDiscoverer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockDiscoverer(t mockConstructorTestingTNewMockDiscoverer) *MockDiscoverer { + mock := &MockDiscoverer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/api/v3/p2p/dial/new.go b/pkg/api/v3/p2p/dial/new.go new file mode 100644 index 000000000..758282b4d --- /dev/null +++ b/pkg/api/v3/p2p/dial/new.go @@ -0,0 +1,54 @@ +// Copyright 2023 The Accumulate Authors +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +package dial + +import "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/message" + +type Option interface { + apply(*dialer) +} + +type optionFunc func(*dialer) + +func (fn optionFunc) apply(d *dialer) { fn(d) } + +// New creates a new [message.Dialer]. New will panic if the connector, +// discoverer, or tracker is unspecified. +func New(opts ...Option) message.Dialer { + d := &dialer{} + for _, opt := range opts { + opt.apply(d) + } + if d.tracker == nil { + panic("missing tracker") + } + if d.host == nil { + panic("missing connector") + } + if d.peers == nil { + panic("missing discoverer") + } + return d +} + +func WithDiscoverer(v Discoverer) Option { + return optionFunc(func(d *dialer) { + d.peers = v + }) +} + +func WithConnector(v Connector) Option { + return optionFunc(func(d *dialer) { + d.host = v + }) +} + +func WithTracker(v Tracker) Option { + return optionFunc(func(d *dialer) { + d.tracker = v + }) +} diff --git a/pkg/api/v3/p2p/peer_queue.go b/pkg/api/v3/p2p/dial/peer_queue.go similarity index 99% rename from pkg/api/v3/p2p/peer_queue.go rename to pkg/api/v3/p2p/dial/peer_queue.go index 31a74a2e4..f4e7b343d 100644 --- a/pkg/api/v3/p2p/peer_queue.go +++ b/pkg/api/v3/p2p/dial/peer_queue.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -package p2p +package dial import ( "sync" diff --git a/pkg/api/v3/p2p/peer_queue_test.go b/pkg/api/v3/p2p/dial/peer_queue_test.go similarity index 98% rename from pkg/api/v3/p2p/peer_queue_test.go rename to pkg/api/v3/p2p/dial/peer_queue_test.go index 9bf760347..35d805bdd 100644 --- a/pkg/api/v3/p2p/peer_queue_test.go +++ b/pkg/api/v3/p2p/dial/peer_queue_test.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -package p2p +package dial import ( "testing" diff --git a/pkg/api/v3/p2p/dial/simple.go b/pkg/api/v3/p2p/dial/simple.go new file mode 100644 index 000000000..bbbc4a9e1 --- /dev/null +++ b/pkg/api/v3/p2p/dial/simple.go @@ -0,0 +1,117 @@ +// Copyright 2023 The Accumulate Authors +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +package dial + +import ( + "sync" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiformats/go-multiaddr" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" + "golang.org/x/exp/slog" +) + +type SimpleTracker struct { + mu sync.RWMutex + good map[string]*peerQueue + bad map[string]*peerQueue +} + +var _ Tracker = (*SimpleTracker)(nil) + +func (t *SimpleTracker) get(service multiaddr.Multiaddr) (good, bad *peerQueue) { + key := service.String() + t.mu.RLock() + good, ok := t.good[key] + bad = t.bad[key] + t.mu.RUnlock() + if ok { + return good, bad + } + + t.mu.Lock() + defer t.mu.Unlock() + + good, ok = t.good[key] + bad = t.bad[key] + if ok { + return good, bad + } + + if t.good == nil { + t.good = map[string]*peerQueue{} + t.bad = map[string]*peerQueue{} + } + + good = new(peerQueue) + bad = new(peerQueue) + t.good[key] = good + t.bad[key] = bad + return good, bad +} + +func (t *SimpleTracker) Mark(peer peer.ID, service multiaddr.Multiaddr, status api.KnownPeerStatus) { + good, bad := t.get(service) + switch status { + case api.PeerStatusIsKnownGood: + ok1 := good.Add(peer) + ok2 := bad.Remove(peer) + if ok1 || ok2 { + slog.Debug("Marked peer good", "peer", peer, "service", service) + } + + case api.PeerStatusIsKnownBad: + ok1 := good.Remove(peer) + ok2 := bad.Add(peer) + if ok1 || ok2 { + slog.Debug("Marked peer bad", "peer", peer, "service", service) + } + + case api.PeerStatusIsUnknown: + ok1 := good.Remove(peer) + ok2 := bad.Remove(peer) + if ok1 || ok2 { + slog.Debug("Marked peer dead", "peer", peer, "service", service) + } + } +} + +func (t *SimpleTracker) Status(peer peer.ID, service multiaddr.Multiaddr) api.KnownPeerStatus { + good, bad := t.get(service) + switch { + case good.Has(peer): + return api.PeerStatusIsKnownGood + case bad.Has(peer): + return api.PeerStatusIsKnownBad + default: + return api.PeerStatusIsUnknown + } +} + +func (t *SimpleTracker) Next(service multiaddr.Multiaddr, status api.KnownPeerStatus) (peer.ID, bool) { + good, bad := t.get(service) + switch status { + case api.PeerStatusIsKnownGood: + return good.Next() + case api.PeerStatusIsKnownBad: + return bad.Next() + default: + return "", false + } +} + +func (t *SimpleTracker) All(service multiaddr.Multiaddr, status api.KnownPeerStatus) []peer.ID { + good, bad := t.get(service) + switch status { + case api.PeerStatusIsKnownGood: + return good.All() + case api.PeerStatusIsKnownBad: + return bad.All() + default: + return nil + } +} diff --git a/pkg/api/v3/p2p/stream.go b/pkg/api/v3/p2p/dial/stream.go similarity index 53% rename from pkg/api/v3/p2p/stream.go rename to pkg/api/v3/p2p/dial/stream.go index 8f4b19d93..8811869ea 100644 --- a/pkg/api/v3/p2p/stream.go +++ b/pkg/api/v3/p2p/dial/stream.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -package p2p +package dial import ( "context" @@ -17,58 +17,61 @@ import ( "gitlab.com/accumulatenetwork/accumulate/pkg/errors" ) -// stream is a [message.Stream] with an associated [peerState]. +type ConnectionRequest struct { + Service *api.ServiceAddress + PeerID peer.ID +} + +type Connector interface { + Connect(context.Context, *ConnectionRequest) (message.Stream, error) +} + type stream struct { peer peer.ID - conn io.ReadWriteCloser stream message.Stream } -func openStreamFor(ctx context.Context, host dialerHost, peer peer.ID, sa *api.ServiceAddress) (*stream, error) { - conn, err := host.getPeerService(ctx, peer, sa) +func openStreamFor(ctx context.Context, host Connector, req *ConnectionRequest) (*stream, error) { + if req.PeerID == "" { + return nil, errors.BadRequest.With("missing peer ID") + } + if req.Service == nil { + return nil, errors.BadRequest.With("missing service address") + } + + conn, err := host.Connect(ctx, req) if err != nil { // Do not wrap as it will clobber the error return nil, err } - // Close the stream when the context is canceled - go func() { <-ctx.Done(); _ = conn.Close() }() - s := new(stream) - s.peer = peer - s.conn = conn - s.stream = message.NewStream(conn) + s.peer = req.PeerID + s.stream = conn return s, nil } func (s *stream) Read() (message.Message, error) { - // Convert ErrReset and Canceled into EOF m, err := s.stream.Read() - switch { - case err == nil: - return m, nil - case isEOF(err): - return nil, io.EOF - default: - return nil, err - } + return m, streamError(err) } func (s *stream) Write(msg message.Message) error { - // Convert ErrReset and Canceled into EOF err := s.stream.Write(msg) + return streamError(err) +} + +// streamError converts network reset and canceled context errors into +// [io.EOF]. +func streamError(err error) error { switch { case err == nil: return nil - case isEOF(err): + case errors.Is(err, io.EOF), + errors.Is(err, network.ErrReset), + errors.Is(err, context.Canceled): return io.EOF default: return err } } - -func isEOF(err error) bool { - return errors.Is(err, io.EOF) || - errors.Is(err, network.ErrReset) || - errors.Is(err, context.Canceled) -} diff --git a/pkg/api/v3/p2p/dial/types.go b/pkg/api/v3/p2p/dial/types.go new file mode 100644 index 000000000..7cfee69dd --- /dev/null +++ b/pkg/api/v3/p2p/dial/types.go @@ -0,0 +1,10 @@ +// Copyright 2023 The Accumulate Authors +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +package dial + +//go:generate go run github.com/vektra/mockery/v2 +//go:generate go run github.com/rinchsan/gosimports/cmd/gosimports -w . diff --git a/pkg/api/v3/p2p/dial_network.go b/pkg/api/v3/p2p/dial_network.go new file mode 100644 index 000000000..74673aa1a --- /dev/null +++ b/pkg/api/v3/p2p/dial_network.go @@ -0,0 +1,75 @@ +// Copyright 2023 The Accumulate Authors +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +package p2p + +import ( + "context" + + "github.com/multiformats/go-multiaddr" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/message" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/p2p/dial" + "gitlab.com/accumulatenetwork/accumulate/pkg/errors" +) + +// DialNetwork returns a [message.MultiDialer] that opens a stream to a node +// that can provides a given service. +func (n *Node) DialNetwork() message.Dialer { + return dial.New(n.dialOpts...) +} + +type discoverer Node + +func (d *discoverer) Discover(ctx context.Context, req *dial.DiscoveryRequest) (dial.DiscoveryResponse, error) { + var addr multiaddr.Multiaddr + if req.Network != "" { + c, err := multiaddr.NewComponent(api.N_ACC, req.Network) + if err != nil { + return nil, errors.BadRequest.WithFormat("create network multiaddr: %w", err) + } + addr = c + } + if req.Service != nil { + if req.Service.Type == api.ServiceTypeUnknown { + return nil, errors.BadRequest.With("missing service type") + } + c := req.Service.Multiaddr() + if addr == nil { + addr = c + } else { + addr = addr.Encapsulate(c) + } + } + if addr == nil { + return nil, errors.BadRequest.With("no network or service specified") + } + + s, ok := (*Node)(d).getOwnService(req.Network, req.Service) + if ok { + return dial.DiscoveredLocal(func(ctx context.Context) (message.Stream, error) { + return handleLocally(ctx, s), nil + }), nil + } + + ch, err := (*Node)(d).peermgr.getPeers(ctx, addr, req.Limit) + return dial.DiscoveredPeers(ch), err +} + +type connector Node + +func (c *connector) Connect(ctx context.Context, req *dial.ConnectionRequest) (message.Stream, error) { + if req.PeerID != c.host.ID() { + return (*Node)(c).getPeerService(ctx, req.PeerID, req.Service) + } + + s, ok := (*Node)(c).getOwnService("", req.Service) + if !ok { + return nil, errors.NotFound // TODO return protocol not supported + } + + return handleLocally(ctx, s), nil +} diff --git a/pkg/api/v3/p2p/dial_self.go b/pkg/api/v3/p2p/dial_self.go index d951a1b09..778d1c4de 100644 --- a/pkg/api/v3/p2p/dial_self.go +++ b/pkg/api/v3/p2p/dial_self.go @@ -31,8 +31,7 @@ func (d *selfDialer) Dial(ctx context.Context, addr multiaddr.Multiaddr) (messag return nil, errors.UnknownError.Wrap(err) } if peer != "" && peer != d.host.ID() { - s, err := openStreamFor(ctx, (*Node)(d), peer, sa) - return s, errors.UnknownError.Wrap(err) + return nil, errors.BadRequest.With("attempted to dial a different node") } // Check if we provide the service diff --git a/pkg/api/v3/p2p/mock_MessageStreamHandler_test.go b/pkg/api/v3/p2p/mock_MessageStreamHandler_test.go deleted file mode 100644 index 1c10f3e30..000000000 --- a/pkg/api/v3/p2p/mock_MessageStreamHandler_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Code generated by mockery v2.23.1. DO NOT EDIT. - -package p2p - -import ( - mock "github.com/stretchr/testify/mock" - message "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/message" -) - -// MockMessageStreamHandler is an autogenerated mock type for the MessageStreamHandler type -type MockMessageStreamHandler struct { - mock.Mock -} - -type MockMessageStreamHandler_Expecter struct { - mock *mock.Mock -} - -func (_m *MockMessageStreamHandler) EXPECT() *MockMessageStreamHandler_Expecter { - return &MockMessageStreamHandler_Expecter{mock: &_m.Mock} -} - -// Execute provides a mock function with given fields: _a0 -func (_m *MockMessageStreamHandler) Execute(_a0 message.StreamOf[message.Message]) { - _m.Called(_a0) -} - -// MockMessageStreamHandler_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' -type MockMessageStreamHandler_Execute_Call struct { - *mock.Call -} - -// Execute is a helper method to define mock.On call -// - _a0 message.StreamOf[message.Message] -func (_e *MockMessageStreamHandler_Expecter) Execute(_a0 interface{}) *MockMessageStreamHandler_Execute_Call { - return &MockMessageStreamHandler_Execute_Call{Call: _e.mock.On("Execute", _a0)} -} - -func (_c *MockMessageStreamHandler_Execute_Call) Run(run func(_a0 message.StreamOf[message.Message])) *MockMessageStreamHandler_Execute_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(message.StreamOf[message.Message])) - }) - return _c -} - -func (_c *MockMessageStreamHandler_Execute_Call) Return() *MockMessageStreamHandler_Execute_Call { - _c.Call.Return() - return _c -} - -func (_c *MockMessageStreamHandler_Execute_Call) RunAndReturn(run func(message.StreamOf[message.Message])) *MockMessageStreamHandler_Execute_Call { - _c.Call.Return(run) - return _c -} - -type mockConstructorTestingTNewMockMessageStreamHandler interface { - mock.TestingT - Cleanup(func()) -} - -// NewMockMessageStreamHandler creates a new instance of MockMessageStreamHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockMessageStreamHandler(t mockConstructorTestingTNewMockMessageStreamHandler) *MockMessageStreamHandler { - mock := &MockMessageStreamHandler{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/api/v3/p2p/mock_dialerHost_test.go b/pkg/api/v3/p2p/mock_dialerHost_test.go deleted file mode 100644 index 6c42aa47b..000000000 --- a/pkg/api/v3/p2p/mock_dialerHost_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Code generated by mockery v2.23.1. DO NOT EDIT. - -package p2p - -import ( - context "context" - io "io" - - peer "github.com/libp2p/go-libp2p/core/peer" - mock "github.com/stretchr/testify/mock" - api "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" -) - -// mockDialerHost is an autogenerated mock type for the dialerHost type -type mockDialerHost struct { - mock.Mock -} - -type mockDialerHost_Expecter struct { - mock *mock.Mock -} - -func (_m *mockDialerHost) EXPECT() *mockDialerHost_Expecter { - return &mockDialerHost_Expecter{mock: &_m.Mock} -} - -// getOwnService provides a mock function with given fields: network, sa -func (_m *mockDialerHost) getOwnService(network string, sa *api.ServiceAddress) (*serviceHandler, bool) { - ret := _m.Called(network, sa) - - var r0 *serviceHandler - var r1 bool - if rf, ok := ret.Get(0).(func(string, *api.ServiceAddress) (*serviceHandler, bool)); ok { - return rf(network, sa) - } - if rf, ok := ret.Get(0).(func(string, *api.ServiceAddress) *serviceHandler); ok { - r0 = rf(network, sa) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*serviceHandler) - } - } - - if rf, ok := ret.Get(1).(func(string, *api.ServiceAddress) bool); ok { - r1 = rf(network, sa) - } else { - r1 = ret.Get(1).(bool) - } - - return r0, r1 -} - -// mockDialerHost_getOwnService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getOwnService' -type mockDialerHost_getOwnService_Call struct { - *mock.Call -} - -// getOwnService is a helper method to define mock.On call -// - network string -// - sa *api.ServiceAddress -func (_e *mockDialerHost_Expecter) getOwnService(network interface{}, sa interface{}) *mockDialerHost_getOwnService_Call { - return &mockDialerHost_getOwnService_Call{Call: _e.mock.On("getOwnService", network, sa)} -} - -func (_c *mockDialerHost_getOwnService_Call) Run(run func(network string, sa *api.ServiceAddress)) *mockDialerHost_getOwnService_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(*api.ServiceAddress)) - }) - return _c -} - -func (_c *mockDialerHost_getOwnService_Call) Return(_a0 *serviceHandler, _a1 bool) *mockDialerHost_getOwnService_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *mockDialerHost_getOwnService_Call) RunAndReturn(run func(string, *api.ServiceAddress) (*serviceHandler, bool)) *mockDialerHost_getOwnService_Call { - _c.Call.Return(run) - return _c -} - -// getPeerService provides a mock function with given fields: ctx, _a1, service -func (_m *mockDialerHost) getPeerService(ctx context.Context, _a1 peer.ID, service *api.ServiceAddress) (io.ReadWriteCloser, error) { - ret := _m.Called(ctx, _a1, service) - - var r0 io.ReadWriteCloser - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, peer.ID, *api.ServiceAddress) (io.ReadWriteCloser, error)); ok { - return rf(ctx, _a1, service) - } - if rf, ok := ret.Get(0).(func(context.Context, peer.ID, *api.ServiceAddress) io.ReadWriteCloser); ok { - r0 = rf(ctx, _a1, service) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(io.ReadWriteCloser) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, peer.ID, *api.ServiceAddress) error); ok { - r1 = rf(ctx, _a1, service) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// mockDialerHost_getPeerService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getPeerService' -type mockDialerHost_getPeerService_Call struct { - *mock.Call -} - -// getPeerService is a helper method to define mock.On call -// - ctx context.Context -// - _a1 peer.ID -// - service *api.ServiceAddress -func (_e *mockDialerHost_Expecter) getPeerService(ctx interface{}, _a1 interface{}, service interface{}) *mockDialerHost_getPeerService_Call { - return &mockDialerHost_getPeerService_Call{Call: _e.mock.On("getPeerService", ctx, _a1, service)} -} - -func (_c *mockDialerHost_getPeerService_Call) Run(run func(ctx context.Context, _a1 peer.ID, service *api.ServiceAddress)) *mockDialerHost_getPeerService_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(peer.ID), args[2].(*api.ServiceAddress)) - }) - return _c -} - -func (_c *mockDialerHost_getPeerService_Call) Return(_a0 io.ReadWriteCloser, _a1 error) *mockDialerHost_getPeerService_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *mockDialerHost_getPeerService_Call) RunAndReturn(run func(context.Context, peer.ID, *api.ServiceAddress) (io.ReadWriteCloser, error)) *mockDialerHost_getPeerService_Call { - _c.Call.Return(run) - return _c -} - -// selfID provides a mock function with given fields: -func (_m *mockDialerHost) selfID() peer.ID { - ret := _m.Called() - - var r0 peer.ID - if rf, ok := ret.Get(0).(func() peer.ID); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(peer.ID) - } - - return r0 -} - -// mockDialerHost_selfID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'selfID' -type mockDialerHost_selfID_Call struct { - *mock.Call -} - -// selfID is a helper method to define mock.On call -func (_e *mockDialerHost_Expecter) selfID() *mockDialerHost_selfID_Call { - return &mockDialerHost_selfID_Call{Call: _e.mock.On("selfID")} -} - -func (_c *mockDialerHost_selfID_Call) Run(run func()) *mockDialerHost_selfID_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *mockDialerHost_selfID_Call) Return(_a0 peer.ID) *mockDialerHost_selfID_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *mockDialerHost_selfID_Call) RunAndReturn(run func() peer.ID) *mockDialerHost_selfID_Call { - _c.Call.Return(run) - return _c -} - -type mockConstructorTestingTnewMockDialerHost interface { - mock.TestingT - Cleanup(func()) -} - -// newMockDialerHost creates a new instance of mockDialerHost. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func newMockDialerHost(t mockConstructorTestingTnewMockDialerHost) *mockDialerHost { - mock := &mockDialerHost{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/api/v3/p2p/mock_dialerPeers_test.go b/pkg/api/v3/p2p/mock_dialerPeers_test.go deleted file mode 100644 index f20c21ff7..000000000 --- a/pkg/api/v3/p2p/mock_dialerPeers_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Code generated by mockery v2.23.1. DO NOT EDIT. - -package p2p - -import ( - context "context" - - peer "github.com/libp2p/go-libp2p/core/peer" - multiaddr "github.com/multiformats/go-multiaddr" - mock "github.com/stretchr/testify/mock" -) - -// mockDialerPeers is an autogenerated mock type for the dialerPeers type -type mockDialerPeers struct { - mock.Mock -} - -type mockDialerPeers_Expecter struct { - mock *mock.Mock -} - -func (_m *mockDialerPeers) EXPECT() *mockDialerPeers_Expecter { - return &mockDialerPeers_Expecter{mock: &_m.Mock} -} - -// getPeers provides a mock function with given fields: ctx, ma, limit -func (_m *mockDialerPeers) getPeers(ctx context.Context, ma multiaddr.Multiaddr, limit int) (<-chan peer.AddrInfo, error) { - ret := _m.Called(ctx, ma, limit) - - var r0 <-chan peer.AddrInfo - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, multiaddr.Multiaddr, int) (<-chan peer.AddrInfo, error)); ok { - return rf(ctx, ma, limit) - } - if rf, ok := ret.Get(0).(func(context.Context, multiaddr.Multiaddr, int) <-chan peer.AddrInfo); ok { - r0 = rf(ctx, ma, limit) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan peer.AddrInfo) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, multiaddr.Multiaddr, int) error); ok { - r1 = rf(ctx, ma, limit) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// mockDialerPeers_getPeers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getPeers' -type mockDialerPeers_getPeers_Call struct { - *mock.Call -} - -// getPeers is a helper method to define mock.On call -// - ctx context.Context -// - ma multiaddr.Multiaddr -// - limit int -func (_e *mockDialerPeers_Expecter) getPeers(ctx interface{}, ma interface{}, limit interface{}) *mockDialerPeers_getPeers_Call { - return &mockDialerPeers_getPeers_Call{Call: _e.mock.On("getPeers", ctx, ma, limit)} -} - -func (_c *mockDialerPeers_getPeers_Call) Run(run func(ctx context.Context, ma multiaddr.Multiaddr, limit int)) *mockDialerPeers_getPeers_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(multiaddr.Multiaddr), args[2].(int)) - }) - return _c -} - -func (_c *mockDialerPeers_getPeers_Call) Return(_a0 <-chan peer.AddrInfo, _a1 error) *mockDialerPeers_getPeers_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *mockDialerPeers_getPeers_Call) RunAndReturn(run func(context.Context, multiaddr.Multiaddr, int) (<-chan peer.AddrInfo, error)) *mockDialerPeers_getPeers_Call { - _c.Call.Return(run) - return _c -} - -type mockConstructorTestingTnewMockDialerPeers interface { - mock.TestingT - Cleanup(func()) -} - -// newMockDialerPeers creates a new instance of mockDialerPeers. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func newMockDialerPeers(t mockConstructorTestingTnewMockDialerPeers) *mockDialerPeers { - mock := &mockDialerPeers{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/api/v3/p2p/p2p.go b/pkg/api/v3/p2p/p2p.go index 1cb4c871d..98af9c7f4 100644 --- a/pkg/api/v3/p2p/p2p.go +++ b/pkg/api/v3/p2p/p2p.go @@ -9,7 +9,6 @@ package p2p import ( "context" "crypto/ed25519" - "io" "net" "strings" @@ -24,6 +23,7 @@ import ( sortutil "gitlab.com/accumulatenetwork/accumulate/internal/util/sort" "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/message" + "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3/p2p/dial" "gitlab.com/accumulatenetwork/accumulate/pkg/errors" ) @@ -49,7 +49,8 @@ type Node struct { cancel context.CancelFunc peermgr *peerManager host host.Host - tracker peerTracker + dialOpts []dial.Option + tracker dial.Tracker services []*serviceHandler } @@ -89,9 +90,14 @@ func New(opts Options) (_ *Node, err error) { n.context, n.cancel = context.WithCancel(context.Background()) if opts.EnablePeerTracker { - n.tracker = new(simpleTracker) + n.tracker = new(dial.SimpleTracker) } else { - n.tracker = fakeTracker{} + n.tracker = dial.FakeTracker + } + n.dialOpts = []dial.Option{ + dial.WithConnector((*connector)(n)), + dial.WithDiscoverer((*discoverer)(n)), + dial.WithTracker(n.tracker), } // Cancel on fail @@ -208,14 +214,17 @@ func (n *Node) Close() error { return n.host.Close() } -// selfID returns the node's ID. -func (n *Node) selfID() peer.ID { - return n.host.ID() -} - // getPeerService returns a new stream for the given peer and service. -func (n *Node) getPeerService(ctx context.Context, peer peer.ID, service *api.ServiceAddress) (io.ReadWriteCloser, error) { - return n.host.NewStream(ctx, peer, idRpc(service)) +func (n *Node) getPeerService(ctx context.Context, peerID peer.ID, service *api.ServiceAddress) (message.Stream, error) { + s, err := n.host.NewStream(ctx, peerID, idRpc(service)) + if err != nil { + return nil, err + } + + // Close the stream when the context is canceled + go func() { <-ctx.Done(); _ = s.Close() }() + + return message.NewStream(s), nil } // getOwnService returns a service of this node. diff --git a/pkg/api/v3/p2p/peer_manager.go b/pkg/api/v3/p2p/peer_manager.go index dd0c395af..110e70912 100644 --- a/pkg/api/v3/p2p/peer_manager.go +++ b/pkg/api/v3/p2p/peer_manager.go @@ -138,7 +138,7 @@ func (m *peerManager) waitFor(ctx context.Context, addr multiaddr.Multiaddr) err wait := <-m.wait // Look for a peer - ch, err := m.routing.FindPeers(ctx, addr.String(), discovery.Limit(1)) + ch, err := m.getPeers(ctx, addr, 1) if err != nil { return err } diff --git a/pkg/api/v3/p2p/peer_tracker.go b/pkg/api/v3/p2p/peer_tracker.go deleted file mode 100644 index af8440493..000000000 --- a/pkg/api/v3/p2p/peer_tracker.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2023 The Accumulate Authors -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file or at -// https://opensource.org/licenses/MIT. - -package p2p - -import ( - "context" - "sync" - - "github.com/libp2p/go-libp2p/core/peer" - "github.com/multiformats/go-multiaddr" - "gitlab.com/accumulatenetwork/accumulate/pkg/api/v3" - "golang.org/x/exp/slog" -) - -type peerTracker interface { - // markGood marks a peer as good for a given service. - markGood(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) - - // markBad marks a peer as bad for a given service. - markBad(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) - - // markDead marks a peer as dead. - markDead(ctx context.Context, peer peer.ID) - - // status returns the status of a peer. - status(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) api.KnownPeerStatus - - // nextGood returns the next good peer for a service. - nextGood(ctx context.Context, service multiaddr.Multiaddr) (peer.ID, bool) - - // nextBad returns the next bad peer for a service. - nextBad(ctx context.Context, service multiaddr.Multiaddr) (peer.ID, bool) - - // allGood returns all good peers for a service. - allGood(ctx context.Context, service multiaddr.Multiaddr) []peer.ID - - // allBad returns all bad peers for a service. - allBad(ctx context.Context, service multiaddr.Multiaddr) []peer.ID -} - -type fakeTracker struct{} - -func (fakeTracker) markGood(context.Context, peer.ID, multiaddr.Multiaddr) {} -func (fakeTracker) markBad(context.Context, peer.ID, multiaddr.Multiaddr) {} -func (fakeTracker) markDead(context.Context, peer.ID) {} -func (fakeTracker) nextGood(context.Context, multiaddr.Multiaddr) (peer.ID, bool) { return "", false } -func (fakeTracker) nextBad(context.Context, multiaddr.Multiaddr) (peer.ID, bool) { return "", false } -func (fakeTracker) allGood(context.Context, multiaddr.Multiaddr) []peer.ID { return nil } -func (fakeTracker) allBad(context.Context, multiaddr.Multiaddr) []peer.ID { return nil } - -func (fakeTracker) status(context.Context, peer.ID, multiaddr.Multiaddr) api.KnownPeerStatus { - return 0 -} - -type simpleTracker struct { - mu sync.RWMutex - good map[string]*peerQueue - bad map[string]*peerQueue -} - -func (t *simpleTracker) get(service multiaddr.Multiaddr) (good, bad *peerQueue) { - key := service.String() - t.mu.RLock() - good, ok := t.good[key] - bad = t.bad[key] - t.mu.RUnlock() - if ok { - return good, bad - } - - t.mu.Lock() - defer t.mu.Unlock() - - good, ok = t.good[key] - bad = t.bad[key] - if ok { - return good, bad - } - - if t.good == nil { - t.good = map[string]*peerQueue{} - t.bad = map[string]*peerQueue{} - } - - good = new(peerQueue) - bad = new(peerQueue) - t.good[key] = good - t.bad[key] = bad - return good, bad -} - -func (t *simpleTracker) markGood(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) { - good, bad := t.get(service) - ok1 := good.Add(peer) - ok2 := bad.Remove(peer) - if ok1 || ok2 { - slog.DebugCtx(ctx, "Marked peer good", "peer", peer, "service", service) - } -} - -func (t *simpleTracker) markBad(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) { - good, bad := t.get(service) - ok1 := good.Remove(peer) - ok2 := bad.Add(peer) - if ok1 || ok2 { - slog.DebugCtx(ctx, "Marked peer bad", "peer", peer, "service", service) - } -} - -func (t *simpleTracker) markDead(ctx context.Context, peer peer.ID) { - t.mu.RLock() - defer t.mu.RUnlock() - var ok bool - for _, q := range t.good { - if q.Remove(peer) { - ok = true - } - } - for _, q := range t.bad { - if q.Remove(peer) { - ok = true - } - } - if ok { - slog.DebugCtx(ctx, "Marked peer dead", "peer", peer) - } -} - -func (t *simpleTracker) status(ctx context.Context, peer peer.ID, service multiaddr.Multiaddr) api.KnownPeerStatus { - good, bad := t.get(service) - switch { - case good.Has(peer): - return api.PeerStatusIsKnownGood - case bad.Has(peer): - return api.PeerStatusIsKnownBad - default: - return 0 - } -} - -func (t *simpleTracker) nextGood(ctx context.Context, service multiaddr.Multiaddr) (peer.ID, bool) { - good, _ := t.get(service) - return good.Next() -} - -func (t *simpleTracker) nextBad(ctx context.Context, service multiaddr.Multiaddr) (peer.ID, bool) { - _, bad := t.get(service) - return bad.Next() -} - -func (t *simpleTracker) allGood(ctx context.Context, service multiaddr.Multiaddr) []peer.ID { - good, _ := t.get(service) - return good.All() -} - -func (t *simpleTracker) allBad(ctx context.Context, service multiaddr.Multiaddr) []peer.ID { - _, bad := t.get(service) - return bad.All() -} diff --git a/pkg/api/v3/p2p/services.go b/pkg/api/v3/p2p/services.go index c806431d7..b0360e506 100644 --- a/pkg/api/v3/p2p/services.go +++ b/pkg/api/v3/p2p/services.go @@ -109,34 +109,56 @@ func (n *nodeService) FindService(ctx context.Context, opts api.FindServiceOptio return nil, errors.BadRequest.With("no network or service specified") } + var results []*api.FindServiceResult if opts.Known { // Find known peers - var results []*api.FindServiceResult - for _, peer := range n.tracker.allGood(ctx, addr) { - results = append(results, &api.FindServiceResult{ - PeerID: peer, - Status: api.PeerStatusIsKnownGood, - }) - } - for _, peer := range n.tracker.allBad(ctx, addr) { - results = append(results, &api.FindServiceResult{ - PeerID: peer, - Status: api.PeerStatusIsKnownBad, - }) + results = n.getKnownPeers(ctx, addr) + + } else { + // Discover peers + var err error + results, err = n.discoverPeers(ctx, addr) + if err != nil { + return nil, errors.UnknownError.Wrap(err) } - return results, nil } + // Return an empty array, not nil, because JSON-RPC handles that better + if results == nil { + return []*api.FindServiceResult{}, nil + } + return results, nil +} + +func (n *nodeService) getKnownPeers(ctx context.Context, addr multiaddr.Multiaddr) []*api.FindServiceResult { + // Find known peers + var results []*api.FindServiceResult + for _, peer := range n.tracker.All(addr, api.PeerStatusIsKnownGood) { + results = append(results, &api.FindServiceResult{ + PeerID: peer, + Status: api.PeerStatusIsKnownGood, + }) + } + for _, peer := range n.tracker.All(addr, api.PeerStatusIsKnownBad) { + results = append(results, &api.FindServiceResult{ + PeerID: peer, + Status: api.PeerStatusIsKnownBad, + }) + } + return results +} + +func (n *nodeService) discoverPeers(ctx context.Context, addr multiaddr.Multiaddr) ([]*api.FindServiceResult, error) { ch, err := n.peermgr.getPeers(ctx, addr, 100) if err != nil { return nil, err } - var results []*api.FindServiceResult + results := []*api.FindServiceResult{} for peer := range ch { results = append(results, &api.FindServiceResult{ PeerID: peer.ID, - Status: n.tracker.status(ctx, peer.ID, addr), + Status: n.tracker.Status(peer.ID, addr), }) } return results, nil diff --git a/pkg/api/v3/p2p/types.go b/pkg/api/v3/p2p/types.go index ec0e6483e..f13f13aae 100644 --- a/pkg/api/v3/p2p/types.go +++ b/pkg/api/v3/p2p/types.go @@ -17,8 +17,6 @@ type event interface { Type() eventType } -//go:generate go run github.com/vektra/mockery/v2 //go:generate go run gitlab.com/accumulatenetwork/accumulate/tools/cmd/gen-enum --package p2p enums.yml //go:generate go run gitlab.com/accumulatenetwork/accumulate/tools/cmd/gen-types --package p2p types.yml //go:generate go run gitlab.com/accumulatenetwork/accumulate/tools/cmd/gen-types --language go-union --package p2p --out unions_gen.go types.yml -//go:generate go run github.com/rinchsan/gosimports/cmd/gosimports -w .