Skip to content

Commit

Permalink
Fix default port handling (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Jul 12, 2021
1 parent c43cdb1 commit d80a538
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 47 deletions.
44 changes: 28 additions & 16 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,37 @@ func ExampleClient() {
}

func TestIntegration_Connect(t *testing.T) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cli, err := DialContext(ctx, url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
// Overwrite default port to avoid using privileged port during test.
defaultPorts["ws"] = 9001
defaultPorts["wss"] = 9443

test := func(t *testing.T, urls map[string]string) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cli, err := DialContext(ctx, url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if _, err := cli.Connect(ctx, "Client"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if _, err := cli.Connect(ctx, "Client"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
})
if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
})
}
}
t.Run("WithPort", func(t *testing.T) {
test(t, urls)
})
t.Run("WithoutPort", func(t *testing.T) {
test(t, urlsWithoutPort)
})
}

func TestIntegration_Publish(t *testing.T) {
Expand Down
54 changes: 33 additions & 21 deletions conn_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
)

func TestIntegration_WithTLSCertFiles(t *testing.T) {
// Overwrite default port to avoid using privileged port during test.
defaultPorts["ws"] = 9001
defaultPorts["wss"] = 9443

cases := map[string]struct {
opt DialOption
expectError bool
Expand Down Expand Up @@ -52,31 +56,39 @@ func TestIntegration_WithTLSCertFiles(t *testing.T) {
false,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
test := func(t *testing.T, urls map[string]string) {
for name, c := range cases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cli, err := DialContext(ctx, urls["MQTTs"], c.opt)
cli, err := DialContext(ctx, urls["MQTTs"], c.opt)

if err != nil {
if c.expectError {
return
if err != nil {
if c.expectError {
return
}
t.Fatal(err)
}
t.Fatal(err)
}
defer cli.Close()
defer cli.Close()

if c.expectError {
t.Fatal("Expected error but succeeded")
}
if c.expectError {
t.Fatal("Expected error but succeeded")
}

if _, err := cli.Connect(ctx, "TestConnTLS", WithCleanSession(true)); err != nil {
t.Error(err)
}
if err := cli.Disconnect(ctx); err != nil {
t.Error(err)
}
})
if _, err := cli.Connect(ctx, "TestConnTLS", WithCleanSession(true)); err != nil {
t.Error(err)
}
if err := cli.Disconnect(ctx); err != nil {
t.Error(err)
}
})
}
}
t.Run("WithPort", func(t *testing.T) {
test(t, urls)
})
t.Run("WithoutPort", func(t *testing.T) {
test(t, urlsWithoutPort)
})
}
39 changes: 29 additions & 10 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ import (
"golang.org/x/net/websocket"
)

var defaultPorts = map[string]uint16{
"mqtt": 1883,
"mqtts": 8883,
"wss": 443,
"ws": 80,
}

// ErrUnsupportedProtocol means that the specified scheme in the URL is not supported.
var ErrUnsupportedProtocol = errors.New("unsupported protocol")

Expand Down Expand Up @@ -76,15 +83,25 @@ func (d *URLDialer) DialContext(ctx context.Context) (*BaseClient, error) {

// DialContext creates MQTT client using URL string.
func DialContext(ctx context.Context, urlStr string, opts ...DialOption) (*BaseClient, error) {
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
o := &DialOptions{
Dialer: &net.Dialer{},
}
switch u.Scheme {
case "tls", "ssl", "mqtts", "wss":
o.TLSConfig = &tls.Config{
ServerName: u.Hostname(),
}
}
for _, opt := range opts {
if err := opt(o); err != nil {
return nil, err
}
}
return o.dial(ctx, urlStr)
return o.dial(ctx, u)
}

// DialOption sets option for Dial.
Expand Down Expand Up @@ -155,25 +172,27 @@ func WithConnStateHandler(handler func(ConnState, error)) DialOption {
}
}

func (d *DialOptions) dial(ctx context.Context, urlStr string) (*BaseClient, error) {
func (d *DialOptions) dial(ctx context.Context, u *url.URL) (*BaseClient, error) {
c := &BaseClient{
ConnState: d.ConnState,
MaxPayloadLen: d.MaxPayloadLen,
}

u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
switch u.Scheme {
case "tcp", "mqtt", "tls", "ssl", "mqtts", "wss", "ws":
default:
return nil, wrapErrorf(ErrUnsupportedProtocol, "protocol %s", u.Scheme)
}
hostWithPort := u.Host
if u.Port() == "" {
if port, ok := defaultPorts[u.Scheme]; ok {
hostWithPort += fmt.Sprintf(":%d", port)
}
}

baseConn, err := d.Dialer.DialContext(ctx, "tcp", u.Host)
baseConn, err := d.Dialer.DialContext(ctx, "tcp", hostWithPort)
if err != nil {
return nil, err
return nil, wrapError(err, "dialing tcp")
}
switch u.Scheme {
case "tcp", "mqtt":
Expand All @@ -186,13 +205,13 @@ func (d *DialOptions) dial(ctx context.Context, urlStr string) (*BaseClient, err
case "ws":
wsc, err := websocket.NewConfig(u.String(), fmt.Sprintf("https://%s", u.Host))
if err != nil {
return nil, err
return nil, wrapError(err, "configuring websocket")
}
wsc.Protocol = append(wsc.Protocol, "mqtt")
wsc.TlsConfig = d.TLSConfig
ws, err := websocket.NewClient(wsc, baseConn)
if err != nil {
return nil, err
return nil, wrapError(err, "dialing websocket")
}
ws.PayloadType = websocket.BinaryFrame
c.Transport = ws
Expand Down
6 changes: 6 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ var (
"WebSocket": "ws://localhost:9001",
"WebSockets": "wss://localhost:9443",
}
urlsWithoutPort = map[string]string{
"MQTT": "mqtt://localhost",
"MQTTs": "mqtts://localhost",
"WebSocket": "ws://localhost",
"WebSockets": "wss://localhost",
}
)

func TestDialOptionError(t *testing.T) {
Expand Down

0 comments on commit d80a538

Please sign in to comment.