Skip to content

Commit

Permalink
ConnectPacketBuilder - can noe return error (prevents connection atte…
Browse files Browse the repository at this point in the history
…mpt)

`ConnectPacketBuilder` may be used to retrieve a token for use in auth, it's may not always be possible to retrieve this token so this change enables `ConnectPacketBuilder` to return an error (which prevents the current attempt, another attempt will be made after a delay).
  • Loading branch information
MattBrittan authored Oct 13, 2024
2 parents f1fe38b + b7d62b9 commit 7474a8a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 50 deletions.
16 changes: 10 additions & 6 deletions autopaho/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type ClientConfig struct {
WillMessage *paho.WillMessage
WillProperties *paho.WillProperties

ConnectPacketBuilder func(*paho.Connect, *url.URL) *paho.Connect // called prior to connection allowing customisation of the CONNECT packet
ConnectPacketBuilder func(*paho.Connect, *url.URL) (*paho.Connect, error) // called prior to connection allowing customisation of the CONNECT packet

// DisconnectPacketBuilder - called prior to disconnection allowing customisation of the DISCONNECT
// packet. If the function returns nil, then no DISCONNECT packet will be passed; if nil a default packet is sent.
Expand Down Expand Up @@ -179,8 +179,8 @@ func (cfg *ClientConfig) SetWillMessage(topic string, payload []byte, qos byte,
//
// Deprecated: Set ConnectPacketBuilder directly instead. This function exists for
// backwards compatibility only (and may be removed in the future).
func (cfg *ClientConfig) SetConnectPacketConfigurator(fn func(*paho.Connect) *paho.Connect) bool {
cfg.ConnectPacketBuilder = func(pc *paho.Connect, u *url.URL) *paho.Connect {
func (cfg *ClientConfig) SetConnectPacketConfigurator(fn func(*paho.Connect) (*paho.Connect, error)) bool {
cfg.ConnectPacketBuilder = func(pc *paho.Connect, u *url.URL) (*paho.Connect, error) {
return fn(pc)
}
return fn != nil
Expand All @@ -198,7 +198,7 @@ func (cfg *ClientConfig) SetDisConnectPacketConfigurator(fn func() *paho.Disconn

// buildConnectPacket constructs a Connect packet for the paho client, based on staged configuration.
// If the program uses SetConnectPacketConfigurator, the provided callback will be executed with the preliminary Connect packet representation.
func (cfg *ClientConfig) buildConnectPacket(firstConnection bool, serverURL *url.URL) *paho.Connect {
func (cfg *ClientConfig) buildConnectPacket(firstConnection bool, serverURL *url.URL) (*paho.Connect, error) {

cp := &paho.Connect{
KeepAlive: cfg.KeepAlive,
Expand Down Expand Up @@ -230,10 +230,14 @@ func (cfg *ClientConfig) buildConnectPacket(firstConnection bool, serverURL *url
}

if cfg.ConnectPacketBuilder != nil {
cp = cfg.ConnectPacketBuilder(cp, serverURL)
var err error
cp, err = cfg.ConnectPacketBuilder(cp, serverURL)
if err != nil {
return nil, err
}
}

return cp
return cp, nil
}

// NewConnection creates a connection manager and begins the connection process (will retry until the context is cancelled)
Expand Down
16 changes: 8 additions & 8 deletions autopaho/auto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ func TestClientConfig_buildConnectPacket(t *testing.T) {
}

// Validate initial state
cp := config.buildConnectPacket(true, nil)
cp, _ := config.buildConnectPacket(true, nil)

if !cp.CleanStart {
t.Errorf("Expected Clean Start to be true")
Expand All @@ -573,7 +573,7 @@ func TestClientConfig_buildConnectPacket(t *testing.T) {
config.SetUsernamePassword("testuser", []byte("testpassword"))
config.SetWillMessage(fmt.Sprintf("client/%s/state", config.ClientID), []byte("disconnected"), 1, true)

cp = config.buildConnectPacket(false, nil)
cp, _ = config.buildConnectPacket(false, nil)
if cp.CleanStart {
t.Errorf("Expected Clean Start to be false")
}
Expand Down Expand Up @@ -609,14 +609,14 @@ func TestClientConfig_buildConnectPacket(t *testing.T) {
}

// Set an override method for the CONNECT packet
config.SetConnectPacketConfigurator(func(c *paho.Connect) *paho.Connect {
config.SetConnectPacketConfigurator(func(c *paho.Connect) (*paho.Connect, error) {
delay := uint32(200)
c.WillProperties.WillDelayInterval = &delay
return c
return c, nil
})

testUrl, _ := url.Parse("mqtt://mqtt_user:[email protected]:1883")
cp = config.buildConnectPacket(false, testUrl)
cp, _ = config.buildConnectPacket(false, testUrl)

if *(cp.WillProperties.WillDelayInterval) != 200 { // verifies the override
t.Errorf("Will message Delay Interval did not match expected [200]: found [%v]", *(cp.Properties.WillDelayInterval))
Expand All @@ -634,15 +634,15 @@ func ExampleClientConfig_ConnectPacketBuilder() {
ClientID: "test",
},
}
config.ConnectPacketBuilder = func(c *paho.Connect, u *url.URL) *paho.Connect {
config.ConnectPacketBuilder = func(c *paho.Connect, u *url.URL) (*paho.Connect, error) {
// Extracting password from URL
c.Username = u.User.Username()
// up to user to catch empty password passed via URL
p, _ := u.User.Password()
c.Password = []byte(p)
return c
return c, nil
}
cp := config.buildConnectPacket(false, serverURL)
cp, _ := config.buildConnectPacket(false, serverURL)
fmt.Printf("user: %s, pass: %s", cp.Username, string(cp.Password))
// Output: user: mqtt_user, pass: mqtt_pass
}
74 changes: 38 additions & 36 deletions autopaho/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import (
// context is cancelled (in which case nil will be returned).
func establishServerConnection(ctx context.Context, cfg ClientConfig, firstConnection bool) (*paho.Client, *paho.Connack) {
// Note: We do not touch b.cli in order to avoid adding thread safety issues.
var err error

var attempt int = 0
for {
Expand All @@ -53,48 +52,51 @@ func establishServerConnection(ctx context.Context, cfg ClientConfig, firstConne
return nil, nil
}
for _, u := range cfg.ServerUrls {
connectionCtx, cancelConnCtx := context.WithTimeout(ctx, cfg.ConnectTimeout)

if cfg.AttemptConnection != nil { // Use custom function if it is provided
cfg.Conn, err = cfg.AttemptConnection(ctx, cfg, u)
} else {
switch strings.ToLower(u.Scheme) {
case "mqtt", "tcp", "":
cfg.Conn, err = attemptTCPConnection(connectionCtx, u.Host)
case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps":
cfg.Conn, err = attemptTLSConnection(connectionCtx, cfg.TlsCfg, u.Host)
case "ws":
cfg.Conn, err = attemptWebsocketConnection(connectionCtx, nil, cfg.WebSocketCfg, u)
case "wss":
cfg.Conn, err = attemptWebsocketConnection(connectionCtx, cfg.TlsCfg, cfg.WebSocketCfg, u)
default:
if cfg.OnConnectError != nil {
cfg.OnConnectError(fmt.Errorf("unsupported scheme (%s) user in url %s", u.Scheme, u.String()))
}
cancelConnCtx()
continue
}
}

var connack *paho.Connack

cp, err := cfg.buildConnectPacket(firstConnection, u)
if err == nil {
cli := paho.NewClient(cfg.ClientConfig)
if cfg.PahoDebug != nil {
cli.SetDebugLogger(cfg.PahoDebug)
connectionCtx, cancelConnCtx := context.WithTimeout(ctx, cfg.ConnectTimeout)

if cfg.AttemptConnection != nil { // Use custom function if it is provided
cfg.Conn, err = cfg.AttemptConnection(ctx, cfg, u)
} else {
switch strings.ToLower(u.Scheme) {
case "mqtt", "tcp", "":
cfg.Conn, err = attemptTCPConnection(connectionCtx, u.Host)
case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps":
cfg.Conn, err = attemptTLSConnection(connectionCtx, cfg.TlsCfg, u.Host)
case "ws":
cfg.Conn, err = attemptWebsocketConnection(connectionCtx, nil, cfg.WebSocketCfg, u)
case "wss":
cfg.Conn, err = attemptWebsocketConnection(connectionCtx, cfg.TlsCfg, cfg.WebSocketCfg, u)
default:
if cfg.OnConnectError != nil {
cfg.OnConnectError(fmt.Errorf("unsupported scheme (%s) user in url %s", u.Scheme, u.String()))
}
cancelConnCtx()
continue
}
}

if cfg.PahoErrors != nil {
cli.SetErrorLogger(cfg.PahoErrors)
}
if err == nil {
cli := paho.NewClient(cfg.ClientConfig)
if cfg.PahoDebug != nil {
cli.SetDebugLogger(cfg.PahoDebug)
}

cp := cfg.buildConnectPacket(firstConnection, u)
connack, err = cli.Connect(connectionCtx, cp) // will return an error if the connection is unsuccessful (checks the reason code)
if err == nil { // Successfully connected
cancelConnCtx()
return cli, connack
if cfg.PahoErrors != nil {
cli.SetErrorLogger(cfg.PahoErrors)
}

connack, err = cli.Connect(connectionCtx, cp) // will return an error if the connection is unsuccessful (checks the reason code)
if err == nil { // Successfully connected
cancelConnCtx()
return cli, connack
}
}
cancelConnCtx()
}
cancelConnCtx()

// Possible failure was due to outer context being cancelled
if ctx.Err() != nil {
Expand Down

0 comments on commit 7474a8a

Please sign in to comment.