Skip to content

Commit 05bb8f6

Browse files
committed
feat: enhance event handling by adding EvPairSuccess and EvReceipt types, and refactor callback execution
1 parent a8ef461 commit 05bb8f6

3 files changed

Lines changed: 183 additions & 140 deletions

File tree

src/clients/tryx.rs

Lines changed: 90 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::{Arc};
22
use std::future::Future;
33
use std::pin::Pin;
4-
use pyo3::{Bound, PyAny, pyclass, pymethods};
4+
use pyo3::{Bound, PyAny, PyTypeInfo, pyclass, pymethods};
55
use pyo3::prelude::*;
66
use pyo3_async_runtimes::{TaskLocals, into_future_with_locals};
77
use pyo3_async_runtimes::tokio::{future_into_py_with_locals, get_current_locals, into_future};
@@ -20,7 +20,7 @@ use tracing::{debug, error, info, warn};
2020
use super::tryx_client::TryxClient;
2121
use crate::log::init_logging;
2222
use crate::backend::{SqliteBackend, BackendBase};
23-
use crate::events::types::{EvArchiveUpdate, EvConnected, EvLoggedOut, EvMessage, EvPairingQrCode};
23+
use crate::events::types::{EvArchiveUpdate, EvConnected, EvLoggedOut, EvMessage, EvPairSuccess, EvPairingQrCode, EvReceipt};
2424
use crate::exceptions::UnsupportedBackend;
2525
use crate::events::dispatcher::Dispatcher;
2626
use crate::types::JID;
@@ -29,6 +29,7 @@ use crate::types::JID;
2929
#[pyclass]
3030
pub struct Tryx {
3131
backend: Arc<dyn Backend>,
32+
#[pyo3(get)]
3233
handlers: Py<Dispatcher>,
3334
tryx_client: Py<TryxClient>,
3435
client_tx: watch::Sender<Option<Arc<Client>>>,
@@ -67,7 +68,7 @@ impl Tryx {
6768
self.tryx_client.clone_ref(py)
6869
}
6970
/// Returns a decorator compatible with:
70-
/// @client.on(Message)
71+
/// @client.on(EvMessage)
7172
/// async def on_message(client, data): ...
7273
fn on(&self, py: Python, event_type: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
7374
debug!("registering event decorator through Tryx.on");
@@ -175,6 +176,41 @@ impl Tryx {
175176

176177

177178
impl Tryx {
179+
async fn call_event<T: PyTypeInfo>(callbacks: Arc<Vec<Py<PyAny>>>, payload: Py<T>, locals: Option<TaskLocals>) -> PyResult<()> {
180+
for callback in callbacks.iter() {
181+
debug!("calling event Python callback");
182+
let py_future = Python::attach(|py| -> PyResult<_> {
183+
let awaitable = callback.bind(py).call1((payload.clone_ref(py),))?;
184+
let fut: Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send>> = match &locals {
185+
Some(locals) => {
186+
let fut = into_future_with_locals(locals, awaitable)?;
187+
Box::pin(async move { fut.await })
188+
}
189+
None => {
190+
let fut = into_future(awaitable)?;
191+
Box::pin(async move { fut.await })
192+
}
193+
};
194+
Ok(fut)
195+
});
196+
197+
match py_future {
198+
Ok(py_future) => {
199+
if let Err(err) = py_future.await {
200+
error!(error = %err, "event callback failed");
201+
Python::attach(|py| err.print(py));
202+
} else {
203+
debug!("event callback finished");
204+
}
205+
}
206+
Err(err) => {
207+
error!(error = %err, "failed to schedule event callback");
208+
Python::attach(|py| err.print(py));
209+
}
210+
}
211+
}
212+
Ok(())
213+
}
178214
async fn run_bot(
179215
backend: Arc<dyn Backend>,
180216
handlers: Py<Dispatcher>,
@@ -187,6 +223,7 @@ impl Tryx {
187223
message_callbacks,
188224
connected_callbacks,
189225
logout_callbacks,
226+
pair_success_callbacks,
190227
receipt_callbacks,
191228
undecryptable_message_callbacks,
192229
notification_callbacks,
@@ -219,6 +256,7 @@ impl Tryx {
219256
dispatcher.message_handlers(py),
220257
dispatcher.connected_handlers(py),
221258
dispatcher.logout_handlers(py),
259+
dispatcher.pair_success_handlers(py),
222260
dispatcher.receipt_handlers(py),
223261
dispatcher.undecryptable_message_handlers(py),
224262
dispatcher.notification_handlers(py),
@@ -250,6 +288,7 @@ impl Tryx {
250288
let message_callbacks = Arc::new(message_callbacks);
251289
let connected_callbacks = Arc::new(connected_callbacks);
252290
let logout_callbacks = Arc::new(logout_callbacks);
291+
let pair_success_callbacks = Arc::new(pair_success_callbacks);
253292
let receipt_callbacks = Arc::new(receipt_callbacks);
254293
let undecryptable_message_callbacks = Arc::new(undecryptable_message_callbacks);
255294
let notification_callbacks = Arc::new(notification_callbacks);
@@ -276,38 +315,38 @@ impl Tryx {
276315
let connect_failure_callbacks = Arc::new(connect_failure_callbacks);
277316
let stream_error_callbacks = Arc::new(stream_error_callbacks);
278317

279-
info!(
280-
pairing_qr_handlers = pairing_qr_callbacks.len(),
281-
message_handlers = message_callbacks.len(),
282-
connected_handlers = connected_callbacks.len(),
283-
logout_handlers = logout_callbacks.len(),
284-
receipt_handlers = receipt_callbacks.len(),
285-
undecryptable_message_handlers = undecryptable_message_callbacks.len(),
286-
notification_handlers = notification_callbacks.len(),
287-
chat_presence_handlers = chat_presence_callbacks.len(),
288-
presence_handlers = presence_callbacks.len(),
289-
picture_update_handlers = picture_update_callbacks.len(),
290-
user_about_update_handlers = user_about_update_callbacks.len(),
291-
joined_group_handlers = joined_group_callbacks.len(),
292-
group_info_update_handlers = group_info_update_callbacks.len(),
293-
contact_update_handlers = contact_update_callbacks.len(),
294-
push_name_update_handlers = push_name_update_callbacks.len(),
295-
self_push_name_updated_handlers = self_push_name_updated_callbacks.len(),
296-
pin_update_handlers = pin_update_callbacks.len(),
297-
mute_update_handlers = mute_update_callbacks.len(),
298-
archive_update_handlers = archive_update_callbacks.len(),
299-
mark_chat_as_read_update_handlers = mark_chat_as_read_update_callbacks.len(),
300-
history_sync_handlers = history_sync_callbacks.len(),
301-
offline_sync_preview_handlers = offline_sync_preview_callbacks.len(),
302-
offline_sync_completed_handlers = offline_sync_completed_callbacks.len(),
303-
device_list_update_handlers = device_list_update_callbacks.len(),
304-
business_status_update_handlers = business_status_update_callbacks.len(),
305-
stream_replaced_handlers = stream_replaced_callbacks.len(),
306-
temporary_ban_handlers = temporary_ban_callbacks.len(),
307-
connect_failure_handlers = connect_failure_callbacks.len(),
308-
stream_error_handlers = stream_error_callbacks.len(),
309-
"cached dispatcher handlers for runtime"
310-
);
318+
// info!(
319+
// pairing_qr_handlers = pairing_qr_callbacks.len(),
320+
// message_handlers = message_callbacks.len(),
321+
// connected_handlers = connected_callbacks.len(),
322+
// logout_handlers = logout_callbacks.len(),
323+
// receipt_handlers = receipt_callbacks.len(),
324+
// undecryptable_message_handlers = undecryptable_message_callbacks.len(),
325+
// notification_handlers = notification_callbacks.len(),
326+
// chat_presence_handlers = chat_presence_callbacks.len(),
327+
// presence_handlers = presence_callbacks.len(),
328+
// picture_update_handlers = picture_update_callbacks.len(),
329+
// user_about_update_handlers = user_about_update_callbacks.len(),
330+
// joined_group_handlers = joined_group_callbacks.len(),
331+
// group_info_update_handlers = group_info_update_callbacks.len(),
332+
// contact_update_handlers = contact_update_callbacks.len(),
333+
// push_name_update_handlers = push_name_update_callbacks.len(),
334+
// self_push_name_updated_handlers = self_push_name_updated_callbacks.len(),
335+
// pin_update_handlers = pin_update_callbacks.len(),
336+
// mute_update_handlers = mute_update_callbacks.len(),
337+
// archive_update_handlers = archive_update_callbacks.len(),
338+
// mark_chat_as_read_update_handlers = mark_chat_as_read_update_callbacks.len(),
339+
// history_sync_handlers = history_sync_callbacks.len(),
340+
// offline_sync_preview_handlers = offline_sync_preview_callbacks.len(),
341+
// offline_sync_completed_handlers = offline_sync_completed_callbacks.len(),
342+
// device_list_update_handlers = device_list_update_callbacks.len(),
343+
// business_status_update_handlers = business_status_update_callbacks.len(),
344+
// stream_replaced_handlers = stream_replaced_callbacks.len(),
345+
// temporary_ban_handlers = temporary_ban_callbacks.len(),
346+
// connect_failure_handlers = connect_failure_callbacks.len(),
347+
// stream_error_handlers = stream_error_callbacks.len(),
348+
// "cached dispatcher handlers for runtime"
349+
// );
311350

312351
info!("building WhatsApp bot");
313352
let mut bot = Bot::builder()
@@ -320,6 +359,7 @@ impl Tryx {
320359
let message_callbacks = Arc::clone(&message_callbacks);
321360
let connected_callbacks = Arc::clone(&connected_callbacks);
322361
let logout_callbacks = Arc::clone(&logout_callbacks);
362+
let pair_success_callbacks = Arc::clone(&pair_success_callbacks);
323363
let receipt_callbacks = Arc::clone(&receipt_callbacks);
324364
let undecryptable_message_callbacks = Arc::clone(&undecryptable_message_callbacks);
325365
let notification_callbacks = Arc::clone(&notification_callbacks);
@@ -349,101 +389,20 @@ impl Tryx {
349389
async move {
350390
match event {
351391
Event::PairingQrCode { code, timeout } => {
352-
info!(timeout_secs = timeout.as_secs(), "received pairing QR event");
353-
info!(handlers = pairing_qr_callbacks.len(), "dispatching pairing QR handlers");
354-
355-
for (idx, callback) in pairing_qr_callbacks.iter().enumerate() {
356-
debug!(handler_index = idx, "calling pairing QR Python callback");
357-
let locals = locals.clone();
358-
let py_future = Python::attach(|py| -> PyResult<_> {
359-
let payload = Py::new(py, EvPairingQrCode::new(code.clone(), timeout.as_secs()))?;
360-
let awaitable = callback.bind(py).call1((py.None(), payload))?;
361-
let fut: Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send>> = match &locals {
362-
Some(locals) => {
363-
let fut = into_future_with_locals(locals, awaitable)?;
364-
Box::pin(async move { fut.await })
365-
}
366-
None => {
367-
let fut = into_future(awaitable)?;
368-
Box::pin(async move { fut.await })
369-
}
370-
};
371-
Ok(fut)
372-
});
373-
374-
match py_future {
375-
Ok(py_future) => {
376-
if let Err(err) = py_future.await {
377-
error!(handler_index = idx, error = %err, "pairing QR callback failed");
378-
Python::attach(|py| err.print(py));
379-
} else {
380-
debug!(handler_index = idx, "pairing QR callback finished");
381-
}
382-
}
383-
Err(err) => {
384-
error!(handler_index = idx, error = %err, "failed to schedule pairing QR callback");
385-
Python::attach(|py| err.print(py));
386-
}
387-
}
388-
}
392+
let payload = Python::attach(|py| Py::new(py, EvPairingQrCode::new(code.clone(), timeout.as_secs()))).map_err(|e| e).unwrap();
393+
Self::call_event(pairing_qr_callbacks, payload, locals.clone()).await.unwrap()
389394
}
390395
Event::Message(msg, info) => {
391396
let payload = Python::attach(|py| Py::new(py, EvMessage::new(msg, info))).map_err(|e| e).unwrap();
392-
for callback in message_callbacks.iter() {
393-
let locals = locals.clone();
394-
let py_future = Python::attach(|py| -> PyResult<_> {
395-
let client_obj = tryx_client.clone_ref(py);
396-
let awaitable = callback.bind(py).call1((client_obj, payload.clone_ref(py)))?;
397-
let fut: Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send>> = match &locals {
398-
Some(locals) => {
399-
let fut = into_future_with_locals(locals, awaitable)?;
400-
Box::pin(async move { fut.await })
401-
}
402-
None => {
403-
let fut = into_future(awaitable)?;
404-
Box::pin(async move { fut.await })
405-
}
406-
};
407-
Ok(fut)
408-
});
409-
410-
match py_future {
411-
Ok(py_future) => {
412-
if let Err(err) = py_future.await {
413-
Python::attach(|py| err.print(py));
414-
} else {
415-
debug!( "message callback finished");
416-
}
417-
}
418-
Err(err) => {
419-
error!("failed to schedule message callback");
420-
Python::attach(|py| err.print(py));
421-
}
422-
}
423-
}
397+
Self::call_event(message_callbacks, payload, locals.clone()).await.unwrap()
424398
}
425399
Event::Connected(_) => {
426400
let payload = Python::attach(|py| pyo3::Py::new(py, EvConnected{})).map_err(|e| e).unwrap();
427-
for callback in connected_callbacks.iter() {
428-
debug!("calling connected event handler");
429-
let _ = Python::attach(|py| -> PyResult<_> {
430-
let awaitable = callback.bind(py).call1((payload.clone_ref(py),))?;
431-
let fut = into_future(awaitable)?;
432-
Ok(fut)
433-
});
434-
}
401+
Self::call_event(connected_callbacks, payload, locals.clone()).await.unwrap();
435402
}
436403
Event::LoggedOut(logout) => {
437404
let payload = Python::attach(|py| pyo3::Py::new(py, EvLoggedOut::new(logout))).map_err(|e| e).unwrap();
438-
for callback in logout_callbacks.iter() {
439-
debug!("calling logged out event handler");
440-
let _ = Python::attach(|py| -> PyResult<_> {
441-
let awaitable = callback.bind(py).call1((payload.clone_ref(py),))?;
442-
let fut = into_future(awaitable)?;
443-
Ok(fut)
444-
});
445-
}
446-
405+
Self::call_event(logout_callbacks, payload, locals.clone()).await.unwrap();
447406
}
448407
Event::ArchiveUpdate(archived) => {
449408

@@ -453,17 +412,18 @@ impl Tryx {
453412
Arc::from(archived.action.clone()),
454413
archived.from_full_sync,
455414
))).map_err(|e| e).unwrap();
456-
for callback in archive_update_callbacks.iter() {
457-
debug!("calling archive update event handler");
458-
let _ = Python::attach(|py| -> PyResult<_> {
459-
let awaitable = callback.bind(py).call1((payload.clone_ref(py),))?;
460-
let fut = into_future(awaitable)?;
461-
Ok(fut)
462-
});
463-
}
464-
// debug!("received archive update event for jid {}", archived.jid);
415+
Self::call_event(archive_update_callbacks, payload, locals.clone()).await.unwrap();
465416

466417
}
418+
Event::Receipt(receipt) => {
419+
let receipt_callbacks = Arc::clone(&receipt_callbacks);
420+
let payload = EvReceipt::new(Arc::new(receipt.source), receipt.message_ids, receipt.timestamp, receipt.r#type, receipt.message_sender);
421+
Self::call_event(receipt_callbacks, payload, locals.clone()).await.unwrap();
422+
}
423+
Event::PairSuccess(pair_success) => {
424+
let payload = Python::attach(|py| pyo3::Py::new(py, EvPairSuccess::new(pair_success.id.into(), pair_success.lid.into(), pair_success.business_name, pair_success.platform))).map_err(|e| e).unwrap();
425+
Self::call_event(pair_success_callbacks, payload, locals.clone()).await.unwrap();
426+
}
467427
_ => {
468428
debug!("received event without registered dispatcher path");
469429
}

0 commit comments

Comments
 (0)