Skip to content

Commit

Permalink
Add wrapped subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanTinianov committed Dec 5, 2024
1 parent b158e7a commit 16416ef
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
51 changes: 32 additions & 19 deletions pkg/solana/client/multinode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ type MultiNodeClient[RPC any, HEAD Head] struct {
latestChainInfo ChainInfo
}

// WrappedSubscription is used to ensure that the subscription is removed from the client when unsubscribed
type WrappedSubscription struct {
Subscription
removeSub func(sub Subscription)
}

func (w *WrappedSubscription) Unsubscribe() {
w.Subscription.Unsubscribe()
if w.removeSub != nil {
w.removeSub(w)
}
}

func NewMultiNodeClient[RPC any, HEAD Head](
cfg *mnCfg.MultiNodeConfig, rpc *RPC, ctxTimeout time.Duration, log logger.Logger,
latestBlock func(ctx context.Context, rpc *RPC) (HEAD, error),
Expand All @@ -60,24 +73,14 @@ func (m *MultiNodeClient[RPC, HEAD]) LenSubs() int {
return len(m.subs)
}

// removeClosedSubscriptions removes any subscriptions that have been closed
func (m *MultiNodeClient[RPC, HEAD]) removeClosedSubscriptions() {
func (m *MultiNodeClient[RPC, HEAD]) removeSubscription(sub Subscription) {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
for sub := range m.subs {
select {
case _, ok := <-sub.Err():
if !ok {
delete(m.subs, sub)
}
default:
}
}
delete(m.subs, sub)
}

// RegisterSub adds the sub to the rpcClient list
func (m *MultiNodeClient[RPC, HEAD]) RegisterSub(sub Subscription, stopInFLightCh chan struct{}) error {
defer m.removeClosedSubscriptions()
// registerSub adds the sub to the rpcClient list
func (m *MultiNodeClient[RPC, HEAD]) registerSub(sub Subscription, stopInFLightCh chan struct{}) error {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
// ensure that the `sub` belongs to current life cycle of the `rpcClient` and it should not be killed due to
Expand Down Expand Up @@ -148,13 +151,18 @@ func (m *MultiNodeClient[RPC, HEAD]) SubscribeToHeads(ctx context.Context) (<-ch
return nil, nil, err
}

err := m.RegisterSub(&poller, chStopInFlight)
sub := &WrappedSubscription{
Subscription: &poller,
removeSub: m.removeSubscription,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
sub.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeClient[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) {
Expand All @@ -176,13 +184,18 @@ func (m *MultiNodeClient[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Conte
return nil, nil, err
}

err := m.RegisterSub(&poller, chStopInFlight)
sub := &WrappedSubscription{
Subscription: &poller,
removeSub: m.removeSubscription,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeClient[RPC, HEAD]) OnNewHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
Expand Down
21 changes: 11 additions & 10 deletions pkg/solana/client/multinode/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,18 @@ func TestMultiNodeClient_HeadSubscriptions(t *testing.T) {
}
})

t.Run("Remove Closed Subscriptions", func(t *testing.T) {
t.Run("Remove Subscription on Unsubscribe", func(t *testing.T) {
_, sub1, err := c.SubscribeToHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
sub1.Unsubscribe()

_, sub2, err := c.SubscribeToHeads(tests.Context(t))
_, sub2, err := c.SubscribeToFinalizedHeads(tests.Context(t))
require.NoError(t, err)
defer sub2.Unsubscribe()
// Ensure sub1 was removed since it was closed
require.Equal(t, 2, c.LenSubs())

sub1.Unsubscribe()
require.Equal(t, 1, c.LenSubs())
sub2.Unsubscribe()
require.Equal(t, 0, c.LenSubs())
})
}

Expand All @@ -129,7 +130,7 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) {

t.Run("registerSub", func(t *testing.T) {
sub := newMockSub()
err := c.RegisterSub(sub, make(chan struct{}))
err := c.registerSub(sub, make(chan struct{}))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
c.UnsubscribeAllExcept()
Expand All @@ -139,7 +140,7 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) {
chStopInFlight := make(chan struct{})
close(chStopInFlight)
sub := newMockSub()
err := c.RegisterSub(sub, chStopInFlight)
err := c.registerSub(sub, chStopInFlight)
require.Error(t, err)
require.Equal(t, true, sub.unsubscribed)
})
Expand All @@ -148,9 +149,9 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) {
chStopInFlight := make(chan struct{})
sub1 := newMockSub()
sub2 := newMockSub()
err := c.RegisterSub(sub1, chStopInFlight)
err := c.registerSub(sub1, chStopInFlight)
require.NoError(t, err)
err = c.RegisterSub(sub2, chStopInFlight)
err = c.registerSub(sub2, chStopInFlight)
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())

Expand Down

0 comments on commit 16416ef

Please sign in to comment.