Skip to content

Commit 2eb00a7

Browse files
committed
fix(websocket): respond to getState requests directly over WebSocket
The getState handler was calling app.emit("timer:state-query", ...) which fires a Tauri frontend IPC event instead of writing back through the WebSocket connection. The requesting client received nothing. The fix introduces a tokio mpsc channel per connection so the receive task can push direct replies to the send task, which holds the WebSocket sender. handle_client_message is refactored to accept Option<TimerSnapshot> and an unbounded sender, removing the AppHandle dependency and making the function unit-testable without a full Tauri app. Adds five unit tests and one network-level integration test (tokio-tungstenite). Closes #415
1 parent 98d9ee5 commit 2eb00a7

4 files changed

Lines changed: 145 additions & 39 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
### Bug Fixes
44

5+
- **WebSocket `getState` returns no response** — sending `{"type":"getState"}` over the WebSocket API would receive no reply. The handler was routing the response to a Tauri frontend IPC event (`timer:state-query`) instead of writing it back through the WebSocket connection. The fix introduces a per-connection `tokio::sync::mpsc` channel so the receive task can deliver direct replies through the send task, which holds the WebSocket sender. `getState` now correctly responds with `{"type":"state","payload":{...}}` to the requesting client only.
56
- **Timer not restarting correctly after quickly starting the next round** — when a round completed and the user clicked Start before the engine's follow-up duration update arrived, the update (a `Reconfigure` command) would force the engine back to Idle, cancelling the freshly started timer. The follow-up is now sent as a lighter-weight `Prime` command that updates the stored duration in place without affecting the running phase. Contributed by [@SeanTong11](https://github.com/SeanTong11).
67
- **Timer completing instantly when a stale duration update arrives mid-round** — in a rare race, the engine could receive a `Prime` command carrying a duration shorter than the already-elapsed time (e.g. if the round duration was shortened in settings while a timer was running). Without a guard this caused the timer to complete on the very next tick. The `Prime` handler now clamps the new duration to at least one tick beyond the current elapsed position so the timer always advances at least once before completing.
78

src-tauri/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src-tauri/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ rodio = { version = "0.22", default-features = false, features = ["playback", "w
4040
tiny-skia = "0.12"
4141
notify = "8"
4242

43+
[dev-dependencies]
44+
tokio-tungstenite = "0.29"
45+
4346
[target.'cfg(target_os = "macos")'.dependencies]
4447
objc2 = "0.6"
4548
raw-window-handle = "0.6"

src-tauri/src/websocket/mod.rs

Lines changed: 140 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use futures_util::{SinkExt, StreamExt};
2929
use tauri::{AppHandle, Emitter, Manager};
3030
use tokio::{
3131
net::TcpListener,
32-
sync::broadcast,
32+
sync::{broadcast, mpsc},
3333
task::JoinHandle,
3434
};
3535

@@ -68,8 +68,8 @@ pub enum WsEvent {
6868

6969
#[derive(Clone)]
7070
struct ServerState {
71-
app: AppHandle,
7271
broadcast_tx: broadcast::Sender<WsEvent>,
72+
snapshot_fn: Arc<dyn Fn() -> Option<TimerSnapshot> + Send + Sync>,
7373
}
7474

7575
// ---------------------------------------------------------------------------
@@ -113,9 +113,16 @@ pub async fn start(port: u16, app: AppHandle, state: &Arc<WsState>) {
113113
}
114114
};
115115

116+
let app_clone = app.clone();
117+
let snapshot_fn = Arc::new(move || {
118+
app_clone
119+
.try_state::<TimerController>()
120+
.map(|t| t.get_snapshot())
121+
}) as Arc<dyn Fn() -> Option<TimerSnapshot> + Send + Sync>;
122+
116123
let server_state = ServerState {
117-
app: app.clone(),
118124
broadcast_tx: state.broadcast_tx.clone(),
125+
snapshot_fn,
119126
};
120127

121128
let router = Router::new()
@@ -154,28 +161,33 @@ async fn handle_socket(socket: WebSocket, state: ServerState) {
154161
log::debug!("[ws] client connected");
155162
let (mut sender, mut receiver) = socket.split();
156163
let mut rx = state.broadcast_tx.subscribe();
164+
let (direct_tx, mut direct_rx) = mpsc::unbounded_channel::<String>();
157165

158-
// Task: forward broadcast events to this client.
166+
// Task: forward broadcast events and direct replies to this client.
159167
let mut send_task = tokio::spawn(async move {
160-
while let Ok(event) = rx.recv().await {
161-
let json = match serde_json::to_string(&event) {
162-
Ok(s) => s,
163-
Err(_) => continue,
164-
};
165-
if sender.send(Message::Text(json.into())).await.is_err() {
166-
break;
168+
loop {
169+
tokio::select! {
170+
result = rx.recv() => {
171+
let Ok(event) = result else { break };
172+
let Ok(json) = serde_json::to_string(&event) else { continue };
173+
if sender.send(Message::Text(json.into())).await.is_err() { break }
174+
}
175+
msg = direct_rx.recv() => {
176+
let Some(json) = msg else { break };
177+
if sender.send(Message::Text(json.into())).await.is_err() { break }
178+
}
167179
}
168180
}
169181
});
170182

171183
// Main loop: handle incoming messages from this client.
172-
let app = state.app.clone();
173-
let broadcast_tx = state.broadcast_tx.clone();
184+
let snapshot_fn = Arc::clone(&state.snapshot_fn);
174185
let mut recv_task = tokio::spawn(async move {
175186
while let Some(Ok(msg)) = receiver.next().await {
176187
match msg {
177188
Message::Text(text) => {
178-
handle_client_message(&text, &app, &broadcast_tx).await;
189+
let snapshot = (snapshot_fn)();
190+
handle_client_message(&text, snapshot, &direct_tx).await;
179191
}
180192
Message::Close(_) => break,
181193
_ => {}
@@ -193,24 +205,19 @@ async fn handle_socket(socket: WebSocket, state: ServerState) {
193205

194206
async fn handle_client_message(
195207
text: &str,
196-
app: &AppHandle,
197-
_broadcast_tx: &broadcast::Sender<WsEvent>,
208+
snapshot: Option<TimerSnapshot>,
209+
direct_tx: &mpsc::UnboundedSender<String>,
198210
) {
199211
let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) else {
200212
return;
201213
};
202214

203215
if let Some("getState") = msg.get("type").and_then(|t| t.as_str()) {
204-
if let Some(timer) = app.try_state::<TimerController>() {
205-
let snapshot = timer.get_snapshot();
206-
let response = serde_json::json!({
207-
"type": "state",
208-
"payload": snapshot,
209-
});
210-
// Note: we can't send directly here without the sender;
211-
// the client will receive state via the next broadcast.
212-
// For an immediate reply, broadcast it.
213-
let _ = app.emit("timer:state-query", response);
216+
if let Some(snap) = snapshot {
217+
let json = serde_json::to_string(
218+
&serde_json::json!({ "type": "state", "payload": snap })
219+
).unwrap_or_default();
220+
let _ = direct_tx.send(json);
214221
}
215222
}
216223
}
@@ -252,17 +259,8 @@ pub fn broadcast_reset(state: &Arc<WsState>) {
252259
mod tests {
253260
use super::*;
254261

255-
#[test]
256-
fn ws_state_can_be_created() {
257-
let state = WsState::new();
258-
// broadcast_tx should have 0 receivers initially.
259-
assert_eq!(state.broadcast_tx.receiver_count(), 0);
260-
}
261-
262-
#[test]
263-
fn ws_event_serializes_correctly() {
264-
use crate::timer::TimerSnapshot;
265-
let snap = TimerSnapshot {
262+
fn make_snapshot() -> TimerSnapshot {
263+
TimerSnapshot {
266264
round_type: "work".into(),
267265
previous_round_type: "short-break".into(),
268266
elapsed_secs: 60,
@@ -272,8 +270,20 @@ mod tests {
272270
work_round_number: 1,
273271
work_rounds_total: 4,
274272
session_work_count: 1,
275-
};
276-
let event = WsEvent::RoundChange { payload: snap };
273+
}
274+
}
275+
276+
// -- existing serialization tests --
277+
278+
#[test]
279+
fn ws_state_can_be_created() {
280+
let state = WsState::new();
281+
assert_eq!(state.broadcast_tx.receiver_count(), 0);
282+
}
283+
284+
#[test]
285+
fn ws_event_serializes_correctly() {
286+
let event = WsEvent::RoundChange { payload: make_snapshot() };
277287
let json = serde_json::to_string(&event).unwrap();
278288
assert!(json.contains("\"type\":\"roundChange\""));
279289
assert!(json.contains("\"elapsed_secs\":60"));
@@ -309,4 +319,95 @@ mod tests {
309319
let json = serde_json::to_string(&event).unwrap();
310320
assert_eq!(json, r#"{"type":"reset"}"#);
311321
}
322+
323+
// -- handle_client_message unit tests --
324+
325+
#[tokio::test]
326+
async fn getstate_sends_state_reply() {
327+
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
328+
handle_client_message(r#"{"type":"getState"}"#, Some(make_snapshot()), &tx).await;
329+
let reply = rx.try_recv().expect("expected a reply on direct channel");
330+
let val: serde_json::Value = serde_json::from_str(&reply).unwrap();
331+
assert_eq!(val["type"], "state");
332+
assert_eq!(val["payload"]["elapsed_secs"], 60);
333+
assert_eq!(val["payload"]["total_secs"], 1500);
334+
assert_eq!(val["payload"]["round_type"], "work");
335+
assert_eq!(val["payload"]["is_running"], true);
336+
}
337+
338+
#[tokio::test]
339+
async fn getstate_no_timer_state_sends_nothing() {
340+
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
341+
handle_client_message(r#"{"type":"getState"}"#, None, &tx).await;
342+
assert!(rx.try_recv().is_err(), "expected no reply when snapshot is None");
343+
}
344+
345+
#[tokio::test]
346+
async fn malformed_json_is_silently_ignored() {
347+
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
348+
handle_client_message("not valid json {{{", Some(make_snapshot()), &tx).await;
349+
assert!(rx.try_recv().is_err(), "expected no reply for malformed JSON");
350+
}
351+
352+
#[tokio::test]
353+
async fn unknown_message_type_is_ignored() {
354+
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
355+
handle_client_message(r#"{"type":"unknownCommand"}"#, Some(make_snapshot()), &tx).await;
356+
assert!(rx.try_recv().is_err(), "expected no reply for unknown message type");
357+
}
358+
359+
#[tokio::test]
360+
async fn reply_uses_direct_channel_not_broadcast() {
361+
let (broadcast_tx, _) = broadcast::channel::<WsEvent>(8);
362+
let (direct_tx, mut direct_rx) = mpsc::unbounded_channel::<String>();
363+
handle_client_message(r#"{"type":"getState"}"#, Some(make_snapshot()), &direct_tx).await;
364+
// Reply appeared on the direct channel
365+
assert!(direct_rx.try_recv().is_ok(), "expected reply on direct channel");
366+
// Nothing sent to the broadcast channel
367+
assert_eq!(broadcast_tx.receiver_count(), 0);
368+
}
369+
370+
// -- network-level integration test --
371+
372+
#[tokio::test]
373+
async fn integration_getstate_round_trip() {
374+
use axum::Router;
375+
use tokio::net::TcpListener;
376+
use tokio_tungstenite::connect_async;
377+
use tokio_tungstenite::tungstenite::Message as TungMessage;
378+
use futures_util::{SinkExt, StreamExt};
379+
380+
let snap = make_snapshot();
381+
let snap_clone = snap.clone();
382+
let snapshot_fn = Arc::new(move || Some(snap_clone.clone()))
383+
as Arc<dyn Fn() -> Option<TimerSnapshot> + Send + Sync>;
384+
385+
let (broadcast_tx, _) = broadcast::channel::<WsEvent>(8);
386+
let server_state = ServerState { broadcast_tx, snapshot_fn };
387+
388+
let router = Router::new()
389+
.route("/ws", get(ws_handler))
390+
.with_state(server_state);
391+
392+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
393+
let port = listener.local_addr().unwrap().port();
394+
395+
tokio::spawn(async move {
396+
axum::serve(listener, router).await.unwrap();
397+
});
398+
399+
let url = format!("ws://127.0.0.1:{port}/ws");
400+
let (mut ws, _) = connect_async(&url).await.expect("WebSocket connect failed");
401+
402+
ws.send(TungMessage::Text(r#"{"type":"getState"}"#.into())).await.unwrap();
403+
404+
let msg = ws.next().await.expect("expected a message").unwrap();
405+
let TungMessage::Text(text) = msg else { panic!("expected text frame") };
406+
let val: serde_json::Value = serde_json::from_str(&text).unwrap();
407+
408+
assert_eq!(val["type"], "state", "response type should be 'state'");
409+
assert_eq!(val["payload"]["elapsed_secs"], snap.elapsed_secs);
410+
assert_eq!(val["payload"]["total_secs"], snap.total_secs);
411+
assert_eq!(val["payload"]["round_type"], snap.round_type);
412+
}
312413
}

0 commit comments

Comments
 (0)