diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index a0343eb8..6aa23593 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -1103,10 +1103,14 @@ impl RecordDecrypt for Engine { self.peer_encryption_enabled } - fn replay_check_and_update(&mut self, seq: Sequence) -> bool { + fn replay_check(&self, seq: Sequence) -> bool { // Only epoch 1 (encrypted) records reach here; epoch 0 records are // returned early by the DTLS 1.2 incoming parser. - self.replay.check_and_update(seq.sequence_number) + self.replay.check(seq.sequence_number) + } + + fn replay_update(&mut self, seq: Sequence) { + self.replay.update(seq.sequence_number); } fn decryption_aad_and_nonce(&self, dtls: &DTLSRecord, buf: &[u8]) -> (Aad, Nonce) { diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index 26ed3b0b..820c19ba 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -150,9 +150,10 @@ impl Record { // We need to decrypt the record and redo the parsing. let dtls = record.record(); + let sequence = dtls.sequence; - // Anti-replay check - if !decrypt.replay_check_and_update(dtls.sequence) { + // Anti-replay check (read-only, does not update window) + if !decrypt.replay_check(sequence) { return Ok(None); } @@ -180,6 +181,11 @@ impl Record { buffer.len() }; + // Decryption succeeded — now commit the replay window update. + // RFC 6347 §4.1.2.6: "The receive window is updated only if the + // MAC verification succeeds." + decrypt.replay_update(sequence); + // Update the length of the record. buffer[11] = (new_len >> 8) as u8; buffer[12] = new_len as u8; @@ -263,7 +269,8 @@ impl ParsedRecord { /// parsing to depend only on the cryptographic operations it actually uses. pub trait RecordDecrypt { fn is_peer_encryption_enabled(&self) -> bool; - fn replay_check_and_update(&mut self, seq: Sequence) -> bool; + fn replay_check(&self, seq: Sequence) -> bool; + fn replay_update(&mut self, seq: Sequence); fn decryption_aad_and_nonce(&self, dtls: &DTLSRecord, buf: &[u8]) -> (Aad, Nonce); fn explicit_nonce_len(&self) -> usize; fn decrypt_data( diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d6c2e6fc..a973336c 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -2273,19 +2273,24 @@ impl RecordDecrypt for Engine { reconstruct_sequence(seq_bits, expected, bits) } - fn replay_check_and_update(&mut self, seq: Sequence) -> bool { + fn replay_check(&self, seq: Sequence) -> bool { // Route to the correct per-epoch replay window - let accepted = if seq.epoch == 2 { - self.hs_replay.check_and_update(seq.sequence_number) + if seq.epoch == 2 { + self.hs_replay.check(seq.sequence_number) } else { - match self.app_recv_keys.iter_mut().find(|e| e.epoch == seq.epoch) { - Some(entry) => entry.replay.check_and_update(seq.sequence_number), - None => return false, // no keys for this epoch + match self.app_recv_keys.iter().find(|e| e.epoch == seq.epoch) { + Some(entry) => entry.replay.check(seq.sequence_number), + None => false, // no keys for this epoch } - }; + } + } - if !accepted { - return false; + fn replay_update(&mut self, seq: Sequence) { + // Update the replay window for this epoch + if seq.epoch == 2 { + self.hs_replay.update(seq.sequence_number); + } else if let Some(entry) = self.app_recv_keys.iter_mut().find(|e| e.epoch == seq.epoch) { + entry.replay.update(seq.sequence_number); } // Advance expected receive sequence for this epoch @@ -2304,8 +2309,6 @@ impl RecordDecrypt for Engine { } } } - - true } fn decrypt_record( diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index b20559df..79a72cf3 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -245,8 +245,8 @@ impl Record { sequence_number: full_seq, }; - // Anti-replay check - if !decrypt.replay_check_and_update(full_sequence) { + // Anti-replay check (read-only, does not update window) + if !decrypt.replay_check(full_sequence) { return Ok(None); } @@ -278,6 +278,11 @@ impl Record { buffer.len() }; + // Decryption succeeded — now commit the replay window update. + // RFC 9147 §4.5.1: "The window MUST NOT be updated due to a received + // record until that record has been deprotected successfully." + decrypt.replay_update(full_sequence); + // Recover inner content type from DTLSInnerPlaintext let decrypted = &buffer[header_end..header_end + new_len]; let (inner_content_type, content_len) = match recover_inner_content_type(decrypted) { @@ -395,7 +400,8 @@ pub trait RecordDecrypt { fn is_peer_encryption_enabled(&self) -> bool; fn resolve_epoch(&self, epoch_bits: u8) -> u16; fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64; - fn replay_check_and_update(&mut self, seq: Sequence) -> bool; + fn replay_check(&self, seq: Sequence) -> bool; + fn replay_update(&mut self, seq: Sequence); fn decrypt_record( &mut self, header: &[u8], diff --git a/src/window.rs b/src/window.rs index d53c7046..66b93e09 100644 --- a/src/window.rs +++ b/src/window.rs @@ -16,9 +16,24 @@ impl ReplayWindow { Self::default() } - /// Check if the given sequence number is acceptable and update the window state. - /// Returns true if fresh/acceptable, false if duplicate/too old. - pub fn check_and_update(&mut self, seqno: u64) -> bool { + /// Check if the given sequence number is acceptable (not a replay, not too old). + /// Read-only: does not modify the window state. + pub fn check(&self, seqno: u64) -> bool { + if seqno > self.max_seq { + true + } else { + let offset = self.max_seq - seqno; + if offset >= 64 { + return false; // too old + } + let mask = 1u64 << offset; + (self.window & mask) == 0 // false if duplicate + } + } + + /// Update the window state to record that `seqno` has been received. + /// Must only be called after the record has been authenticated (decrypted successfully). + pub fn update(&mut self, seqno: u64) { if seqno > self.max_seq { let delta = seqno - self.max_seq; if delta > 63 { @@ -29,18 +44,11 @@ impl ReplayWindow { self.window |= 1; // mark newest as seen } self.max_seq = seqno; - true } else { let offset = self.max_seq - seqno; - if offset >= 64 { - return false; // too old - } - let mask = 1u64 << offset; - if (self.window & mask) != 0 { - return false; // duplicate + if offset < 64 { + self.window |= 1u64 << offset; } - self.window |= mask; - true } } } @@ -49,54 +57,87 @@ impl ReplayWindow { mod tests { use super::*; + /// Helper: check and update in one step (simulates authenticated record). + fn check_and_update(w: &mut ReplayWindow, seqno: u64) -> bool { + if w.check(seqno) { + w.update(seqno); + true + } else { + false + } + } + #[test] fn accepts_fresh_and_rejects_duplicate() { let mut w = ReplayWindow::new(); - assert!(w.check_and_update(1)); - assert!(!w.check_and_update(1)); // duplicate - assert!(w.check_and_update(2)); // next fresh + assert!(check_and_update(&mut w, 1)); + assert!(!check_and_update(&mut w, 1)); // duplicate + assert!(check_and_update(&mut w, 2)); // next fresh } #[test] fn accepts_out_of_order_within_window() { let mut w = ReplayWindow::new(); - assert!(w.check_and_update(10)); // establish max=10 - assert!(w.check_and_update(8)); // unseen within 64 - assert!(!w.check_and_update(8)); // duplicate now - assert!(w.check_and_update(9)); // unseen within 64 + assert!(check_and_update(&mut w, 10)); // establish max=10 + assert!(check_and_update(&mut w, 8)); // unseen within 64 + assert!(!check_and_update(&mut w, 8)); // duplicate now + assert!(check_and_update(&mut w, 9)); // unseen within 64 } #[test] fn rejects_too_old() { let mut w = ReplayWindow::new(); - assert!(w.check_and_update(100)); + assert!(check_and_update(&mut w, 100)); // offset = 64 -> too old - assert!(!w.check_and_update(36)); + assert!(!check_and_update(&mut w, 36)); // offset = 63 -> allowed once - assert!(w.check_and_update(37)); + assert!(check_and_update(&mut w, 37)); } #[test] fn handles_large_jump_and_window_shift() { let mut w = ReplayWindow::new(); - assert!(w.check_and_update(1)); + assert!(check_and_update(&mut w, 1)); // Large forward jump clears the window entirely - assert!(w.check_and_update(80)); + assert!(check_and_update(&mut w, 80)); // Within window of new max and unseen - assert!(w.check_and_update(79)); + assert!(check_and_update(&mut w, 79)); // Too old relative to new max - assert!(!w.check_and_update(15)); + assert!(!check_and_update(&mut w, 15)); } #[test] fn large_jump_does_not_leave_stale_bits() { let mut w = ReplayWindow::new(); - assert!(w.check_and_update(0)); + assert!(check_and_update(&mut w, 0)); // Jump of 200 exceeds window size (64). The window must be fully // cleared so no stale bits from seq 0 remain. - assert!(w.check_and_update(200)); + assert!(check_and_update(&mut w, 200)); // seq 137 is within the window (offset = 200 - 137 = 63) and was // never seen, so it must be accepted. - assert!(w.check_and_update(137)); + assert!(check_and_update(&mut w, 137)); + } + + #[test] + fn check_does_not_modify_window() { + let mut w = ReplayWindow::new(); + w.update(10); + // check alone should not change state + assert!(w.check(11)); + assert!(w.check(11)); // still acceptable because update was never called + w.update(11); + assert!(!w.check(11)); // now it's a duplicate + } + + #[test] + fn failed_auth_does_not_advance_window() { + let mut w = ReplayWindow::new(); + w.update(5); + // Simulate receiving seq 200 that passes check but fails authentication + assert!(w.check(200)); + // Do NOT call update (authentication failed) + // Legitimate packet at seq 6 should still be accepted + assert!(w.check(6)); + w.update(6); } }