Skip to content

Commit b2ed663

Browse files
committed
[Rust] Ensure proper lifetime management of WebsocketClientCallback objects
1 parent ba9c3c5 commit b2ed663

File tree

2 files changed

+186
-78
lines changed

2 files changed

+186
-78
lines changed

rust/src/websocket/client.rs

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,31 @@ use crate::rc::{Ref, RefCountable};
22
use crate::string::{BnString, IntoCStr};
33
use binaryninjacore_sys::*;
44
use std::ffi::{c_char, c_void, CStr};
5+
use std::marker::PhantomData;
6+
use std::ops::Deref;
57
use std::ptr::NonNull;
68

79
pub trait WebsocketClientCallback: Sync + Send {
8-
fn connected(&mut self) -> bool;
10+
/// Receive a notification that the websocket connection has been connected successfully.
11+
///
12+
/// Return `false` if you would like to terminate the connection early.
13+
fn connected(&self) -> bool;
914

10-
fn disconnected(&mut self);
15+
/// Receive a notification that the websocket connection has been terminated.
16+
///
17+
/// For implementations, you must call this at the end of the websocket connection lifecycle
18+
/// even if you notify the client of an error.
19+
fn disconnected(&self);
1120

12-
fn error(&mut self, msg: &str);
21+
/// Receive an error from the websocket connection.
22+
///
23+
/// For implementations, you typically write the data to an interior mutable buffer on the instance.
24+
fn error(&self, msg: &str);
1325

14-
fn read(&mut self, data: &[u8]) -> bool;
26+
/// Receive data from the websocket connection.
27+
///
28+
/// For implementations, you typically write the data to an interior mutable buffer on the instance.
29+
fn read(&self, data: &[u8]) -> bool;
1530
}
1631

1732
pub trait WebsocketClient: Sync + Send {
@@ -27,7 +42,32 @@ pub trait WebsocketClient: Sync + Send {
2742
fn disconnect(&self) -> bool;
2843
}
2944

45+
/// Represents a live websocket connection.
46+
///
47+
/// This manages the lifetime of the callback, ensuring it outlives the connection.
48+
pub struct ActiveConnection<'a, C: WebsocketClientCallback> {
49+
pub client: Ref<CoreWebsocketClient>,
50+
_callback: PhantomData<&'a mut C>,
51+
}
52+
53+
impl<'a, C: WebsocketClientCallback> Deref for ActiveConnection<'a, C> {
54+
type Target = CoreWebsocketClient;
55+
56+
fn deref(&self) -> &Self::Target {
57+
&self.client
58+
}
59+
}
60+
61+
impl<'a, C: WebsocketClientCallback> Drop for ActiveConnection<'a, C> {
62+
fn drop(&mut self) {
63+
self.client.disconnect();
64+
}
65+
}
66+
3067
/// Implements a websocket client.
68+
///
69+
/// To connect, use [`Ref<CoreWebsocketClient>::connect`] which will return an [`ActiveConnection`]
70+
/// which manages the lifecycle of the websocket connection.
3171
#[repr(transparent)]
3272
pub struct CoreWebsocketClient {
3373
pub(crate) handle: NonNull<BNWebsocketClient>,
@@ -43,7 +83,50 @@ impl CoreWebsocketClient {
4383
&mut *self.handle.as_ptr()
4484
}
4585

46-
/// Initializes the web socket connection.
86+
/// Call the connect callback function, forward the callback returned value
87+
pub fn notify_connected(&self) -> bool {
88+
unsafe { BNNotifyWebsocketClientConnect(self.handle.as_ptr()) }
89+
}
90+
91+
/// Notify the callback function of a disconnect. This must be called at the end of an active
92+
/// websocket connection lifecycle, to free resources.
93+
///
94+
/// NOTE: This does not actually disconnect, use the [Self::disconnect] function for that.
95+
pub fn notify_disconnected(&self) {
96+
unsafe { BNNotifyWebsocketClientDisconnect(self.handle.as_ptr()) }
97+
}
98+
99+
/// Call the error callback function, this is not a terminating request you must use
100+
/// [`CoreWebsocketClient::notify_disconnected`] to terminate the connection.
101+
pub fn notify_error(&self, msg: &str) {
102+
let error = msg.to_cstr();
103+
unsafe { BNNotifyWebsocketClientError(self.handle.as_ptr(), error.as_ptr()) }
104+
}
105+
106+
/// Call the read callback function, forward the callback returned value.
107+
pub fn notify_read(&self, data: &[u8]) -> bool {
108+
unsafe {
109+
BNNotifyWebsocketClientReadData(
110+
self.handle.as_ptr(),
111+
data.as_ptr() as *mut _,
112+
data.len().try_into().unwrap(),
113+
)
114+
}
115+
}
116+
117+
pub fn write(&self, data: &[u8]) -> bool {
118+
let len = u64::try_from(data.len()).unwrap();
119+
unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) != 0 }
120+
}
121+
122+
pub fn disconnect(&self) -> bool {
123+
unsafe { BNDisconnectWebsocketClient(self.as_raw()) }
124+
}
125+
}
126+
127+
impl Ref<CoreWebsocketClient> {
128+
/// Initializes the web socket connection, returning the [`ActiveConnection`], once dropped the
129+
/// connection will be disconnected.
47130
///
48131
/// Connect to a given url, asynchronously. The connection will be run in a
49132
/// separate thread managed by the websocket provider.
@@ -54,20 +137,22 @@ impl CoreWebsocketClient {
54137
/// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal
55138
/// termination, [WebsocketClientCallback::disconnected] will be called.
56139
///
57-
/// If the connection succeeds, but later fails, [WebsocketClientCallback::disconnected] will not
58-
/// be called, and [WebsocketClientCallback::error] will be called instead.
140+
/// If the connection succeeds but later fails, [`WebsocketClientCallback::error`] will be called
141+
/// and shortly thereafter [`WebsocketClientCallback::disconnected`] will be called.
59142
///
60-
/// If the connection fails, neither [WebsocketClientCallback::connected] nor
61-
/// [WebsocketClientCallback::disconnected] will be called, and [WebsocketClientCallback::error]
62-
/// will be called instead.
63-
///
64-
/// If [WebsocketClientCallback::connected] or [WebsocketClientCallback::read] return false, the
143+
/// If [`WebsocketClientCallback::connected`] or [`WebsocketClientCallback::read`] return false, the
65144
/// connection will be aborted.
66145
///
67146
/// * `host` - Full url with scheme, domain, optionally port, and path
68147
/// * `headers` - HTTP header keys and values
69148
/// * `callback` - Callbacks for various websocket events
70-
pub fn initialize_connection<I, C>(&self, host: &str, headers: I, callbacks: &mut C) -> bool
149+
#[must_use]
150+
pub fn connect<'a, I, C>(
151+
self,
152+
host: &str,
153+
headers: I,
154+
callbacks: &'a C,
155+
) -> Option<ActiveConnection<'a, C>>
71156
where
72157
I: IntoIterator<Item = (String, String)>,
73158
C: WebsocketClientCallback,
@@ -79,16 +164,16 @@ impl CoreWebsocketClient {
79164
.unzip();
80165
let header_keys: Vec<*const c_char> = header_keys.iter().map(|k| k.as_ptr()).collect();
81166
let header_values: Vec<*const c_char> = header_values.iter().map(|v| v.as_ptr()).collect();
82-
// SAFETY: This context will only be live for the duration of BNConnectWebsocketClient
167+
// SAFETY: This context will live for as long as the `ActiveConnection` is alive.
83168
// SAFETY: Any subsequent call to BNConnectWebsocketClient will write over the context.
84169
let mut output_callbacks = BNWebsocketClientOutputCallbacks {
85-
context: callbacks as *mut C as *mut c_void,
170+
context: callbacks as *const C as *mut C as *mut c_void,
86171
connectedCallback: Some(cb_connected::<C>),
87172
disconnectedCallback: Some(cb_disconnected::<C>),
88173
errorCallback: Some(cb_error::<C>),
89174
readCallback: Some(cb_read::<C>),
90175
};
91-
unsafe {
176+
let success = unsafe {
92177
BNConnectWebsocketClient(
93178
self.handle.as_ptr(),
94179
url.as_ptr(),
@@ -97,46 +182,17 @@ impl CoreWebsocketClient {
97182
header_values.as_ptr(),
98183
&mut output_callbacks,
99184
)
100-
}
101-
}
102-
103-
/// Call the connect callback function, forward the callback returned value
104-
pub fn notify_connected(&self) -> bool {
105-
unsafe { BNNotifyWebsocketClientConnect(self.handle.as_ptr()) }
106-
}
107-
108-
/// Notify the callback function of a disconnect,
109-
///
110-
/// NOTE: This does not actually disconnect, use the [Self::disconnect] function for that.
111-
pub fn notify_disconnected(&self) {
112-
unsafe { BNNotifyWebsocketClientDisconnect(self.handle.as_ptr()) }
113-
}
114-
115-
/// Call the error callback function
116-
pub fn notify_error(&self, msg: &str) {
117-
let error = msg.to_cstr();
118-
unsafe { BNNotifyWebsocketClientError(self.handle.as_ptr(), error.as_ptr()) }
119-
}
185+
};
120186

121-
/// Call the read callback function, forward the callback returned value
122-
pub fn notify_read(&self, data: &[u8]) -> bool {
123-
unsafe {
124-
BNNotifyWebsocketClientReadData(
125-
self.handle.as_ptr(),
126-
data.as_ptr() as *mut _,
127-
data.len().try_into().unwrap(),
128-
)
187+
if success {
188+
Some(ActiveConnection {
189+
client: self,
190+
_callback: PhantomData,
191+
})
192+
} else {
193+
None
129194
}
130195
}
131-
132-
pub fn write(&self, data: &[u8]) -> bool {
133-
let len = u64::try_from(data.len()).unwrap();
134-
unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) != 0 }
135-
}
136-
137-
pub fn disconnect(&self) -> bool {
138-
unsafe { BNDisconnectWebsocketClient(self.as_raw()) }
139-
}
140196
}
141197

142198
unsafe impl Sync for CoreWebsocketClient {}

rust/tests/websocket.rs

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use binaryninja::websocket::{
44
register_websocket_provider, CoreWebsocketClient, CoreWebsocketProvider, WebsocketClient,
55
WebsocketClientCallback, WebsocketProvider,
66
};
7+
use std::sync::atomic::{AtomicBool, Ordering};
8+
use std::sync::RwLock;
79

810
struct MyWebsocketProvider {
911
core: CoreWebsocketProvider,
@@ -55,27 +57,50 @@ impl WebsocketClient for MyWebsocketClient {
5557

5658
#[derive(Default)]
5759
struct MyClientCallbacks {
58-
data_read: Vec<u8>,
59-
did_disconnect: bool,
60-
did_error: bool,
60+
data_read: RwLock<Vec<u8>>,
61+
did_disconnect: AtomicBool,
62+
did_error: AtomicBool,
6163
}
6264

6365
impl WebsocketClientCallback for MyClientCallbacks {
64-
fn connected(&mut self) -> bool {
66+
fn connected(&self) -> bool {
6567
true
6668
}
6769

68-
fn disconnected(&mut self) {
69-
self.did_disconnect = true;
70+
fn disconnected(&self) {
71+
self.did_disconnect.store(true, Ordering::Relaxed);
7072
}
7173

72-
fn error(&mut self, msg: &str) {
74+
fn error(&self, msg: &str) {
7375
assert_eq!(msg, "error");
74-
self.did_error = true;
76+
self.did_error.store(true, Ordering::Relaxed);
7577
}
7678

77-
fn read(&mut self, data: &[u8]) -> bool {
78-
self.data_read.extend_from_slice(data);
79+
fn read(&self, data: &[u8]) -> bool {
80+
self.data_read.write().unwrap().extend_from_slice(data);
81+
true
82+
}
83+
}
84+
85+
#[derive(Default)]
86+
struct LifetimeCallbacks {
87+
data: RwLock<Vec<u8>>,
88+
}
89+
90+
impl WebsocketClientCallback for LifetimeCallbacks {
91+
fn connected(&self) -> bool {
92+
true
93+
}
94+
95+
fn disconnected(&self) {}
96+
97+
fn error(&self, _msg: &str) {}
98+
99+
fn read(&self, data: &[u8]) -> bool {
100+
if data == "sent: ".as_bytes() || data == "\n".as_bytes() {
101+
return true;
102+
}
103+
assert_eq!(data, &self.data.read().unwrap()[..]);
79104
true
80105
}
81106
}
@@ -86,12 +111,12 @@ fn reg_websocket_provider() {
86111
let provider = register_websocket_provider::<MyWebsocketProvider>("RustWebsocketProvider");
87112
let client = provider.create_client().unwrap();
88113
let mut callback = MyClientCallbacks::default();
89-
let success = client.initialize_connection(
114+
let connection = client.connect(
90115
"url",
91116
[("header".to_string(), "value".to_string())],
92117
&mut callback,
93118
);
94-
assert!(success, "Failed to initialize connection!");
119+
assert!(connection.is_some(), "Failed to initialize connection!");
95120
}
96121

97122
#[test]
@@ -101,24 +126,51 @@ fn listen_websocket_provider() {
101126

102127
let client = provider.create_client().unwrap();
103128
let mut callback = MyClientCallbacks::default();
104-
client.initialize_connection(
105-
"url",
106-
[("header".to_string(), "value".to_string())],
107-
&mut callback,
108-
);
129+
let connection = client
130+
.connect(
131+
"url",
132+
[("header".to_string(), "value".to_string())],
133+
&callback,
134+
)
135+
.expect("Failed to initialize connection!");
109136

110-
assert!(client.write("test1".as_bytes()));
111-
assert!(client.write("test2".as_bytes()));
137+
assert!(connection.write("test1".as_bytes()));
138+
assert!(connection.write("test2".as_bytes()));
112139

113-
client.notify_error("error");
114-
client.disconnect();
115-
drop(client);
140+
connection.notify_error("error");
141+
connection.disconnect();
116142

117143
assert_eq!(
118-
&callback.data_read[..],
144+
&callback.data_read.read().unwrap()[..],
119145
"sent: test1\nsent: test2\n".as_bytes()
120146
);
121147
// If we disconnected that means the error callback was not notified.
122-
assert!(!callback.did_disconnect);
123-
assert!(callback.did_error);
148+
assert!(!callback.did_disconnect.load(Ordering::Relaxed));
149+
assert!(callback.did_error.load(Ordering::Relaxed));
150+
}
151+
152+
#[test]
153+
fn correct_websocket_client_lifetime() {
154+
let _session = Session::new().expect("Failed to initialize session");
155+
let provider = register_websocket_provider::<MyWebsocketProvider>("RustWebsocketProvider2");
156+
157+
let client = provider.create_client().unwrap();
158+
let callback = LifetimeCallbacks::default();
159+
let connection = client
160+
.connect(
161+
"url",
162+
[("header".to_string(), "value".to_string())],
163+
&callback,
164+
)
165+
.expect("Failed to initialize connection!");
166+
167+
println!("{:?}", callback.data);
168+
callback
169+
.data
170+
.write()
171+
.unwrap()
172+
.extend(vec![0x55, 0x55, 0x55, 0x55, 0x55]);
173+
174+
assert!(connection.write(&[0x55, 0x55, 0x55, 0x55, 0x55]));
175+
assert!(connection.write(&[0x55, 0x55, 0x55, 0x55, 0x55]));
124176
}

0 commit comments

Comments
 (0)