Skip to content

Commit ecb828e

Browse files
committed
feat: implement media download functionality and enhance TryxClient methods
1 parent 3ec4f23 commit ecb828e

4 files changed

Lines changed: 100 additions & 8 deletions

File tree

python/tryx/client.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,19 @@ class Tryx:
1919
def run(self) -> Awaitable[None]: ...
2020
def run_blocking(self) -> None: ...
2121

22+
DownloadableMedia = (
23+
MessageProto.ImageMessage
24+
| MessageProto.VideoMessage
25+
| MessageProto.AudioMessage
26+
| MessageProto.DocumentMessage
27+
| MessageProto.StickerMessage
28+
)
29+
2230
class TryxClient:
2331
async def send_message(self, chat: JID, message: MessageProto) -> str: ...
2432
async def upload(self, data: bytes, media_type: MediaType) -> UploadResponse: ...
2533
async def upload_file(self, path: str, media_type: MediaType) -> UploadResponse: ...
34+
async def download_media(self, message: DownloadableMedia) -> bytes: ...
2635

2736
class Nu[T]:
2837
X: T

src/client.rs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use whatsapp_rust::store::Backend;
1515
use whatsapp_rust_tokio_transport::TokioWebSocketTransportFactory;
1616
use whatsapp_rust_ureq_http_client::UreqHttpClient;
1717
use waproto::whatsapp::Message as WhatsappMessage;
18+
use waproto::whatsapp::message::{self as wa};
1819
use prost::Message;
1920
use tokio::signal;
2021
use tracing::{debug, error, info, warn};
@@ -55,13 +56,96 @@ pub struct TryxClient {
5556

5657
#[pymethods]
5758
impl TryxClient {
59+
fn is_connected(&self) -> bool {
60+
self.client_rx.borrow().is_some()
61+
}
62+
fn download_media<'py>(&self, py: Python<'py>, message: Py<PyAny>) -> PyResult<Bound<'py, PyAny>> {
63+
let client = self.client_rx.borrow().clone().ok_or_else(|| {
64+
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Bot is not running")
65+
})?;
66+
let message_type_name = message
67+
.getattr(py, "DESCRIPTOR")
68+
.and_then(|descriptor| descriptor.getattr(py, "name"))
69+
.and_then(|name| name.extract::<String>(py))
70+
.unwrap_or_default();
71+
let serialized: Vec<u8> = message
72+
.call_method0(py, "SerializeToString")?
73+
.extract(py)?;
74+
75+
let locals = get_current_locals(py)?;
76+
future_into_py_with_locals(py, locals, async move {
77+
let download = match message_type_name.as_str() {
78+
"ImageMessage" => {
79+
let media = wa::ImageMessage::decode(serialized.as_slice()).map_err(|e| {
80+
PyErr::new::<pyo3::exceptions::PyValueError, _>(
81+
format!("Failed to decode ImageMessage: {}", e),
82+
)
83+
})?;
84+
client.download(&media).await
85+
}
86+
"VideoMessage" => {
87+
let media = wa::VideoMessage::decode(serialized.as_slice()).map_err(|e| {
88+
PyErr::new::<pyo3::exceptions::PyValueError, _>(
89+
format!("Failed to decode VideoMessage: {}", e),
90+
)
91+
})?;
92+
client.download(&media).await
93+
}
94+
"DocumentMessage" => {
95+
let media = wa::DocumentMessage::decode(serialized.as_slice()).map_err(|e| {
96+
PyErr::new::<pyo3::exceptions::PyValueError, _>(
97+
format!("Failed to decode DocumentMessage: {}", e),
98+
)
99+
})?;
100+
client.download(&media).await
101+
}
102+
"AudioMessage" => {
103+
let media = wa::AudioMessage::decode(serialized.as_slice()).map_err(|e| {
104+
PyErr::new::<pyo3::exceptions::PyValueError, _>(
105+
format!("Failed to decode AudioMessage: {}", e),
106+
)
107+
})?;
108+
client.download(&media).await
109+
}
110+
"StickerMessage" => {
111+
let media = wa::StickerMessage::decode(serialized.as_slice()).map_err(|e| {
112+
PyErr::new::<pyo3::exceptions::PyValueError, _>(
113+
format!("Failed to decode StickerMessage: {}", e),
114+
)
115+
})?;
116+
client.download(&media).await
117+
}
118+
_ => {
119+
// Fallback path for unknown wrappers from Python side.
120+
if let Ok(media) = wa::ImageMessage::decode(serialized.as_slice()) {
121+
client.download(&media).await
122+
} else if let Ok(media) = wa::VideoMessage::decode(serialized.as_slice()) {
123+
client.download(&media).await
124+
} else if let Ok(media) = wa::DocumentMessage::decode(serialized.as_slice()) {
125+
client.download(&media).await
126+
} else if let Ok(media) = wa::AudioMessage::decode(serialized.as_slice()) {
127+
client.download(&media).await
128+
} else if let Ok(media) = wa::StickerMessage::decode(serialized.as_slice()) {
129+
client.download(&media).await
130+
} else {
131+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
132+
"Failed to decode message as supported media message",
133+
));
134+
}
135+
}
136+
};
137+
138+
download.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
139+
})
140+
}
58141
fn upload_file<'py>(&self, py: Python<'py>, path: String, media_type: Py<MediaType>) -> PyResult<Bound<'py, PyAny>> {
59142
let client = self.client_rx.borrow().clone().ok_or_else(|| {
60143
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Bot is not running")
61144
})?;
62145
let media_type_enum = media_type.bind(py).borrow_mut().to_wacore_enum();
63146
let locals = get_current_locals(py)?;
64-
let data = std::fs::read(path.clone()).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
147+
let data = std::fs::read(&path)
148+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
65149
future_into_py_with_locals(py, locals, async move {
66150
let url = client
67151
.upload(data, media_type_enum)

src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,27 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1111
let backend = Arc::new(SqliteStore::new("whatsapp.db").await?);
1212

1313
// Build the bot
14-
let mut bot = Bot::builder()
14+
let bot = Bot::builder()
1515
.with_backend(backend)
1616
.with_transport_factory(TokioWebSocketTransportFactory::new())
1717
.with_http_client(UreqHttpClient::new())
18-
.on_event(|event, client| async move {
18+
.on_event(|event, _client| async move {
1919
match event {
2020
Event::PairingQrCode { code, .. } => {
2121
println!("Scan this QR code with WhatsApp:\n{}", code);
2222
}
2323
Event::Message(msg, info) => {
2424
println!("Message from {}: {:?}", info.source.sender, msg);
2525
}
26-
Event::Connected(e)=> {
26+
Event::Connected(_e)=> {
2727
println!("Connected to WhatsApp");
2828
}
2929
_ => {}
3030
}
3131
})
3232
.build();
3333
let mut bot2 = bot.await?;
34-
let g = bot2.client();
34+
let _g = bot2.client();
3535

3636
// Start the bot
3737
bot2.run().await?.await?;

src/types.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use std::sync::Arc;
22

3-
use pyo3::{PyAny, PyErr, PyResult, Python, exceptions::{PyException, PyRuntimeError}, ffi::PyObject, pyclass, pymethods, types::{PyAnyMethods, PyBytes, PyDateTime, PyType}};
4-
use waproto::whatsapp::VerifiedNameCertificate;
3+
use pyo3::{PyAny, PyErr, PyResult, Python, exceptions::PyRuntimeError, pyclass, pymethods, types::{PyAnyMethods, PyBytes, PyDateTime}};
54
use whatsapp_rust::{Jid as WhatsAppJID};
6-
use wacore::types::message::{BotEditType, EditAttribute, MessageInfo as WhatsAppMessageInfo, MessageSource as WhatsAppMessageSource, MsgBotInfo as WhatsAppMsgBotInfo, MsgMetaInfo as WhatsappMsgMetaInfo};
5+
use wacore::types::message::{BotEditType, EditAttribute, MessageInfo as WhatsAppMessageInfo, MessageSource as WhatsAppMessageSource, MsgBotInfo as WhatsAppMsgBotInfo};
76
use prost::Message;
87

98

0 commit comments

Comments
 (0)