Skip to content

Commit 41278d6

Browse files
committed
refactor: improve MQTT client implementation and error handling
1 parent c6b09ec commit 41278d6

3 files changed

Lines changed: 40 additions & 55 deletions

File tree

mqtt/client.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package mqtt
22

33
import (
4+
"context"
45
"log"
5-
"time"
66

77
mqtt "github.com/eclipse/paho.mqtt.golang"
88
"github.com/pmoscode/go-common/shutdown"
@@ -13,7 +13,7 @@ const notConnected = "Mqtt Client not connected! Call 'connect' method..."
1313

1414
// Client wraps the paho MQTT client with convenience methods for connect, publish, subscribe, and disconnect.
1515
type Client struct {
16-
client *mqtt.Client
16+
client mqtt.Client
1717
options *mqtt.ClientOptions
1818
}
1919

@@ -34,15 +34,15 @@ func (c *Client) connect(client mqtt.Client) error {
3434
return token.Error()
3535
}
3636

37-
c.client = &client
37+
c.client = client
3838
log.Println("Mqtt connected to", c.options.Servers[0])
3939

4040
return nil
4141
}
4242

4343
// Disconnect does a clean disconnect from the mqtt broker.
4444
func (c *Client) Disconnect() error {
45-
(*c.client).Disconnect(100)
45+
c.client.Disconnect(100)
4646

4747
return nil
4848
}
@@ -53,7 +53,7 @@ func (c *Client) Publish(message *Message) {
5353
if c.client == nil {
5454
log.Println(notConnected)
5555
} else {
56-
token := (*c.client).Publish(message.Topic, 2, false, message.FromJson())
56+
token := c.client.Publish(message.Topic, 2, false, message.FromJson())
5757
token.Wait()
5858
}
5959
}
@@ -64,7 +64,7 @@ func (c *Client) Subscribe(topic string, fn func(message Message)) {
6464
if c.client == nil {
6565
log.Println(notConnected)
6666
} else {
67-
(*c.client).Subscribe(topic, 2, func(client mqtt.Client, msg mqtt.Message) {
67+
c.client.Subscribe(topic, 2, func(client mqtt.Client, msg mqtt.Message) {
6868
message := Message{
6969
Topic: msg.Topic(),
7070
Value: msg.Payload(),
@@ -74,15 +74,13 @@ func (c *Client) Subscribe(topic string, fn func(message Message)) {
7474
}
7575
}
7676

77-
// LoopForever Halts the current thread.
78-
func (c *Client) LoopForever() {
77+
// LoopForever halts the current thread until the context is cancelled.
78+
// On cancellation, the client will be disconnected cleanly.
79+
func (c *Client) LoopForever(ctx context.Context) {
7980
if c.client == nil {
8081
log.Println(notConnected)
8182
} else {
8283
shutdown.GetObserver().AddCommand(c.Disconnect)
83-
84-
for {
85-
time.Sleep(10 * time.Second)
86-
}
84+
<-ctx.Done()
8785
}
8886
}

mqtt/message-type.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package mqtt
33
import (
44
"encoding/json"
55
"fmt"
6-
"log"
76
)
87

98
// Message defines the default message which is received and sent by the mqtt client.
@@ -24,16 +23,24 @@ func (m Message) FromJson() string {
2423
return string(marshal)
2524
}
2625

27-
// ToStruct converts the received JSON message to the given target struct
28-
func (m Message) ToStruct(target any) {
29-
message := m.ToRawString()
30-
err := json.Unmarshal([]byte(message), &target)
26+
// ToStruct converts the received JSON message to the given target struct.
27+
// Returns an error if the message value is not a byte slice or not valid JSON.
28+
func (m Message) ToStruct(target any) error {
29+
message, err := m.ToRawString()
3130
if err != nil {
32-
log.Println("Message is not a valid Json: ", message)
31+
return err
3332
}
33+
if err := json.Unmarshal([]byte(message), &target); err != nil {
34+
return fmt.Errorf("message is not valid JSON: %w", err)
35+
}
36+
return nil
3437
}
3538

36-
// ToRawString converts the received JSON string to an ordinary string
37-
func (m Message) ToRawString() string {
38-
return string(m.Value.([]uint8))
39+
// ToRawString converts the received message value to an ordinary string.
40+
// Returns an error if the value is not a byte slice.
41+
func (m Message) ToRawString() (string, error) {
42+
if v, ok := m.Value.([]uint8); ok {
43+
return string(v), nil
44+
}
45+
return "", fmt.Errorf("message value is not a byte slice, got %T", m.Value)
3946
}

mqtt/mqtt_test.go

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"time"
66

77
mqtt "github.com/eclipse/paho.mqtt.golang"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
810
)
911

1012
func TestConnection(t *testing.T) {
@@ -15,9 +17,7 @@ func TestConnection(t *testing.T) {
1517
WithUsernameAndPassword("test", "pwd")
1618

1719
brokerOpt, err := hostConfig.Build()
18-
if err != nil {
19-
t.Fatal(err)
20-
}
20+
require.NoError(t, err)
2121

2222
client := NewClient(
2323
WithClientId("test"),
@@ -26,21 +26,13 @@ func TestConnection(t *testing.T) {
2626

2727
mockClientImpl := &mockClient{}
2828
err = client.connect(mockClientImpl)
29-
if err != nil {
30-
t.Fatal(err)
31-
}
29+
require.NoError(t, err)
3230

3331
err = client.Disconnect()
34-
if err != nil {
35-
t.Fatal(err)
36-
}
32+
require.NoError(t, err)
3733

38-
if mockClientImpl.cntConnect != 1 {
39-
t.Fatal("Wrong connect count: ", mockClientImpl.cntConnect, " should be ", 1)
40-
}
41-
if mockClientImpl.cntDisconnect != 1 {
42-
t.Fatal("Wrong disconnect count: ", mockClientImpl.cntDisconnect, " should be ", 1)
43-
}
34+
assert.Equal(t, 1, mockClientImpl.cntConnect)
35+
assert.Equal(t, 1, mockClientImpl.cntDisconnect)
4436
}
4537

4638
func TestPublish(t *testing.T) {
@@ -50,9 +42,7 @@ func TestPublish(t *testing.T) {
5042
WithProtocol(MqttTcp)
5143

5244
brokerOpt, err := hostConfig.Build()
53-
if err != nil {
54-
t.Fatal(err)
55-
}
45+
require.NoError(t, err)
5646

5747
client := NewClient(
5848
WithClientId("test"),
@@ -61,29 +51,19 @@ func TestPublish(t *testing.T) {
6151

6252
mockClientImpl := &mockClient{}
6353
err = client.connect(mockClientImpl)
64-
if err != nil {
65-
t.Fatal(err)
66-
}
54+
require.NoError(t, err)
6755

6856
client.Publish(&Message{
6957
Topic: "/test/testMessage",
7058
Value: "{'test': 2}",
7159
})
7260

7361
err = client.Disconnect()
74-
if err != nil {
75-
t.Fatal(err)
76-
}
77-
78-
if mockClientImpl.cntConnect != 1 {
79-
t.Fatal("Wrong connect count: ", mockClientImpl.cntConnect, " should be ", 1)
80-
}
81-
if mockClientImpl.cntDisconnect != 1 {
82-
t.Fatal("Wrong disconnect count: ", mockClientImpl.cntDisconnect, " should be ", 1)
83-
}
84-
if mockClientImpl.cntPublish != 1 {
85-
t.Fatal("Wrong publish count: ", mockClientImpl.cntPublish, " should be ", 1)
86-
}
62+
require.NoError(t, err)
63+
64+
assert.Equal(t, 1, mockClientImpl.cntConnect)
65+
assert.Equal(t, 1, mockClientImpl.cntDisconnect)
66+
assert.Equal(t, 1, mockClientImpl.cntPublish)
8767
}
8868

8969
type mockClient struct {

0 commit comments

Comments
 (0)