Skip to content

Commit 843bdf4

Browse files
committed
Fix: WebSocket handling to use a safe connection wrapper, preventing concurrent write issues and improving thread safety.
1 parent 5faee9a commit 843bdf4

1 file changed

Lines changed: 43 additions & 14 deletions

File tree

internal/api/websocket.go

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,47 @@ var upgrader = websocket.Upgrader{
1919
},
2020
}
2121

22+
// safeConn wraps a WebSocket connection with a write mutex to prevent concurrent writes
23+
type safeConn struct {
24+
conn *websocket.Conn
25+
writeMu sync.Mutex
26+
}
27+
28+
// WriteJSON safely writes JSON to the WebSocket connection
29+
func (sc *safeConn) WriteJSON(v interface{}) error {
30+
sc.writeMu.Lock()
31+
defer sc.writeMu.Unlock()
32+
return sc.conn.WriteJSON(v)
33+
}
34+
35+
// WriteMessage safely writes a message to the WebSocket connection
36+
func (sc *safeConn) WriteMessage(messageType int, data []byte) error {
37+
sc.writeMu.Lock()
38+
defer sc.writeMu.Unlock()
39+
return sc.conn.WriteMessage(messageType, data)
40+
}
41+
42+
// Close closes the underlying connection
43+
func (sc *safeConn) Close() error {
44+
return sc.conn.Close()
45+
}
46+
2247
// WebSocketHandler handles WebSocket connections for progress updates
2348
type WebSocketHandler struct {
24-
clients map[*websocket.Conn]bool
49+
clients map[*safeConn]bool
2550
broadcast chan *worker.ProgressUpdate
26-
register chan *websocket.Conn
27-
unregister chan *websocket.Conn
51+
register chan *safeConn
52+
unregister chan *safeConn
2853
mu sync.RWMutex
2954
}
3055

3156
// NewWebSocketHandler creates a new WebSocket handler
3257
func NewWebSocketHandler() *WebSocketHandler {
3358
handler := &WebSocketHandler{
34-
clients: make(map[*websocket.Conn]bool),
59+
clients: make(map[*safeConn]bool),
3560
broadcast: make(chan *worker.ProgressUpdate, 100),
36-
register: make(chan *websocket.Conn),
37-
unregister: make(chan *websocket.Conn),
61+
register: make(chan *safeConn),
62+
unregister: make(chan *safeConn),
3863
}
3964

4065
go handler.run()
@@ -62,16 +87,19 @@ func (h *WebSocketHandler) run() {
6287

6388
case update := <-h.broadcast:
6489
h.mu.RLock()
90+
clients := make([]*safeConn, 0, len(h.clients))
6591
for client := range h.clients {
66-
// Send update to client
92+
clients = append(clients, client)
93+
}
94+
h.mu.RUnlock()
95+
96+
// Write to clients outside the lock to avoid deadlocks
97+
for _, client := range clients {
6798
if err := client.WriteJSON(update); err != nil {
6899
log.Printf("WebSocket write error: %v", err)
69-
h.mu.RUnlock()
70100
h.unregister <- client
71-
h.mu.RLock()
72101
}
73102
}
74-
h.mu.RUnlock()
75103
}
76104
}
77105
}
@@ -84,16 +112,17 @@ func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
84112
return
85113
}
86114

87-
h.register <- conn
115+
sc := &safeConn{conn: conn}
116+
h.register <- sc
88117

89118
// Keep connection alive with ping/pong
90119
go func() {
91120
ticker := time.NewTicker(30 * time.Second)
92121
defer ticker.Stop()
93122

94123
for range ticker.C {
95-
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
96-
h.unregister <- conn
124+
if err := sc.WriteMessage(websocket.PingMessage, nil); err != nil {
125+
h.unregister <- sc
97126
return
98127
}
99128
}
@@ -103,7 +132,7 @@ func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
103132
for {
104133
_, _, err := conn.ReadMessage()
105134
if err != nil {
106-
h.unregister <- conn
135+
h.unregister <- sc
107136
break
108137
}
109138
}

0 commit comments

Comments
 (0)