Skip to content

Commit a09ec69

Browse files
committed
restrict ServiceFlags api
1 parent 9959201 commit a09ec69

5 files changed

Lines changed: 92 additions & 109 deletions

File tree

dash-spv/src/network/handshake.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::net::SocketAddr;
44
use std::time::{Duration, SystemTime, UNIX_EPOCH};
55

66
use dashcore::network::constants;
7-
use dashcore::network::constants::{ServiceFlags, NODE_HEADERS_COMPRESSED};
7+
use dashcore::network::constants::ServiceFlags;
88
use dashcore::network::message::NetworkMessage;
99
use dashcore::network::message_network::VersionMessage;
1010
use dashcore::Network;
@@ -36,7 +36,7 @@ pub struct HandshakeManager {
3636
state: HandshakeState,
3737
our_version: u32,
3838
peer_version: Option<u32>,
39-
peer_services: Option<ServiceFlags>,
39+
peer_services: ServiceFlags,
4040
version_received: bool,
4141
verack_received: bool,
4242
version_sent: bool,
@@ -56,7 +56,7 @@ impl HandshakeManager {
5656
state: HandshakeState::Init,
5757
our_version: constants::PROTOCOL_VERSION,
5858
peer_version: None,
59-
peer_services: None,
59+
peer_services: ServiceFlags::NONE,
6060
version_received: false,
6161
verack_received: false,
6262
version_sent: false,
@@ -157,7 +157,7 @@ impl HandshakeManager {
157157
version_msg
158158
);
159159
self.peer_version = Some(version_msg.version);
160-
self.peer_services = Some(version_msg.services);
160+
self.peer_services = version_msg.services;
161161
self.version_received = true;
162162

163163
// Update connection's peer information
@@ -261,7 +261,7 @@ impl HandshakeManager {
261261
.as_secs() as i64;
262262

263263
// Advertise headers2 support (NODE_HEADERS_COMPRESSED)
264-
let services = ServiceFlags::NONE | NODE_HEADERS_COMPRESSED;
264+
let services = ServiceFlags::NODE_HEADERS_COMPRESSED;
265265

266266
// Parse the local address safely
267267
let local_addr = "127.0.0.1:0"
@@ -313,7 +313,7 @@ impl HandshakeManager {
313313

314314
/// Check if peer supports headers2 compression.
315315
pub fn peer_supports_headers2(&self) -> bool {
316-
self.peer_services.map(|services| services.has(NODE_HEADERS_COMPRESSED)).unwrap_or(false)
316+
self.peer_services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
317317
}
318318

319319
/// Negotiate headers2 support with the peer after handshake completion.

dash-spv/src/network/peer.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub struct Peer {
4040
pending_pings: HashMap<u64, SystemTime>, // nonce -> sent_time
4141
// Peer information from Version message
4242
version: Option<u32>,
43-
services: Option<u64>,
43+
services: ServiceFlags,
4444
user_agent: Option<String>,
4545
best_height: Option<u32>,
4646
relay: Option<bool>,
@@ -68,7 +68,7 @@ impl Peer {
6868
last_pong_received: None,
6969
pending_pings: HashMap::new(),
7070
version: None,
71-
services: None,
71+
services: ServiceFlags::NONE,
7272
user_agent: None,
7373
best_height: None,
7474
relay: None,
@@ -115,7 +115,7 @@ impl Peer {
115115
last_pong_received: None,
116116
pending_pings: HashMap::new(),
117117
version: None,
118-
services: None,
118+
services: ServiceFlags::NONE,
119119
user_agent: None,
120120
best_height: None,
121121
relay: None,
@@ -144,7 +144,7 @@ impl Peer {
144144
}
145145

146146
pub fn has_service(&self, flags: ServiceFlags) -> bool {
147-
self.services.map(|s| ServiceFlags::from(s).has(flags)).unwrap_or(false)
147+
self.services.has(flags)
148148
}
149149

150150
/// Connect to the peer (instance method for compatibility).
@@ -273,7 +273,7 @@ impl Peer {
273273

274274
// All validations passed, update peer info
275275
self.version = Some(version_msg.version);
276-
self.services = Some(version_msg.services.as_u64());
276+
self.services = version_msg.services;
277277
self.user_agent = Some(version_msg.user_agent.clone());
278278
self.best_height = Some(version_msg.start_height as u32);
279279
self.relay = Some(version_msg.relay);
@@ -824,12 +824,7 @@ impl Peer {
824824
// We can request headers2 if peer has the service flag for headers2 support
825825
// Note: We don't wait for SendHeaders2 from peer as that creates a race condition
826826
// during initial sync. The service flag is sufficient to know they support headers2.
827-
if let Some(services) = self.services {
828-
dashcore::network::constants::ServiceFlags::from(services)
829-
.has(dashcore::network::constants::NODE_HEADERS_COMPRESSED)
830-
} else {
831-
false
832-
}
827+
self.services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
833828
}
834829
}
835830

dash/src/network/address.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ impl Encodable for AddrV2Message {
308308
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
309309
let mut len = 0;
310310
len += self.time.consensus_encode(w)?;
311-
len += VarInt(self.services.as_u64()).consensus_encode(w)?;
311+
// This msg encodes ServiceFlags as a VarInt, so we need to
312+
// use the specialized method for it. Don't use consensus_encode
313+
// since it encodes as a u64, not a VarInt.
314+
len += self.services.consensus_encode_as_var_int(w)?;
312315
len += self.addr.consensus_encode(w)?;
313316

314317
w.write_all(&self.port.to_be_bytes())?;
@@ -322,7 +325,10 @@ impl Decodable for AddrV2Message {
322325
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
323326
Ok(AddrV2Message {
324327
time: Decodable::consensus_decode(r)?,
325-
services: ServiceFlags::from(VarInt::consensus_decode(r)?.0),
328+
// This msg encodes ServiceFlags as a VarInt, so we need to
329+
// use the specialized method for it. Don't use consensus_decode
330+
// since it decodes as a u64, not a VarInt.
331+
services: ServiceFlags::consensus_decode_from_var_int(r)?,
326332
addr: Decodable::consensus_decode(r)?,
327333
port: u16::swap_bytes(Decodable::consensus_decode(r)?),
328334
})
@@ -365,12 +371,13 @@ mod test {
365371

366372
#[test]
367373
fn debug_format_test() {
368-
let mut flags = ServiceFlags::NETWORK;
374+
let mut services = ServiceFlags::NETWORK;
375+
services.add(ServiceFlags::WITNESS);
369376
assert_eq!(
370377
format!(
371378
"The address is: {:?}",
372379
Address {
373-
services: flags.add(ServiceFlags::WITNESS),
380+
services,
374381
address: [0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001],
375382
port: 8333
376383
}
@@ -412,16 +419,20 @@ mod test {
412419

413420
#[test]
414421
fn test_socket_addr() {
422+
let mut services = ServiceFlags::NETWORK;
423+
services.add(ServiceFlags::WITNESS);
424+
415425
let s4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(111, 222, 123, 4)), 5555);
416-
let a4 = Address::new(&s4, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
426+
let a4 = Address::new(&s4, services);
417427
assert_eq!(a4.socket_addr().unwrap(), s4);
428+
418429
let s6 = SocketAddr::new(
419430
IpAddr::V6(Ipv6Addr::new(
420431
0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888,
421432
)),
422433
9999,
423434
);
424-
let a6 = Address::new(&s6, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
435+
let a6 = Address::new(&s6, services);
425436
assert_eq!(a6.socket_addr().unwrap(), s6);
426437
}
427438

@@ -577,19 +588,23 @@ mod test {
577588
let raw = hex!("0261bc6649019902abab208d79627683fd4804010409090909208d");
578589
let addresses: Vec<AddrV2Message> = deserialize(&raw).unwrap();
579590

591+
let services1 = ServiceFlags::NETWORK;
592+
593+
let mut services2 = ServiceFlags::NETWORK_LIMITED;
594+
services2.add(ServiceFlags::WITNESS);
595+
services2.add(ServiceFlags::COMPACT_FILTERS);
596+
580597
assert_eq!(
581598
addresses,
582599
vec![
583600
AddrV2Message {
584-
services: ServiceFlags::NETWORK,
601+
services: services1,
585602
time: 0x4966bc61,
586603
port: 8333,
587604
addr: AddrV2::Unknown(153, hex!("abab"))
588605
},
589606
AddrV2Message {
590-
services: ServiceFlags::NETWORK_LIMITED
591-
| ServiceFlags::WITNESS
592-
| ServiceFlags::COMPACT_FILTERS,
607+
services: services2,
593608
time: 0x83766279,
594609
port: 8333,
595610
addr: AddrV2::Ipv4(Ipv4Addr::new(9, 9, 9, 9))

dash/src/network/constants.rs

Lines changed: 45 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,12 @@
3333
//! assert_eq!(&bytes[..], &[0xBF, 0x0C, 0x6B, 0xBD]);
3434
//! ```
3535
36-
use core::convert::From;
37-
use core::{fmt, ops};
36+
use core::fmt;
3837

3938
use hashes::Hash;
4039

4140
use crate::consensus::encode::{self, Decodable, Encodable};
42-
use crate::{BlockHash, io};
43-
44-
// Re-export NODE_HEADERS_COMPRESSED for convenience
45-
pub const NODE_HEADERS_COMPRESSED: ServiceFlags = ServiceFlags::NODE_HEADERS_COMPRESSED;
41+
use crate::{BlockHash, VarInt, io};
4642

4743
/// Version of the protocol as appearing in network message headers
4844
/// This constant is used to signal to other peers which features you support.
@@ -231,30 +227,44 @@ impl ServiceFlags {
231227
// NOTE: When adding new flags, remember to update the Display impl accordingly.
232228

233229
/// Add [ServiceFlags] together.
234-
///
235-
/// Returns itself.
236-
pub fn add(&mut self, other: ServiceFlags) -> ServiceFlags {
230+
pub fn add(&mut self, other: ServiceFlags) {
237231
self.0 |= other.0;
238-
*self
239232
}
240233

241234
/// Remove [ServiceFlags] from this.
242-
///
243-
/// Returns itself.
244-
pub fn remove(&mut self, other: ServiceFlags) -> ServiceFlags {
235+
pub fn remove(&mut self, other: ServiceFlags) {
245236
self.0 ^= other.0;
246-
*self
247237
}
248238

249239
/// Check whether [ServiceFlags] are included in this one.
250-
pub fn has(self, flags: ServiceFlags) -> bool {
240+
pub fn has(&self, flags: ServiceFlags) -> bool {
251241
(self.0 | flags.0) == self.0
252242
}
253243

254244
/// Get the integer representation of this [ServiceFlags].
255-
pub fn as_u64(self) -> u64 {
245+
pub fn as_u64(&self) -> u64 {
256246
self.0
257247
}
248+
249+
// This struct is weird in the dash protocol, sometime services are encoded as u64
250+
// and sometimes as a VarInt. While the Encodable/Decodable encodes and decodes the u64
251+
// as usual, this methods use VarInt to satisfy the protocol
252+
253+
#[inline]
254+
pub fn consensus_encode_as_var_int<W: io::Write + ?Sized>(
255+
&self,
256+
w: &mut W,
257+
) -> Result<usize, io::Error> {
258+
self.0.consensus_encode(w)
259+
}
260+
261+
#[inline]
262+
pub fn consensus_decode_from_var_int<R: io::Read + ?Sized>(
263+
r: &mut R,
264+
) -> Result<Self, encode::Error> {
265+
let services = VarInt::consensus_decode(r)?;
266+
Ok(ServiceFlags(services.0))
267+
}
258268
}
259269

260270
impl fmt::LowerHex for ServiceFlags {
@@ -307,54 +317,20 @@ impl fmt::Display for ServiceFlags {
307317
}
308318
}
309319

310-
impl From<u64> for ServiceFlags {
311-
fn from(f: u64) -> Self {
312-
ServiceFlags(f)
313-
}
314-
}
315-
316-
impl From<ServiceFlags> for u64 {
317-
fn from(val: ServiceFlags) -> Self {
318-
val.0
319-
}
320-
}
321-
322-
impl ops::BitOr for ServiceFlags {
323-
type Output = Self;
324-
325-
fn bitor(mut self, rhs: Self) -> Self {
326-
self.add(rhs)
327-
}
328-
}
329-
330-
impl ops::BitOrAssign for ServiceFlags {
331-
fn bitor_assign(&mut self, rhs: Self) {
332-
self.add(rhs);
333-
}
334-
}
335-
336-
impl ops::BitXor for ServiceFlags {
337-
type Output = Self;
338-
339-
fn bitxor(mut self, rhs: Self) -> Self {
340-
self.remove(rhs)
341-
}
342-
}
343-
344-
impl ops::BitXorAssign for ServiceFlags {
345-
fn bitxor_assign(&mut self, rhs: Self) {
346-
self.remove(rhs);
347-
}
348-
}
349-
350320
impl Encodable for ServiceFlags {
321+
/// Encodes the service flags as a u64, not a VarInt. Services are usually encoded as a u64
322+
/// but there are some messages that encode them as a VarInt instead. For those use the
323+
/// specialized method `consensus_encode_as_var_int`.
351324
#[inline]
352325
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
353326
self.0.consensus_encode(w)
354327
}
355328
}
356329

357330
impl Decodable for ServiceFlags {
331+
/// Decodes the service flags as a u64, not a VarInt. Services are usually decoded as a u64
332+
/// but there are some messages that decode them as a VarInt instead. For those use the
333+
/// specialized method `consensus_decode_as_var_int`.
358334
#[inline]
359335
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
360336
Ok(ServiceFlags(Decodable::consensus_decode(r)?))
@@ -434,27 +410,28 @@ mod tests {
434410
assert!(!flags.has(*f));
435411
}
436412

437-
flags |= ServiceFlags::WITNESS;
413+
flags.add(ServiceFlags::WITNESS);
438414
assert_eq!(flags, ServiceFlags::WITNESS);
439415

440-
let mut flags2 = flags | ServiceFlags::GETUTXO;
416+
flags.add(ServiceFlags::GETUTXO);
441417
for f in all.iter() {
442-
assert_eq!(flags2.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
418+
assert_eq!(flags.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
443419
}
444420

445-
flags2 ^= ServiceFlags::WITNESS;
446-
assert_eq!(flags2, ServiceFlags::GETUTXO);
421+
flags.remove(ServiceFlags::WITNESS);
422+
assert_eq!(flags, ServiceFlags::GETUTXO);
447423

448-
flags2 |= ServiceFlags::COMPACT_FILTERS;
449-
flags2 ^= ServiceFlags::GETUTXO;
450-
assert_eq!(flags2, ServiceFlags::COMPACT_FILTERS);
424+
flags.add(ServiceFlags::COMPACT_FILTERS);
425+
flags.remove(ServiceFlags::GETUTXO);
426+
assert_eq!(flags, ServiceFlags::COMPACT_FILTERS);
451427

452428
// Test formatting.
453429
assert_eq!("ServiceFlags(NONE)", ServiceFlags::NONE.to_string());
454430
assert_eq!("ServiceFlags(WITNESS)", ServiceFlags::WITNESS.to_string());
455-
let flag = ServiceFlags::WITNESS | ServiceFlags::BLOOM | ServiceFlags::NETWORK;
456-
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flag.to_string());
457-
let flag = ServiceFlags::WITNESS | 0xf0.into();
458-
assert_eq!("ServiceFlags(WITNESS|COMPACT_FILTERS|0xb0)", flag.to_string());
431+
432+
let mut flags = ServiceFlags::WITNESS;
433+
flags.add(ServiceFlags::BLOOM);
434+
flags.add(ServiceFlags::NETWORK);
435+
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flags.to_string());
459436
}
460437
}

0 commit comments

Comments
 (0)