Skip to content

Commit

Permalink
zmq4: add option for automatic reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul authored Jun 21, 2022
1 parent ae18bc0 commit 2b7cf28
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 12 deletions.
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ func WithDialerMaxRetries(maxRetries int) Option {
}
}

// WithAutomaticReconnect allows to configure a socket to automatically
// reconnect on connection loss.
func WithAutomaticReconnect(automaticReconnect bool) Option {
return func(s *socket) {
s.autoReconnect = automaticReconnect
}
}

/*
// TODO(sbinet)
Expand Down
31 changes: 20 additions & 11 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string
autoReconnect bool

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -53,8 +54,9 @@ type socket struct {
listener net.Listener
dialer net.Dialer

closedConns []*Conn
reaperCond *sync.Cond
closedConns []*Conn
reaperCond *sync.Cond
reaperStarted bool
}

func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
Expand Down Expand Up @@ -271,7 +273,10 @@ connect:
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
}

go sck.connReaper()
if !sck.reaperStarted {
go sck.connReaper()
sck.reaperStarted = true
}
sck.addConn(zconn)
return nil
}
Expand Down Expand Up @@ -330,6 +335,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
sck.closedConns = append(sck.closedConns, c)
sck.reaperCond.Signal()
sck.reaperCond.L.Unlock()

if sck.autoReconnect {
sck.Dial(sck.ep)
}
}

// Type returns the type of this Socket (PUB, SUB, ...)
Expand Down
69 changes: 68 additions & 1 deletion socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
if err := pub.Dial(endpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
wg := new(sync.WaitGroup)
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -322,3 +322,70 @@ func TestConnMaxRetriesInfinite(t *testing.T) {
t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries)
}
}

func TestSocketAutomaticReconnect(t *testing.T) {
ep, err := EndPoint("tcp")
if err != nil {
t.Fatalf("could not find endpoint: %+v", err)
}
message := "test"

var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sendMessages := func(socket zmq4.Socket) {
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
for {
socket.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}(t)
}

sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true))
defer sub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
pub := zmq4.NewPub(context.Background())
if err := pub.Listen(ep); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
if err := sub.Dial(ep); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}

sendMessages(pub)

checkConnectionWorking := func(socket zmq4.Socket) {
for {
msg, err := socket.Recv()
if errors.Is(err, io.EOF) {
continue
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
return
}
}

checkConnectionWorking(sub)
pub.Close()

pub2 := zmq4.NewPub(context.Background())
defer pub2.Close()
if err := pub2.Listen(ep); err != nil {
t.Fatalf("Sub Listen failed: %v", err)
}
sendMessages(pub2)
checkConnectionWorking(sub)
}

0 comments on commit 2b7cf28

Please sign in to comment.