Skip to content

Commit

Permalink
[CloudSync] Adding TLS for Secure MQTT Broker Client Communication
Browse files Browse the repository at this point in the history
Signed-off-by: Nitu Gupta <[email protected]>
  • Loading branch information
nitu-s-gupta committed Feb 3, 2022
1 parent 0600306 commit 3c7bf28
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 14 deletions.
44 changes: 39 additions & 5 deletions internal/common/mqtt/mqttconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@
package mqtt

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"strings"
"sync"
"time"

MQTT "github.com/eclipse/paho.mqtt.golang"
)

const mqttPort = 1883
const (
edgeDir = "/var/edge-orchestration"
caCertConfig = edgeDir + "/mqtt/certs/cacert.pem"
)

// Client is a wrapper on top of `MQTT.Client`
type Client struct {
Expand All @@ -39,6 +45,7 @@ type Client struct {
sync.RWMutex
ClientOptions *MQTT.ClientOptions
MQTT.Client
protocol string
}

//Message is used to wrap the app id and payload into one and publish to broker
Expand Down Expand Up @@ -80,6 +87,14 @@ func InitClientData() {
clientData = make(map[string]*Client)
}

func (c *Client) setProtocol() {
if c.Port == 8883 {
c.protocol = "tcps"
} else {
c.protocol = "tcp"
}
}

// CheckifClientExist used to check if the client conn object exist
func CheckifClientExist(clientID string) *Client {
client := clientData[clientID]
Expand All @@ -92,21 +107,36 @@ func addClientData(client *Client, clientID string) {
}

//SetBrokerURL returns the broker url for connection
func (c *Client) SetBrokerURL(protocol string) string {
return fmt.Sprintf("%s://%s:%d", protocol, c.Host, c.Port)
func (c *Client) SetBrokerURL() string {
return fmt.Sprintf("%s://%s:%d", c.protocol, c.Host, c.Port)
}

func checkforConnection(brokerURL string, mqttClient *Client) int {
func checkforConnection(brokerURL string, mqttClient *Client, mqttPort uint) int {
if mqttClient == nil {
return 0
}
log.Info(logPrefix, mqttClient.URL)
connURL := fmt.Sprintf("%s://%s:%d", "tcp", brokerURL, mqttPort)
connURL := fmt.Sprintf("%s://%s:%d", mqttClient.protocol, brokerURL, mqttPort)
return strings.Compare(connURL, mqttClient.URL)
}

//NewTLSConfig creates a tls config for mqtt client
func NewTLSConfig() *tls.Config {
certpool := x509.NewCertPool()
ca, err := ioutil.ReadFile(caCertConfig)
if err != nil {
log.Warn(err.Error())
return nil
}
certpool.AppendCertsFromPEM(ca)
return &tls.Config{
RootCAs: certpool,
}
}

// NewClient returns a configured `Client`. Is mandatory
func NewClient(configs ...Config) (*Client, error) {
tlsconfig := NewTLSConfig()
client := &Client{
Qos: byte(1),
}
Expand All @@ -120,6 +150,10 @@ func NewClient(configs ...Config) (*Client, error) {
copts.SetMaxReconnectInterval(1 * time.Second)
copts.SetOnConnectHandler(client.onConnect())
copts.SetConnectionLostHandler(client.onConnectionLost())
copts.SetTLSConfig(tlsconfig)
// TODO Use Username and password to provide authorization to MQTT Broker zuncomment the below two lines to enable authorization after creating password file for each user
/*copts.SetUsername("nitu")
copts.SetPassword("nitu")*/
client.ClientOptions = copts

return client, nil
Expand Down
7 changes: 4 additions & 3 deletions internal/common/mqtt/mqttconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ func (client *Client) Disconnect(quiesce uint) {
}

// StartMQTTClient is used to initiate the client and set the configuration
func StartMQTTClient(brokerURL string, clientID string) string {
func StartMQTTClient(brokerURL string, clientID string, mqttPort uint) string {
log.Info(logPrefix, "Starting the MQTT Client")
//Check if the client connection exist
mqttClient := CheckifClientExist(clientID)
// Check if the connection exist for same url
ifConn := checkforConnection(brokerURL, mqttClient)
ifConn := checkforConnection(brokerURL, mqttClient, mqttPort)
if mqttClient != nil && ifConn == 0 {
log.Info(logPrefix, "Connection Object exist", mqttClient)
if mqttClient.IsConnected() {
Expand Down Expand Up @@ -115,7 +115,8 @@ func StartMQTTClient(brokerURL string, clientID string) string {
return err.Error()
}
clientConfig.ClientOptions.SetOnConnectHandler(clientConfig.onConnect())
URL := clientConfig.SetBrokerURL("tcp")
clientConfig.setProtocol()
URL := clientConfig.SetBrokerURL()
log.Info(logPrefix, " The broker is", URL)
clientConfig.URL = URL
clientConfig.ClientOptions.AddBroker(URL)
Expand Down
12 changes: 7 additions & 5 deletions internal/common/mqtt/mqttconnection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import (
)

const Host = "broker.hivemq.com"
const port = "1883"

var port uint = 1883

const InvalidHost = "invalid"

func initializeTest(Host string, clientID string) {
InitClientData()
StartMQTTClient(Host, clientID)
StartMQTTClient(Host, clientID, port)
}

func TestStartMQTTClient(t *testing.T) {
Expand All @@ -36,15 +38,15 @@ func TestStartMQTTClient(t *testing.T) {
client := clientData["testClient"]
message := Message{"appid", "Test data for testing"}
client.Publish(message, "home/livingroom")
err := StartMQTTClient(Host, "testClient")
err := StartMQTTClient(Host, "testClient", port)
expected := ""
if !strings.Contains(err, expected) {
t.Error("Unexpected err", err)
}
})
t.Run("Fail", func(t *testing.T) {
InitClientData()
err := StartMQTTClient(InvalidHost, "testClientFailure")
err := StartMQTTClient(InvalidHost, "testClientFailure", port)
expected := "dial tcp: lookup invalid: Temporary failure in name resolution"
if !strings.Contains(err, expected) {
t.Error("Unexpected err", err)
Expand All @@ -56,7 +58,7 @@ func TestCheckforConnection(t *testing.T) {
t.Run("Success", func(t *testing.T) {
initializeTest(Host, "CheckConnection")
client := clientData["CheckConnection"]
isConnected := checkforConnection(Host, client)
isConnected := checkforConnection(Host, client, port)
expected := 0
if isConnected != expected {
t.Errorf("Expected %d But received %d", expected, isConnected)
Expand Down
11 changes: 10 additions & 1 deletion internal/controller/cloudsyncmgr/cloudsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package cloudsyncmgr

import (
"fmt"
"os"
"strings"
"sync"

Expand All @@ -46,6 +47,7 @@ var (
log = logmgr.GetInstance()
mqttClient *mqttmgr.Client
isCloudSyncSet bool
mqttPort uint = 1883
)

func init() {
Expand All @@ -65,6 +67,13 @@ func (c *CloudSyncImpl) InitiateCloudSync(isCloudSet string) (err error) {
if strings.Compare(strings.ToLower(isCloudSet), "true") == 0 {
log.Println("CloudSync init set")
isCloudSyncSet = true
secure := os.Getenv("SECURE")
if len(secure) > 0 {
if strings.Compare(strings.ToLower(secure), "true") == 0 {
log.Println(logPrefix, "Orchestration init with secure option")
mqttPort = 8883
}
}
//Intialize the client and hashmap storing client data
mqttmgr.InitClientData()
}
Expand All @@ -87,7 +96,7 @@ func (c *CloudSyncImpl) RequestCloudSyncConf(host string, clientID string, messa
wg.Add(1)
errs := make(chan string, 1)
go func() {
errs <- mqttmgr.StartMQTTClient(host, clientID)
errs <- mqttmgr.StartMQTTClient(host, clientID, mqttPort)
resp = <-errs
wg.Done()

Expand Down

0 comments on commit 3c7bf28

Please sign in to comment.