From 2b7cf285e4657eb8db05283cf91e624dd65c62e0 Mon Sep 17 00:00:00 2001 From: thielepaul Date: Tue, 21 Jun 2022 10:35:19 +0200 Subject: [PATCH] zmq4: add option for automatic reconnect --- options.go | 8 ++++++ socket.go | 31 +++++++++++++++-------- socket_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) diff --git a/options.go b/options.go index 6a4ead8..d85b7c5 100644 --- a/options.go +++ b/options.go @@ -59,6 +59,14 @@ func WithDialerMaxRetries(maxRetries int) Option { } } +// WithAutomaticReconnect allows to configure a socket to automatically +// reconnect on connection loss. +func WithAutomaticReconnect(automaticReconnect bool) Option { + return func(s *socket) { + s.autoReconnect = automaticReconnect + } +} + /* // TODO(sbinet) diff --git a/socket.go b/socket.go index db5b953..9e185be 100644 --- a/socket.go +++ b/socket.go @@ -31,14 +31,15 @@ var ( // socket implements the ZeroMQ socket interface type socket struct { - ep string // socket end-point - typ SocketType - id SocketIdentity - retry time.Duration - maxRetries int - sec Security - log *log.Logger - subTopics func() []string + ep string // socket end-point + typ SocketType + id SocketIdentity + retry time.Duration + maxRetries int + sec Security + log *log.Logger + subTopics func() []string + autoReconnect bool mu sync.RWMutex ids map[string]*Conn // ZMTP connection IDs @@ -53,8 +54,9 @@ type socket struct { listener net.Listener dialer net.Dialer - closedConns []*Conn - reaperCond *sync.Cond + closedConns []*Conn + reaperCond *sync.Cond + reaperStarted bool } func newDefaultSocket(ctx context.Context, sockType SocketType) *socket { @@ -271,7 +273,10 @@ connect: return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint) } - go sck.connReaper() + if !sck.reaperStarted { + go sck.connReaper() + sck.reaperStarted = true + } sck.addConn(zconn) return nil } @@ -330,6 +335,10 @@ func (sck *socket) scheduleRmConn(c *Conn) { sck.closedConns = append(sck.closedConns, c) sck.reaperCond.Signal() sck.reaperCond.L.Unlock() + + if sck.autoReconnect { + sck.Dial(sck.ep) + } } // Type returns the type of this Socket (PUB, SUB, ...) diff --git a/socket_test.go b/socket_test.go index 51048e2..764cbcc 100644 --- a/socket_test.go +++ b/socket_test.go @@ -239,7 +239,7 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) { if err := pub.Dial(endpoint); err != nil { t.Fatalf("Pub Dial failed: %v", err) } - wg := new(sync.WaitGroup) + var wg sync.WaitGroup defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -322,3 +322,70 @@ func TestConnMaxRetriesInfinite(t *testing.T) { t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries) } } + +func TestSocketAutomaticReconnect(t *testing.T) { + ep, err := EndPoint("tcp") + if err != nil { + t.Fatalf("could not find endpoint: %+v", err) + } + message := "test" + + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sendMessages := func(socket zmq4.Socket) { + wg.Add(1) + go func(t *testing.T) { + defer wg.Done() + for { + socket.Send(zmq4.NewMsgFromString([]string{message})) + if ctx.Err() != nil { + return + } + time.Sleep(1 * time.Millisecond) + } + }(t) + } + + sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true)) + defer sub.Close() + sub.SetOption(zmq4.OptionSubscribe, message) + pub := zmq4.NewPub(context.Background()) + if err := pub.Listen(ep); err != nil { + t.Fatalf("Pub Dial failed: %v", err) + } + if err := sub.Dial(ep); err != nil { + t.Fatalf("Sub Dial failed: %v", err) + } + + sendMessages(pub) + + checkConnectionWorking := func(socket zmq4.Socket) { + for { + msg, err := socket.Recv() + if errors.Is(err, io.EOF) { + continue + } + if err != nil { + t.Fatalf("Recv failed: %v", err) + } + if string(msg.Frames[0]) != message { + t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message) + } + return + } + } + + checkConnectionWorking(sub) + pub.Close() + + pub2 := zmq4.NewPub(context.Background()) + defer pub2.Close() + if err := pub2.Listen(ep); err != nil { + t.Fatalf("Sub Listen failed: %v", err) + } + sendMessages(pub2) + checkConnectionWorking(sub) +}