@@ -2,19 +2,25 @@ use std::sync::Arc;
22use std:: time:: Duration ;
33use axum:: { Extension , Json } ;
44use axum:: extract:: { Query , State } ;
5- use axum:: response:: { Sse } ;
5+ use axum:: extract:: ws:: { Message , WebSocket , WebSocketUpgrade } ;
6+ use axum:: response:: { IntoResponse , Sse } ;
67use axum:: response:: sse:: Event ;
8+ use bytes:: Bytes ;
79use chrono:: { DateTime , Utc } ;
810use futures:: Stream ;
9- use log:: error;
11+ use tokio:: time;
12+ use log:: { debug, error} ;
1013use serde:: Deserialize ;
14+ use tokio:: sync:: broadcast:: error:: RecvError ;
1115use tokio_stream:: wrappers:: BroadcastStream ;
1216use tokio_stream:: wrappers:: errors:: BroadcastStreamRecvError ;
17+ use tracing:: warn;
1318use uuid:: Uuid ;
1419use crate :: broadcast:: { BroadcastChannel , Notification } ;
1520use crate :: core:: AppState ;
1621use crate :: errors:: { AppError , AppResponse } ;
1722use crate :: keycloak:: decode:: KeycloakToken ;
23+ use crate :: keycloak:: layer:: KeycloakAuthLayer ;
1824
1925struct ConnectionGuard {
2026 user_id : Uuid ,
@@ -64,6 +70,69 @@ pub async fn stream_server_events(
6470 )
6571}
6672
73+
74+ pub async fn websocket_server_events (
75+ websocket : WebSocketUpgrade ,
76+ Extension ( token) : Extension < KeycloakToken < String > >
77+ ) -> impl IntoResponse {
78+
79+ websocket
80+ . on_failed_upgrade ( |error| warn ! ( "Error upgrading websocket: {}" , error) )
81+ . on_upgrade ( move |socket| handle_socket ( socket, token. subject . clone ( ) ) )
82+ }
83+
84+ async fn handle_socket ( mut socket : WebSocket , user_id : Uuid ) {
85+
86+ let mut broadcast_events = BroadcastChannel :: get ( ) . subscribe_to_user_events ( user_id. clone ( ) ) . await ;
87+ let _guard = ConnectionGuard { user_id } ;
88+ let mut ping_interval = time:: interval ( Duration :: from_secs ( 30 ) ) ;
89+
90+ loop {
91+ tokio:: select! {
92+ // 1. Handle new broadcasting event:
93+ notification_result = broadcast_events. recv( ) => {
94+ match notification_result {
95+ Ok ( event) => {
96+ let json_msg = serde_json:: to_string( & event) . unwrap( ) ;
97+ let ws_message = Message :: text( json_msg) ;
98+
99+ if socket. send( ws_message) . await . is_err( ) {
100+ error!( "Failed to send message to client" ) ;
101+ }
102+ }
103+ Err ( RecvError :: Closed ) => {
104+ debug!( "Client disconnected or channel closed" ) ;
105+ break ;
106+ }
107+ Err ( RecvError :: Lagged ( _) ) => {
108+ debug!( "Client is too slow!" )
109+ }
110+ }
111+ }
112+
113+ // 2. Regular ping from ism:
114+ _ = ping_interval. tick( ) => {
115+ if socket. send( Message :: Ping ( Bytes :: new( ) ) ) . await . is_err( ) { // connection is dead when we can't send ping
116+ break ;
117+ }
118+ }
119+
120+ // 3. Receive messages from the client:
121+ client_msg = socket. recv( ) => {
122+ match client_msg {
123+ Some ( Ok ( Message :: Close ( _) ) ) | None => break , //client is closing connection
124+ Some ( Err ( _) ) => break , //client error
125+ Some ( Ok ( Message :: Pong ( _) ) ) => {
126+ debug!( "Client has sent Pong" ) ;
127+ }
128+ _ => { } //for the future
129+ }
130+ }
131+ }
132+ }
133+ }
134+
135+
67136#[ derive( Deserialize ) ]
68137pub struct NotificationQueryParam {
69138 timestamp : DateTime < Utc >
0 commit comments