@@ -19,22 +19,47 @@ var upgrader = websocket.Upgrader{
1919 },
2020}
2121
22+ // safeConn wraps a WebSocket connection with a write mutex to prevent concurrent writes
23+ type safeConn struct {
24+ conn * websocket.Conn
25+ writeMu sync.Mutex
26+ }
27+
28+ // WriteJSON safely writes JSON to the WebSocket connection
29+ func (sc * safeConn ) WriteJSON (v interface {}) error {
30+ sc .writeMu .Lock ()
31+ defer sc .writeMu .Unlock ()
32+ return sc .conn .WriteJSON (v )
33+ }
34+
35+ // WriteMessage safely writes a message to the WebSocket connection
36+ func (sc * safeConn ) WriteMessage (messageType int , data []byte ) error {
37+ sc .writeMu .Lock ()
38+ defer sc .writeMu .Unlock ()
39+ return sc .conn .WriteMessage (messageType , data )
40+ }
41+
42+ // Close closes the underlying connection
43+ func (sc * safeConn ) Close () error {
44+ return sc .conn .Close ()
45+ }
46+
2247// WebSocketHandler handles WebSocket connections for progress updates
2348type WebSocketHandler struct {
24- clients map [* websocket. Conn ]bool
49+ clients map [* safeConn ]bool
2550 broadcast chan * worker.ProgressUpdate
26- register chan * websocket. Conn
27- unregister chan * websocket. Conn
51+ register chan * safeConn
52+ unregister chan * safeConn
2853 mu sync.RWMutex
2954}
3055
3156// NewWebSocketHandler creates a new WebSocket handler
3257func NewWebSocketHandler () * WebSocketHandler {
3358 handler := & WebSocketHandler {
34- clients : make (map [* websocket. Conn ]bool ),
59+ clients : make (map [* safeConn ]bool ),
3560 broadcast : make (chan * worker.ProgressUpdate , 100 ),
36- register : make (chan * websocket. Conn ),
37- unregister : make (chan * websocket. Conn ),
61+ register : make (chan * safeConn ),
62+ unregister : make (chan * safeConn ),
3863 }
3964
4065 go handler .run ()
@@ -62,16 +87,19 @@ func (h *WebSocketHandler) run() {
6287
6388 case update := <- h .broadcast :
6489 h .mu .RLock ()
90+ clients := make ([]* safeConn , 0 , len (h .clients ))
6591 for client := range h .clients {
66- // Send update to client
92+ clients = append (clients , client )
93+ }
94+ h .mu .RUnlock ()
95+
96+ // Write to clients outside the lock to avoid deadlocks
97+ for _ , client := range clients {
6798 if err := client .WriteJSON (update ); err != nil {
6899 log .Printf ("WebSocket write error: %v" , err )
69- h .mu .RUnlock ()
70100 h .unregister <- client
71- h .mu .RLock ()
72101 }
73102 }
74- h .mu .RUnlock ()
75103 }
76104 }
77105}
@@ -84,16 +112,17 @@ func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
84112 return
85113 }
86114
87- h .register <- conn
115+ sc := & safeConn {conn : conn }
116+ h .register <- sc
88117
89118 // Keep connection alive with ping/pong
90119 go func () {
91120 ticker := time .NewTicker (30 * time .Second )
92121 defer ticker .Stop ()
93122
94123 for range ticker .C {
95- if err := conn .WriteMessage (websocket .PingMessage , nil ); err != nil {
96- h .unregister <- conn
124+ if err := sc .WriteMessage (websocket .PingMessage , nil ); err != nil {
125+ h .unregister <- sc
97126 return
98127 }
99128 }
@@ -103,7 +132,7 @@ func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
103132 for {
104133 _ , _ , err := conn .ReadMessage ()
105134 if err != nil {
106- h .unregister <- conn
135+ h .unregister <- sc
107136 break
108137 }
109138 }
0 commit comments