@@ -10,7 +10,9 @@ import (
1010 "io"
1111 "io/ioutil"
1212 "net/http"
13+ "strconv"
1314 "strings"
15+ "sync"
1416 "sync/atomic"
1517 "time"
1618
@@ -32,7 +34,8 @@ type API struct {
3234
3335 CancelPreviouseCommunicationChan chan struct {}
3436 WebsocketReq chan []byte
35- WebsocketResp chan []byte
37+ WebsocketRespLock sync.Mutex
38+ WebsocketResp []chan []byte
3639
3740 Cache map [string ]time.Time
3841}
@@ -42,7 +45,7 @@ func NewAPI() *API {
4245 return & API {
4346 CancelPreviouseCommunicationChan : make (chan struct {}),
4447 WebsocketReq : make (chan []byte ),
45- WebsocketResp : make ( chan []byte ) ,
48+ WebsocketResp : [] chan []byte {} ,
4649
4750 Cache : map [string ]time.Time {},
4851 }
@@ -251,19 +254,62 @@ type WSMsg[T any] struct {
251254 Data T `json:"data"`
252255}
253256
254- // ConnectToWS connects to the rtcv websocket
255- func (a * API ) ConnectToWS () {
257+ // HandleWebsocketResponse handles a websocket response
258+ // This decodes the payload and checks to which connected websocket it should be sent
259+ func (a * API ) HandleWebsocketResponse (payload []byte ) {
256260 if a .MockMode {
257- go func () {
258- for {
259- resp := <- a .WebsocketResp
260- fmt .Println ("got websocket response but we are in mock mode" , string (resp ))
261- }
262- }()
263261 return
264262 }
265263
266- server := a .connections [0 ]
264+ data := WSMsg [json.RawMessage ]{}
265+ err := json .Unmarshal (payload , & data )
266+ if err != nil {
267+ fmt .Println ("error un-marshaling websocket response, error:" , err )
268+ return
269+ }
270+
271+ idParts := strings .SplitN (data .ID , "-" , 2 )
272+ if len (idParts ) != 2 {
273+ fmt .Println ("error invalid id in websocket response, expected 2 parts but got 1" )
274+ return
275+ }
276+
277+ idx , err := strconv .Atoi (idParts [0 ])
278+ if err != nil {
279+ fmt .Println ("error invalid id connection index in websocket response, error:" , err )
280+ return
281+ }
282+
283+ data .ID = idParts [1 ]
284+
285+ // Re-encode data with the new ID
286+ payload , err = json .Marshal (data )
287+ if err != nil {
288+ fmt .Println ("error marshaling websocket response, error:" , err )
289+ return
290+ }
291+
292+ a .WebsocketRespLock .Lock ()
293+ // Data sending of the channel is thread safe but fething the array index is not hence why we lock WebsocketRespLock
294+ a .WebsocketResp [idx ] <- payload
295+ a .WebsocketRespLock .Unlock ()
296+ }
297+
298+ // ConnectToAllWebsockets connects to all the conenctions their websocket
299+ func (a * API ) ConnectToAllWebsockets () {
300+ if a .MockMode {
301+ return
302+ }
303+
304+ a .WebsocketResp = make ([]chan []byte , len (a .connections ))
305+ for idx := 0 ; idx < len (a .connections ); idx ++ {
306+ go a .connectToWS (idx )
307+ }
308+ }
309+
310+ // connectToWS connects to the rtcv websocket
311+ func (a * API ) connectToWS (idx int ) {
312+ server := a .connections [idx ]
267313
268314 url := server .serverLocation
269315 url = strings .Replace (url , "http://" , "ws://" , 1 )
@@ -277,17 +323,22 @@ func (a *API) ConnectToWS() {
277323 }
278324 }()
279325
280- go func () {
326+ a .WebsocketRespLock .Lock ()
327+ a .WebsocketResp [idx ] = make (chan []byte )
328+ listenChan := & a .WebsocketResp [idx ]
329+ a .WebsocketRespLock .Unlock ()
330+
331+ go func (ws * chan []byte ) {
281332 for {
282333 // TODO: if the response fails to send data might get lost.
283334 // It would be nice if the response is retried when WriteMessage fails
284- resp := <- a .WebsocketResp
335+ resp := <- a .WebsocketResp [ idx ]
285336 err := c .WriteMessage (1 , resp )
286337 if err != nil {
287338 fmt .Println ("unable to write ws response:" , err )
288339 }
289340 }
290- }()
341+ }(listenChan )
291342
292343 firstMessage := true
293344 var aMessageWasHandled atomic.Bool
@@ -314,7 +365,17 @@ func (a *API) ConnectToWS() {
314365 msg := WSMsg [json.RawMessage ]{}
315366 err = json .Unmarshal (msgBytes , & msg )
316367 if err != nil {
317- fmt .Println ("error unmarshaling web socket message:" , err )
368+ fmt .Println ("error un-marshaling web socket message:" , err )
369+ continue
370+ }
371+
372+ // We inject the index of the server connection into the message id so we know where to send the response to later
373+ // See the /server_response for how we handle the response
374+ msg .ID = fmt .Sprintf ("%d-%s" , idx , msg .ID )
375+
376+ msgBytes , err = json .Marshal (msg )
377+ if err != nil {
378+ fmt .Println ("error marshaling web socket message:" , err )
318379 continue
319380 }
320381
0 commit comments