Skip to content

Commit

Permalink
Avoid RetryClient task processing before base client connect (#174)
Browse files Browse the repository at this point in the history
* Add test cases
* Pause RetryClient task runner loop until the base client is connected
* Stop task runner on Disconnect
* Fix to use RLock as possible
  • Loading branch information
at-wat authored Jun 8, 2021
1 parent 01dbf48 commit 41f1d05
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 65 deletions.
74 changes: 62 additions & 12 deletions retryclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ var ErrClosedClient = errors.New("operation on closed client")

// RetryClient queues unacknowledged messages and retry on reconnect.
type RetryClient struct {
cli *BaseClient
cli *BaseClient
chConnectErr chan error
chConnSwitch chan struct{}

retryQueue []retryFn
subEstablished subscriptions // acknoledged subscriptions
mu sync.Mutex
mu sync.RWMutex
handler Handler
chTask chan struct{}
taskQueue []func(ctx context.Context, cli *BaseClient)
Expand All @@ -48,9 +50,9 @@ func (c *RetryClient) Handle(handler Handler) {
// Publish tries to publish the message and immediately returns.
// If it is not acknowledged to be published, the message will be queued.
func (c *RetryClient) Publish(ctx context.Context, message *Message) error {
c.mu.Lock()
c.mu.RLock()
cli := c.cli
c.mu.Unlock()
c.mu.RUnlock()

if cli != nil {
if err := cli.ValidateMessage(message); err != nil {
Expand Down Expand Up @@ -170,21 +172,24 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics .
func (c *RetryClient) Disconnect(ctx context.Context) error {
return wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) {
cli.Disconnect(ctx)
c.mu.Lock()
close(c.chTask)
c.mu.Unlock()
}), "retryclient: disconnecting")
}

// Ping to the broker.
func (c *RetryClient) Ping(ctx context.Context) error {
c.mu.Lock()
c.mu.RLock()
cli := c.cli
c.mu.Unlock()
c.mu.RUnlock()
return wrapError(cli.Ping(ctx), "retryclient: pinging")
}

// Client returns the base client.
func (c *RetryClient) Client() *BaseClient {
c.mu.Lock()
defer c.mu.Unlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.cli
}

Expand All @@ -194,6 +199,11 @@ func (c *RetryClient) Client() *BaseClient {
func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) {
c.mu.Lock()
c.cli = cli
c.chConnectErr = make(chan error, 1)
if c.chConnSwitch != nil {
close(c.chConnSwitch)
}
c.chConnSwitch = make(chan struct{})
c.mu.Unlock()

if c.chTask != nil {
Expand All @@ -202,20 +212,55 @@ func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) {

c.chTask = make(chan struct{}, 1)
go func() {
connected := false
ctx := context.Background()

L_TASK:
for {
if !connected {
// Wait Connect if Client was replaced by SetClient.
for {
c.mu.RLock()
chConnectErr := c.chConnectErr
chConnSwitch := c.chConnSwitch
c.mu.RUnlock()
select {
case _, ok := <-chConnectErr:
if !ok {
connected = true
continue L_TASK
}
case <-chConnSwitch:
}
}
}

c.mu.Lock()
chConnSwitch := c.chConnSwitch
select {
case <-chConnSwitch:
c.mu.Unlock()
connected = false
continue
default:
}

if len(c.taskQueue) == 0 {
c.mu.Unlock()
_, ok := <-c.chTask
if !ok {
return

select {
case _, ok := <-c.chTask:
if !ok {
return
}
case <-chConnSwitch:
connected = false
}
continue
}
cli := c.cli
task := c.taskQueue[0]
c.taskQueue = c.taskQueue[1:]
cli := c.cli
c.mu.Unlock()

task(ctx, cli)
Expand Down Expand Up @@ -248,9 +293,14 @@ func (c *RetryClient) Connect(ctx context.Context, clientID string, opts ...Conn
c.mu.Lock()
cli := c.cli
cli.Handle(c.handler)
chConnectErr := c.chConnectErr
c.mu.Unlock()

present, err := cli.Connect(ctx, clientID, opts...)
if err != nil {
chConnectErr <- err
}
close(chConnectErr)

return present, wrapError(err, "retryclient: connecting")
}
Expand Down
215 changes: 162 additions & 53 deletions retryclient_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package mqtt
import (
"context"
"crypto/tls"
"sync/atomic"
"testing"
"time"

"github.com/at-wat/mqtt-go/internal/filteredpipe"
)

func TestIntegration_RetryClient(t *testing.T) {
Expand Down Expand Up @@ -118,69 +121,175 @@ func TestIntegration_RetryClient_Cancel(t *testing.T) {
}

func TestIntegration_RetryClient_TaskQueue(t *testing.T) {
cliBase, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
type pubTiming string
const (
pubBeforeSetClient pubTiming = "BeforeSetClient"
pubBeforeConnect pubTiming = "BeforeConnect"
pubAfterConnect pubTiming = "AfterConnect"
)
pubTimings := []pubTiming{
pubBeforeSetClient, pubBeforeConnect, pubAfterConnect,
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for _, withWait := range []bool{true, false} {
name := "WithoutWait"
if withWait {
name = "WithWait"
}
withWait := withWait
t.Run(name, func(t *testing.T) {
for _, pubAt := range pubTimings {
pubAt := pubAt
t.Run(string(pubAt), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctxDone, done := context.WithCancel(context.Background())
defer done()

var cli RetryClient
cli.SetClient(ctx, cliBase)
var cnt int
const expectedCount = 100

if _, err := cli.Connect(ctx, "RetryClientQueue"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
cliRecv, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if _, err := cliRecv.Connect(ctx, "RetryClientQueueRecv"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

ctxDone, done := context.WithCancel(context.Background())
defer done()

var cnt int
cli.Handle(HandlerFunc(func(msg *Message) {
if err := cli.Publish(ctx, &Message{
Topic: "test/queue_response",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Errorf("Unexpected error: '%v'", err)
return
}
cnt++
if cnt == 100 {
done()
}
}))
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil {
t.Fatal(err)
if _, err := cliRecv.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil {
t.Fatal(err)
}
cliRecv.Handle(HandlerFunc(func(*Message) {
cnt++
if cnt == expectedCount {
done()
}
}))

cliBase, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

var cli RetryClient
publish := func() {
for i := 0; i < expectedCount; i++ {
if err := cli.Publish(ctx, &Message{
Topic: "test/queue",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Errorf("Unexpected error: '%v' (cnt=%d)", err, cnt)
return
}
select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
default:
}
}
}

if pubAt == pubBeforeSetClient {
publish()
}
if withWait {
time.Sleep(50 * time.Millisecond)
}
cli.SetClient(ctx, cliBase)

if withWait {
time.Sleep(50 * time.Millisecond)
}
// Ensure there is no deadlock when SetClient before Connect.
cli.SetClient(ctx, cliBase)

if pubAt == pubBeforeConnect {
publish()
}
if withWait {
time.Sleep(50 * time.Millisecond)
}

if _, err := cli.Connect(ctx, "RetryClientQueue"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if pubAt == pubAfterConnect {
publish()
}

select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
case <-ctxDone.Done():
}

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
})
}
})
}
time.Sleep(10 * time.Millisecond)
}

func() {
for i := 0; i < 100; i++ {
if err := cli.Publish(ctx, &Message{
Topic: "test/queue",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Errorf("Unexpected error: '%v' (cnt=%d)", err, cnt)
return
func TestIntegration_RetryClient_RetryInitialRequest(t *testing.T) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

topic := "test/RetryInitialReq" + name
var sw int32

cli, err := NewReconnectClient(
DialerFunc(func() (*BaseClient, error) {
cli, err := Dial(url,
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
)
if err != nil {
return nil, err
}
ca, cb := filteredpipe.DetectAndClosePipe(
newOnOffFilter(&sw),
newOnOffFilter(&sw),
)
filteredpipe.Connect(ca, cli.Transport)
cli.Transport = cb
return cli, nil
}),
WithReconnectWait(50*time.Millisecond, 200*time.Millisecond),
WithPingInterval(250*time.Millisecond),
WithTimeout(250*time.Millisecond),
)
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
default:

if _, err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: QoS1}); err != nil {
t.Fatal(err)
}
}
}()
time.Sleep(100 * time.Millisecond)

select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
case <-ctxDone.Done():
}
// Disconnect
atomic.StoreInt32(&sw, 1)
go func() {
time.Sleep(300 * time.Millisecond)
// Connect
atomic.StoreInt32(&sw, 0)
}()

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
if _, err := cli.Connect(ctx, "RetryInitialReq"+name); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := ctx.Err(); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

cli.Disconnect(ctx)
})
}
}

0 comments on commit 41f1d05

Please sign in to comment.