Skip to content

Commit 16ee1a7

Browse files
authored
Small enhancements. (#45)
1 parent 8a69d10 commit 16ee1a7

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

src/nats_cls.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ pub struct NatsCls {
2828
request_timeout: Option<TimeValue>,
2929
}
3030

31-
/// Helper to read the client from the `RwLock`. Returns a clone of the Client if present.
32-
fn get_client(session: &RwLock<Option<async_nats::Client>>) -> NatsrpyResult<async_nats::Client> {
33-
session
34-
.read()
35-
.map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))?
36-
.clone()
37-
.ok_or(NatsrpyError::NotInitialized)
31+
impl NatsCls {
32+
// Small utility for getting nats session.
33+
fn get_client(&self) -> NatsrpyResult<async_nats::Client> {
34+
self.nats_session
35+
.read()
36+
.map_err(|_| NatsrpyError::PoisonedLock)?
37+
.clone()
38+
.ok_or(NatsrpyError::NotInitialized)
39+
}
3840
}
3941

4042
#[pyo3::pymethods]
@@ -137,7 +139,7 @@ impl NatsCls {
137139
reply: Option<String>,
138140
err_on_disconnect: bool,
139141
) -> NatsrpyResult<Bound<'py, PyAny>> {
140-
let client = get_client(&self.nats_session)?;
142+
let client = self.get_client()?;
141143
let data = bytes::Bytes::from(payload);
142144
let headermap = headers
143145
.map(async_nats::HeaderMap::from_pydict)
@@ -175,7 +177,7 @@ impl NatsCls {
175177
inbox: Option<String>,
176178
timeout: Option<TimeValue>,
177179
) -> NatsrpyResult<Bound<'py, PyAny>> {
178-
let client = get_client(&self.nats_session)?;
180+
let client = self.get_client()?;
179181
let data = payload.map(bytes::Bytes::from);
180182
let headermap = headers
181183
.map(async_nats::HeaderMap::from_pydict)
@@ -198,7 +200,7 @@ impl NatsCls {
198200

199201
pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
200202
log::debug!("Draining NATS session");
201-
let client = get_client(&self.nats_session)?;
203+
let client = self.get_client()?;
202204
natsrpy_future(py, async move {
203205
client.drain().await?;
204206
Ok(())
@@ -214,7 +216,7 @@ impl NatsCls {
214216
queue: Option<String>,
215217
) -> NatsrpyResult<Bound<'py, PyAny>> {
216218
log::debug!("Subscribing to '{subject}'");
217-
let client = get_client(&self.nats_session)?;
219+
let client = self.get_client()?;
218220
natsrpy_future(py, async move {
219221
let subscriber = if let Some(queue) = queue {
220222
client.queue_subscribe(subject, queue).await?
@@ -258,7 +260,7 @@ impl NatsCls {
258260
"Either domain or api_prefix should be specified, not both.",
259261
)));
260262
}
261-
let client = get_client(&self.nats_session)?;
263+
let client = self.get_client()?;
262264
natsrpy_future(py, async move {
263265
let mut builder =
264266
async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit);
@@ -288,7 +290,7 @@ impl NatsCls {
288290
pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
289291
log::debug!("Closing nats session");
290292
let session = self.nats_session.clone();
291-
let client = get_client(&session)?;
293+
let client = self.get_client()?;
292294
// Set session to None immediately so no new operations can start.
293295
{
294296
let mut guard = session
@@ -304,7 +306,7 @@ impl NatsCls {
304306

305307
pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
306308
log::debug!("Flushing streams");
307-
let client = get_client(&self.nats_session)?;
309+
let client = self.get_client()?;
308310
natsrpy_future(py, async move {
309311
client.flush().await?;
310312
Ok(())

src/utils/headers.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ impl NatsrpyHeadermapExt for async_nats::HeaderMap {
1414
fn from_pydict(pydict: Bound<PyDict>) -> NatsrpyResult<Self> {
1515
let mut headermap = Self::new();
1616
for (name, val) in pydict {
17-
let rs_name = name.extract::<String>()?;
18-
if let Ok(parsed_str) = val.extract::<String>() {
17+
let rs_name = name.extract::<&str>()?;
18+
if let Ok(parsed_str) = val.extract::<&str>() {
1919
headermap.insert(rs_name, parsed_str);
2020
continue;
2121
}
2222
if let Ok(parsed_list) = val.extract::<Vec<String>>() {
2323
for inner in parsed_list {
24-
headermap.append(rs_name.as_str(), inner);
24+
headermap.append(rs_name, inner);
2525
}
2626
continue;
2727
}

0 commit comments

Comments
 (0)