Skip to content

Commit 1ffc8f3

Browse files
committed
hot fix
1 parent 1e1a165 commit 1ffc8f3

4 files changed

Lines changed: 96 additions & 59 deletions

File tree

src/assembler.rs

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,32 @@ use wg_internal::{network::NodeId, packet::Fragment};
66

77
#[derive(Debug, Default)]
88
pub struct FragmentAssembler {
9-
pub fragments: HashMap<(u64, NodeId), Vec<Fragment>>, // session_id -> data buffer
10-
pub expected_fragments: HashMap<(u64, NodeId), u64>, // session_id -> total_fragments
11-
pub received_fragments: HashMap<(u64, NodeId), Vec<bool>>, // session_id -> received status
9+
pub fragments: HashMap<(u64, NodeId), (u64, Vec<Fragment>)>, // session_id -> data buffer
1210
}
1311

1412
impl FragmentAssembler {
1513
pub fn add_fragment(&mut self, fragment: Fragment, session_id: u64, sender: NodeId) -> Option<Vec<u8>> {
1614
let communication_id = ( session_id, sender );
17-
#[allow(clippy::cast_possible_truncation)]
18-
let index = fragment.fragment_index as usize;
19-
2015
if let Vacant(entry) = self.fragments.entry(communication_id) {
21-
self.expected_fragments.insert(communication_id, fragment.total_n_fragments);
22-
self.received_fragments.insert(communication_id, vec![false; index]);
23-
entry.insert(vec![fragment]);
24-
}
25-
26-
{
27-
let received = self.received_fragments.get_mut(&communication_id)?;
28-
received[index] = true;
16+
17+
entry.insert((fragment.total_n_fragments, vec![fragment]));
2918
}
30-
31-
let expected = self.expected_fragments.get(&communication_id)?;
32-
let received = self.received_fragments.get(&communication_id)?;
19+
3320
let fragments = self.fragments.get(&communication_id)?;
3421

3522
// check if all fragments has been received
36-
if fragments.len() as u64 == *expected && received.iter().all(|f| *f){
23+
if fragments.0 == fragments.1.len() as u64 {
3724
let fragments = self.fragments.get_mut(&communication_id)?;
38-
fragments.sort_by(|t, n| t.fragment_index.cmp(&n.fragment_index));
25+
fragments.1.sort_by(|t, n| t.fragment_index.cmp(&n.fragment_index));
3926
let mut data = vec![];
40-
for f in fragments {
41-
data.copy_from_slice(&f.data);
27+
for f in &fragments.1 {
28+
data.extend_from_slice(&f.data);
29+
}
30+
if let Some(pos) = data.iter().position(|&b| b == 0) {
31+
data.truncate(pos);
4232
}
33+
4334
let _ = self.fragments.remove(&communication_id);
44-
let _ = self.received_fragments.remove(&communication_id);
45-
let _ = self.expected_fragments.remove(&communication_id);
4635
return Some(data);
4736
}
4837
None

src/packet_processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub trait Processor: Send {
5151
}
5252

5353
fn run(&mut self) {
54-
let _ = self.routing_handler().start_flood();
54+
let _ = self.routing_handler().start_flood(None);
5555
loop {
5656
select_biased! {
5757
recv(self.controller_recv()) -> cmd => {

src/routing_handler.rs

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::types::SerializedRequest;
12
use crate::{
23
network::{Network, NetworkError, Node},
34
types::{Event, NodeEvent},
@@ -14,13 +15,15 @@ struct Buffer {
1415
// represents packets which reached the destination
1516
packets_received: HashMap<(u64, NodeId), Vec<(bool, Packet)>>,
1617
packets_to_send: Vec<Packet>,
18+
pending_ser_requests: HashSet<SerializedRequest>,
1719
}
1820

1921
impl Buffer {
2022
fn new() -> Self {
2123
Self {
2224
packets_received: HashMap::new(),
2325
packets_to_send: Vec::new(),
26+
pending_ser_requests: HashSet::new(),
2427
}
2528
}
2629

@@ -84,6 +87,7 @@ pub struct RoutingHandler {
8487
flood_counter: u64,
8588
controller_send: Sender<Box<dyn Event>>,
8689
buffer: Buffer,
90+
node_type: NodeType
8791
}
8892

8993
impl RoutingHandler {
@@ -103,6 +107,7 @@ impl RoutingHandler {
103107
flood_seen: HashSet::new(),
104108
controller_send,
105109
buffer: Buffer::new(),
110+
node_type
106111
}
107112
}
108113

@@ -123,7 +128,10 @@ impl RoutingHandler {
123128
/// and notifying the controller about the flood start.
124129
/// # Errors
125130
/// Returns an error if sending the packet to the controller fails or if sending to any neighbor fails.
126-
pub fn start_flood(&mut self) -> Result<(), NetworkError> {
131+
pub fn start_flood(
132+
&mut self,
133+
pending_request: Option<SerializedRequest>,
134+
) -> Result<(), NetworkError> {
127135
self.session_counter += 1;
128136
self.flood_counter += 1;
129137
let packet = Packet::new_flood_request(
@@ -142,6 +150,10 @@ impl RoutingHandler {
142150
self.remove_neighbor(*node_id);
143151
}
144152
}
153+
154+
if let Some(req) = pending_request {
155+
self.buffer.pending_ser_requests.insert(req);
156+
}
145157
Ok(())
146158
}
147159

@@ -161,12 +173,16 @@ impl RoutingHandler {
161173
/// Handle `flood_response`
162174
/// # Errors
163175
/// Returns error if can't send the packet
164-
pub fn handle_flood_response (
176+
pub fn handle_flood_response(
165177
&mut self,
166178
flood_response: &FloodResponse,
167179
) -> Result<(), NetworkError> {
168180
if flood_response.flood_id == self.flood_counter {
169181
self.update_network_view(&flood_response.path_trace);
182+
let requests = self.buffer.pending_ser_requests.drain().collect::<Vec<_>>();
183+
for req in requests {
184+
self.send_message(&req.data, req.to, None)?;
185+
}
170186
for packet in self.buffer.get_packets_to_send() {
171187
self.try_send(packet)?;
172188
}
@@ -215,7 +231,7 @@ impl RoutingHandler {
215231
.last()
216232
.map_or(flood_request.initiator_id, |x| x.0);
217233

218-
flood_request.path_trace.push((self.id, NodeType::Drone));
234+
flood_request.path_trace.push((self.id, self.node_type));
219235

220236
let flood_session = (flood_request.flood_id, flood_request.initiator_id);
221237

@@ -279,7 +295,7 @@ impl RoutingHandler {
279295
match nack.nack_type {
280296
NackType::ErrorInRouting(id) => {
281297
self.remove_neighbor(id);
282-
self.start_flood()?;
298+
self.start_flood(None)?;
283299
}
284300

285301
NackType::Dropped => {}
@@ -353,7 +369,7 @@ impl RoutingHandler {
353369
match self.try_find_path(destination) {
354370
Ok(shr) => packet.routing_header = shr,
355371
Err(NetworkError::PathNotFound(_)) => {
356-
self.start_flood()?;
372+
self.start_flood(None)?;
357373
self.buffer.add_pending_packet(packet.clone());
358374
}
359375
Err(e) => return Err(e),
@@ -378,39 +394,66 @@ impl RoutingHandler {
378394
pub fn send_message(
379395
&mut self,
380396
message: &[u8],
381-
destination: NodeId,
397+
destination: Option<NodeId>,
382398
session_id: Option<u64>,
383399
) -> Result<(), NetworkError> {
384-
let chunks: Vec<&[u8]> = message.chunks(128).collect();
385-
let total_n_fragments = chunks.len();
400+
// Split into 128-byte chunks
401+
let chunks = message.chunks(128);
402+
let total_n_fragments = chunks.len() as u64;
386403

387-
if session_id.is_none() {
404+
// Decide session id
405+
let session_id = session_id.unwrap_or_else(|| {
388406
self.session_counter += 1;
389-
}
390-
391-
let shr = self.try_find_path(destination)?;
392-
393-
for (i, chunk) in chunks.into_iter().enumerate() {
394-
// Pad/truncate to exactly 128 bytes
395-
let mut arr = [0u8; 128];
396-
arr[..chunk.len()].copy_from_slice(chunk);
407+
self.session_counter
408+
});
409+
410+
if let Some(destination) = destination {
411+
// Try to send directly
412+
if let Ok(shr) = self.try_find_path(destination) {
413+
for (i, chunk) in chunks.enumerate() {
414+
let fragment =
415+
Fragment::new(i as u64, total_n_fragments, Self::pad_chunk(chunk));
416+
let packet = Packet::new_fragment(shr.clone(), session_id, fragment);
417+
self.try_send(packet)?;
418+
419+
}
420+
421+
self.controller_send.send(Box::new(NodeEvent::MessageSent {
422+
notification_from: self.id,
423+
to: destination,
424+
})).map_err(|_e| NetworkError::ControllerDisconnected)?;
425+
426+
return Ok(());
427+
}
397428

398-
let fragment = Fragment::new(i as u64, total_n_fragments as u64, arr);
429+
// Path not found, try flooding passing the pending request
430+
self.start_flood(Some(SerializedRequest {
431+
to: Some(destination),
432+
data: message.to_vec(),
433+
}))?;
399434

400-
let packet = Packet::new_fragment(
401-
shr.clone(),
402-
if let Some(id) = session_id {
403-
id
404-
} else {
405-
self.session_counter
406-
},
407-
fragment,
408-
);
435+
return Ok(());
436+
}
409437

410-
self.try_send(packet)?;
438+
// No explicit destination
439+
if let Some(servers) = self.get_servers() {
440+
for server in servers {
441+
self.send_message(message, Some(server), Some(session_id))?;
442+
}
443+
return Ok(());
411444
}
412445

413-
Ok(())
446+
// Fallback: flooding
447+
self.start_flood(Some(SerializedRequest {
448+
to: None,
449+
data: message.to_vec(),
450+
}))
451+
}
452+
453+
fn pad_chunk(chunk: &[u8]) -> [u8; 128] {
454+
let mut arr = [0u8; 128];
455+
arr[..chunk.len()].copy_from_slice(chunk);
456+
arr
414457
}
415458

416459
pub fn handle_ack(&mut self, ack: &Ack, session_id: u64, from: NodeId) {
@@ -461,7 +504,7 @@ impl RoutingHandler {
461504
#[cfg(test)]
462505
mod routing_handler_tests {
463506
use super::*;
464-
use crossbeam_channel::{unbounded, Receiver};
507+
use crossbeam_channel::{Receiver, unbounded};
465508
use wg_internal::packet::PacketType;
466509

467510
#[test]
@@ -500,7 +543,7 @@ mod routing_handler_tests {
500543
let (neighbor_sender, neighbor_receiver) = unbounded();
501544
handler.add_neighbor(2, neighbor_sender);
502545

503-
handler.start_flood().unwrap();
546+
handler.start_flood(None).unwrap();
504547

505548
let packet = receiver.try_recv().unwrap();
506549
let packet = packet.into_any();
@@ -542,7 +585,7 @@ mod routing_handler_tests {
542585
handler.add_neighbor(2, neighbor_sender);
543586

544587
let message = b"Hello world".to_vec(); // 128 bytes total
545-
handler.send_message(&message, 2, None).unwrap();
588+
handler.send_message(&message, Some(2), None).unwrap();
546589

547590
let packet = neighbor_receiver.try_recv().unwrap();
548591
assert!(matches!(packet.pack_type, PacketType::MsgFragment(_)));
@@ -558,7 +601,7 @@ mod routing_handler_tests {
558601
handler.add_neighbor(2, neighbor_sender);
559602

560603
let message = b"Hello, world!".to_vec();
561-
handler.send_message(&message, 2, None).unwrap();
604+
handler.send_message(&message, Some(2), None).unwrap();
562605

563606
let ack = Ack { fragment_index: 0 };
564607
handler.handle_ack(&ack, 1, 2);
@@ -621,7 +664,7 @@ mod routing_handler_tests {
621664
.network_view
622665
.add_node(Node::new(2, NodeType::Server, vec![1]));
623666
let large_message = b"A".repeat(500);
624-
let _result = handler.send_message(&large_message, 2, None);
667+
let _result = handler.send_message(&large_message, Some(2), None);
625668
//assert!(result.is_ok());
626669
//assert!(handler.buffer.packets_received.len() > 0);
627670
// todo!() asserts fail because of Err(PathNotFound(2))

src/types.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ use uuid::Uuid;
88
use wg_internal::{network::NodeId, packet::Packet};
99
pub type Bytes = Vec<u8>;
1010

11+
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
12+
pub struct SerializedRequest {
13+
pub to: Option<NodeId>,
14+
pub data: Vec<u8>,
15+
}
16+
1117
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
1218
pub struct MediaReference {
1319
location: NodeId,
@@ -310,7 +316,6 @@ pub enum ChatEvent {
310316
},
311317
ErrorClientNotFound {
312318
notification_from: NodeId,
313-
location: NodeId,
314319
not_found: NodeId,
315320
},
316321

0 commit comments

Comments
 (0)