Skip to content

Commit 5dd165d

Browse files
authored
Updated subscription interface (#14)
1 parent fa0b9e0 commit 5dd165d

File tree

10 files changed

+394
-65
lines changed

10 files changed

+394
-65
lines changed

python/natsrpy/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from ._natsrpy_rs import Message, Nats, Subscription
1+
from ._natsrpy_rs import CallbackSubscription, IteratorSubscription, Message, Nats
22

33
__all__ = [
4+
"CallbackSubscription",
5+
"IteratorSubscription",
46
"Message",
57
"Nats",
6-
"Subscription",
78
]

python/natsrpy/_natsrpy_rs/__init__.pyi

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
from collections.abc import Awaitable, Callable
12
from datetime import timedelta
2-
from typing import Any
3+
from typing import Any, overload
34

45
from natsrpy._natsrpy_rs.js import JetStream
56
from natsrpy._natsrpy_rs.message import Message
67

7-
class Subscription:
8-
def __aiter__(self) -> Subscription: ...
8+
class IteratorSubscription:
9+
def __aiter__(self) -> IteratorSubscription: ...
910
async def __anext__(self) -> Message: ...
11+
async def unsubscribe(self, limit: int | None = None) -> None: ...
12+
async def drain(self) -> None: ...
13+
14+
class CallbackSubscription:
15+
async def unsubscribe(self, limit: int | None = None) -> None: ...
16+
async def drain(self) -> None: ...
1017

1118
class Nats:
1219
def __init__(
@@ -37,7 +44,18 @@ class Nats:
3744
async def request(self, subject: str, payload: bytes) -> None: ...
3845
async def drain(self) -> None: ...
3946
async def flush(self) -> None: ...
40-
async def subscribe(self, subject: str) -> Subscription: ...
47+
@overload
48+
async def subscribe(
49+
self,
50+
subject: str,
51+
callback: Callable[[Message], Awaitable[None]],
52+
) -> CallbackSubscription: ...
53+
@overload
54+
async def subscribe(
55+
self,
56+
subject: str,
57+
callback: None = None,
58+
) -> IteratorSubscription: ...
4159
async def jetstream(self) -> JetStream: ...
4260

43-
__all__ = ["Message", "Nats", "Subscription"]
61+
__all__ = ["CallbackSubscription", "IteratorSubscription", "Message", "Nats"]

src/exceptions/rust_err.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub enum NatsrpyError {
4141
#[error(transparent)]
4242
SubscribeError(#[from] async_nats::SubscribeError),
4343
#[error(transparent)]
44+
UnsubscribeError(#[from] async_nats::UnsubscribeError),
45+
#[error(transparent)]
4446
KeyValueError(#[from] async_nats::jetstream::context::KeyValueError),
4547
#[error(transparent)]
4648
CreateKeyValueError(#[from] async_nats::jetstream::context::CreateKeyValueError),

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub mod exceptions;
2424
pub mod js;
2525
pub mod message;
2626
pub mod nats_cls;
27-
pub mod subscription;
27+
pub mod subscriptions;
2828
pub mod utils;
2929

3030
#[pyo3::pymodule]
@@ -38,7 +38,7 @@ pub mod _natsrpy_rs {
3838
#[pymodule_export]
3939
use super::nats_cls::NatsCls;
4040
#[pymodule_export]
41-
use super::subscription::Subscription;
41+
use super::subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription};
4242

4343
#[pymodule_export]
4444
use super::js::pymod as js;

src/message.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ pub struct Message {
1717
pub length: usize,
1818
}
1919

20-
impl TryFrom<async_nats::Message> for Message {
20+
impl TryFrom<&async_nats::Message> for Message {
2121
type Error = NatsrpyError;
2222

23-
fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
23+
fn try_from(value: &async_nats::Message) -> Result<Self, Self::Error> {
2424
Python::attach(move |gil| {
25-
let headers = match value.headers {
25+
let headers = match &value.headers {
2626
Some(headermap) => headermap.to_pydict(gil)?.unbind(),
2727
None => PyDict::new(gil).unbind(),
2828
};
@@ -32,13 +32,21 @@ impl TryFrom<async_nats::Message> for Message {
3232
payload: PyBytes::new(gil, &value.payload).unbind(),
3333
headers,
3434
status: value.status.map(Into::<u16>::into),
35-
description: value.description,
35+
description: value.description.clone(),
3636
length: value.length,
3737
})
3838
})
3939
}
4040
}
4141

42+
impl TryFrom<async_nats::Message> for Message {
43+
type Error = NatsrpyError;
44+
45+
fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
46+
Self::try_from(&value)
47+
}
48+
}
49+
4250
#[pyo3::pymethods]
4351
impl Message {
4452
#[must_use]

src/nats_cls.rs

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage};
22
use pyo3::{
3-
Bound, PyAny, PyResult, Python,
3+
Bound, IntoPyObjectExt, Py, PyAny, Python,
44
types::{PyBytes, PyBytesMethods, PyDict},
55
};
66
use std::{sync::Arc, time::Duration};
77
use tokio::sync::RwLock;
88

99
use crate::{
10-
exceptions::rust_err::NatsrpyError,
11-
subscription::Subscription,
10+
exceptions::rust_err::{NatsrpyError, NatsrpyResult},
11+
subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription},
1212
utils::{
1313
futures::natsrpy_future_with_timeout,
1414
headers::NatsrpyHeadermapExt,
@@ -75,7 +75,7 @@ impl NatsCls {
7575
}
7676
}
7777

78-
pub fn startup<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
78+
pub fn startup<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
7979
let mut conn_opts = async_nats::ConnectOptions::new();
8080
if let Some((username, passwd)) = &self.user_and_pass {
8181
conn_opts = conn_opts.user_and_password(username.clone(), passwd.clone());
@@ -100,23 +100,19 @@ impl NatsCls {
100100
let session = self.nats_session.clone();
101101
let address = self.addr.clone();
102102
let timeout = self.connection_timeout;
103-
return Ok(natsrpy_future_with_timeout(
104-
py,
105-
Some(timeout),
106-
async move {
107-
if session.read().await.is_some() {
108-
return Err(NatsrpyError::SessionError(
109-
"NATS session already exists".to_string(),
110-
));
111-
}
112-
// Scoping for early-dropping of a guard.
113-
{
114-
let mut sesion_guard = session.write().await;
115-
*sesion_guard = Some(conn_opts.connect(address).await?);
116-
}
117-
Ok(())
118-
},
119-
)?);
103+
natsrpy_future_with_timeout(py, Some(timeout), async move {
104+
if session.read().await.is_some() {
105+
return Err(NatsrpyError::SessionError(
106+
"NATS session already exists".to_string(),
107+
));
108+
}
109+
// Scoping for early-dropping of a guard.
110+
{
111+
let mut sesion_guard = session.write().await;
112+
*sesion_guard = Some(conn_opts.connect(address).await?);
113+
}
114+
Ok(())
115+
})
120116
}
121117

122118
#[pyo3(signature = (subject, payload, *, headers=None, reply=None, err_on_disconnect = false))]
@@ -128,14 +124,14 @@ impl NatsCls {
128124
headers: Option<Bound<PyDict>>,
129125
reply: Option<String>,
130126
err_on_disconnect: bool,
131-
) -> PyResult<Bound<'py, PyAny>> {
127+
) -> NatsrpyResult<Bound<'py, PyAny>> {
132128
let session = self.nats_session.clone();
133129
log::info!("Payload: {payload:?}");
134130
let data = payload.into();
135131
let headermap = headers
136132
.map(async_nats::HeaderMap::from_pydict)
137133
.transpose()?;
138-
Ok(natsrpy_future(py, async move {
134+
natsrpy_future(py, async move {
139135
if let Some(session) = session.read().await.as_ref() {
140136
if err_on_disconnect
141137
&& session.connection_state() == async_nats::connection::State::Disconnected
@@ -154,7 +150,7 @@ impl NatsCls {
154150
} else {
155151
Err(NatsrpyError::NotInitialized)
156152
}
157-
})?)
153+
})
158154
}
159155

160156
#[pyo3(signature = (subject, payload, *, headers=None, inbox = None, timeout=None))]
@@ -166,13 +162,13 @@ impl NatsCls {
166162
headers: Option<Bound<PyDict>>,
167163
inbox: Option<String>,
168164
timeout: Option<Duration>,
169-
) -> PyResult<Bound<'py, PyAny>> {
165+
) -> NatsrpyResult<Bound<'py, PyAny>> {
170166
let session = self.nats_session.clone();
171167
let data = payload.map(|inner| bytes::Bytes::from(inner.as_bytes().to_vec()));
172168
let headermap = headers
173169
.map(async_nats::HeaderMap::from_pydict)
174170
.transpose()?;
175-
Ok(natsrpy_future(py, async move {
171+
natsrpy_future(py, async move {
176172
if let Some(session) = session.read().await.as_ref() {
177173
let request = async_nats::Request {
178174
payload: data,
@@ -185,32 +181,44 @@ impl NatsCls {
185181
} else {
186182
Err(NatsrpyError::NotInitialized)
187183
}
188-
})?)
184+
})
189185
}
190186

191-
pub fn drain<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
187+
pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
192188
log::debug!("Draining NATS session");
193189
let session = self.nats_session.clone();
194-
Ok(natsrpy_future(py, async move {
190+
natsrpy_future(py, async move {
195191
if let Some(session) = session.write().await.as_ref() {
196192
session.drain().await?;
197193
Ok(())
198194
} else {
199195
Err(NatsrpyError::NotInitialized)
200196
}
201-
})?)
197+
})
202198
}
203199

204-
pub fn subscribe<'py>(&self, py: Python<'py>, subject: String) -> PyResult<Bound<'py, PyAny>> {
200+
#[pyo3(signature=(subject, callback=None))]
201+
pub fn subscribe<'py>(
202+
&self,
203+
py: Python<'py>,
204+
subject: String,
205+
callback: Option<Py<PyAny>>,
206+
) -> NatsrpyResult<Bound<'py, PyAny>> {
205207
log::debug!("Subscribing to '{subject}'");
206208
let session = self.nats_session.clone();
207-
Ok(natsrpy_future(py, async move {
209+
natsrpy_future(py, async move {
208210
if let Some(session) = session.read().await.as_ref() {
209-
Ok(Subscription::new(session.subscribe(subject).await?))
211+
if let Some(cb) = callback {
212+
let sub = CallbackSubscription::new(session.subscribe(subject).await?, cb)?;
213+
Ok(Python::attach(|gil| sub.into_py_any(gil))?)
214+
} else {
215+
let sub = IteratorSubscription::new(session.subscribe(subject).await?);
216+
Ok(Python::attach(|gil| sub.into_py_any(gil))?)
217+
}
210218
} else {
211219
Err(NatsrpyError::NotInitialized)
212220
}
213-
})?)
221+
})
214222
}
215223

216224
#[pyo3(signature = (
@@ -233,10 +241,10 @@ impl NatsCls {
233241
concurrency_limit: Option<usize>,
234242
max_ack_inflight: Option<usize>,
235243
backpressure_on_inflight: Option<bool>,
236-
) -> PyResult<Bound<'py, PyAny>> {
244+
) -> NatsrpyResult<Bound<'py, PyAny>> {
237245
log::debug!("Creating JetStream context");
238246
let session = self.nats_session.clone();
239-
Ok(natsrpy_future(py, async move {
247+
natsrpy_future(py, async move {
240248
let mut builder =
241249
async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit);
242250
if let Some(timeout) = ack_timeout {
@@ -269,13 +277,13 @@ impl NatsCls {
269277
Ok(crate::js::jetstream::JetStream::new(js))
270278
},
271279
)
272-
})?)
280+
})
273281
}
274282

275-
pub fn shutdown<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
283+
pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
276284
log::debug!("Closing nats session");
277285
let session = self.nats_session.clone();
278-
Ok(natsrpy_future(py, async move {
286+
natsrpy_future(py, async move {
279287
let mut write_guard = session.write().await;
280288
let Some(session) = write_guard.as_ref() else {
281289
return Err(NatsrpyError::NotInitialized);
@@ -284,20 +292,20 @@ impl NatsCls {
284292
*write_guard = None;
285293
drop(write_guard);
286294
Ok(())
287-
})?)
295+
})
288296
}
289297

290-
pub fn flush<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
298+
pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
291299
log::debug!("Flushing streams");
292300
let session = self.nats_session.clone();
293-
Ok(natsrpy_future(py, async move {
301+
natsrpy_future(py, async move {
294302
if let Some(session) = session.write().await.as_ref() {
295303
session.flush().await?;
296304
Ok(())
297305
} else {
298306
Err(NatsrpyError::NotInitialized)
299307
}
300-
})?)
308+
})
301309
}
302310
}
303311

0 commit comments

Comments
 (0)