diff --git a/pkg/solana/client/multinode/client.go b/pkg/solana/client/multinode/client.go index bf7db4e23..62349fbef 100644 --- a/pkg/solana/client/multinode/client.go +++ b/pkg/solana/client/multinode/client.go @@ -37,6 +37,19 @@ type MultiNodeClient[RPC any, HEAD Head] struct { latestChainInfo ChainInfo } +// WrappedSubscription is used to ensure that the subscription is removed from the client when unsubscribed +type WrappedSubscription struct { + Subscription + removeSub func(sub Subscription) +} + +func (w *WrappedSubscription) Unsubscribe() { + w.Subscription.Unsubscribe() + if w.removeSub != nil { + w.removeSub(w) + } +} + func NewMultiNodeClient[RPC any, HEAD Head]( cfg *mnCfg.MultiNodeConfig, rpc *RPC, ctxTimeout time.Duration, log logger.Logger, latestBlock func(ctx context.Context, rpc *RPC) (HEAD, error), @@ -60,24 +73,14 @@ func (m *MultiNodeClient[RPC, HEAD]) LenSubs() int { return len(m.subs) } -// removeClosedSubscriptions removes any subscriptions that have been closed -func (m *MultiNodeClient[RPC, HEAD]) removeClosedSubscriptions() { +func (m *MultiNodeClient[RPC, HEAD]) removeSubscription(sub Subscription) { m.subsSliceMu.Lock() defer m.subsSliceMu.Unlock() - for sub := range m.subs { - select { - case _, ok := <-sub.Err(): - if !ok { - delete(m.subs, sub) - } - default: - } - } + delete(m.subs, sub) } -// RegisterSub adds the sub to the rpcClient list -func (m *MultiNodeClient[RPC, HEAD]) RegisterSub(sub Subscription, stopInFLightCh chan struct{}) error { - defer m.removeClosedSubscriptions() +// registerSub adds the sub to the rpcClient list +func (m *MultiNodeClient[RPC, HEAD]) registerSub(sub Subscription, stopInFLightCh chan struct{}) error { m.subsSliceMu.Lock() defer m.subsSliceMu.Unlock() // ensure that the `sub` belongs to current life cycle of the `rpcClient` and it should not be killed due to @@ -148,13 +151,18 @@ func (m *MultiNodeClient[RPC, HEAD]) SubscribeToHeads(ctx context.Context) (<-ch return nil, nil, err } - err := m.RegisterSub(&poller, chStopInFlight) + sub := &WrappedSubscription{ + Subscription: &poller, + removeSub: m.removeSubscription, + } + + err := m.registerSub(sub, chStopInFlight) if err != nil { - poller.Unsubscribe() + sub.Unsubscribe() return nil, nil, err } - return channel, &poller, nil + return channel, sub, nil } func (m *MultiNodeClient[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) { @@ -176,13 +184,18 @@ func (m *MultiNodeClient[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Conte return nil, nil, err } - err := m.RegisterSub(&poller, chStopInFlight) + sub := &WrappedSubscription{ + Subscription: &poller, + removeSub: m.removeSubscription, + } + + err := m.registerSub(sub, chStopInFlight) if err != nil { poller.Unsubscribe() return nil, nil, err } - return channel, &poller, nil + return channel, sub, nil } func (m *MultiNodeClient[RPC, HEAD]) OnNewHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) { diff --git a/pkg/solana/client/multinode/client_test.go b/pkg/solana/client/multinode/client_test.go index 7b1992d20..0869824c0 100644 --- a/pkg/solana/client/multinode/client_test.go +++ b/pkg/solana/client/multinode/client_test.go @@ -95,17 +95,18 @@ func TestMultiNodeClient_HeadSubscriptions(t *testing.T) { } }) - t.Run("Remove Closed Subscriptions", func(t *testing.T) { + t.Run("Remove Subscription on Unsubscribe", func(t *testing.T) { _, sub1, err := c.SubscribeToHeads(tests.Context(t)) require.NoError(t, err) require.Equal(t, 1, c.LenSubs()) - sub1.Unsubscribe() - - _, sub2, err := c.SubscribeToHeads(tests.Context(t)) + _, sub2, err := c.SubscribeToFinalizedHeads(tests.Context(t)) require.NoError(t, err) - defer sub2.Unsubscribe() - // Ensure sub1 was removed since it was closed + require.Equal(t, 2, c.LenSubs()) + + sub1.Unsubscribe() require.Equal(t, 1, c.LenSubs()) + sub2.Unsubscribe() + require.Equal(t, 0, c.LenSubs()) }) } @@ -129,7 +130,7 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) { t.Run("registerSub", func(t *testing.T) { sub := newMockSub() - err := c.RegisterSub(sub, make(chan struct{})) + err := c.registerSub(sub, make(chan struct{})) require.NoError(t, err) require.Equal(t, 1, c.LenSubs()) c.UnsubscribeAllExcept() @@ -139,7 +140,7 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) { chStopInFlight := make(chan struct{}) close(chStopInFlight) sub := newMockSub() - err := c.RegisterSub(sub, chStopInFlight) + err := c.registerSub(sub, chStopInFlight) require.Error(t, err) require.Equal(t, true, sub.unsubscribed) }) @@ -148,9 +149,9 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) { chStopInFlight := make(chan struct{}) sub1 := newMockSub() sub2 := newMockSub() - err := c.RegisterSub(sub1, chStopInFlight) + err := c.registerSub(sub1, chStopInFlight) require.NoError(t, err) - err = c.RegisterSub(sub2, chStopInFlight) + err = c.registerSub(sub2, chStopInFlight) require.NoError(t, err) require.Equal(t, 2, c.LenSubs())