@@ -14,6 +14,8 @@ use crate::{
1414} ;
1515
1616const COUNTER_INCREMENT_HEADER : & str = "Nats-Incr" ;
17+ const COUNTER_SOURCES_HEADER : & str = "Nats-Counter-Sources" ;
18+ type CounterSources = HashMap < String , HashMap < String , i128 > > ;
1719
1820#[ pyo3:: pyclass( from_py_object, get_all, set_all) ]
1921#[ derive( Debug , Clone , Default ) ]
@@ -250,6 +252,38 @@ impl TryFrom<CountersConfig> for async_nats::jetstream::stream::Config {
250252 }
251253}
252254
255+ #[ derive( Debug , Clone , serde:: Serialize , serde:: Deserialize ) ]
256+ pub struct CounterPayload < ' a > {
257+ val : & ' a str ,
258+ }
259+
260+ #[ pyo3:: pyclass( from_py_object, get_all) ]
261+ #[ derive( Clone ) ]
262+ pub struct CounterEntry {
263+ pub subject : String ,
264+ pub value : i128 ,
265+ pub sources : CounterSources ,
266+ pub increment : Option < i128 > ,
267+ }
268+
269+ impl TryFrom < async_nats:: jetstream:: message:: StreamMessage > for CounterEntry {
270+ type Error = NatsrpyError ;
271+
272+ fn try_from ( value : async_nats:: jetstream:: message:: StreamMessage ) -> Result < Self , Self :: Error > {
273+ let counter_value = serde_json:: from_slice :: < CounterPayload > ( & value. payload ) ?
274+ . val
275+ . parse :: < i128 > ( ) ?;
276+ let sources = parse_sources ( & value. headers ) ?;
277+ let increment = parse_increment ( & value. headers ) ?;
278+ Ok ( Self {
279+ subject : value. subject . to_string ( ) ,
280+ value : counter_value,
281+ sources,
282+ increment,
283+ } )
284+ }
285+ }
286+
253287#[ pyo3:: pyclass]
254288#[ allow( dead_code) ]
255289pub struct Counters {
@@ -269,6 +303,31 @@ impl Counters {
269303 }
270304}
271305
306+ fn parse_sources ( headers : & HeaderMap ) -> NatsrpyResult < CounterSources > {
307+ let Some ( sources) = headers. get ( COUNTER_SOURCES_HEADER ) else {
308+ return Ok ( CounterSources :: new ( ) ) ;
309+ } ;
310+ let raw_sources =
311+ serde_json:: from_str :: < HashMap < String , HashMap < String , String > > > ( sources. as_str ( ) ) ?;
312+ let mut sources = CounterSources :: new ( ) ;
313+ for ( source_id, subjects) in raw_sources {
314+ let mut subject_values = HashMap :: new ( ) ;
315+ for ( subject, value_str) in subjects {
316+ subject_values. insert ( subject, value_str. parse ( ) ?) ;
317+ }
318+ sources. insert ( source_id, subject_values) ;
319+ }
320+
321+ Ok ( sources)
322+ }
323+
324+ pub fn parse_increment ( headers : & HeaderMap ) -> NatsrpyResult < Option < i128 > > {
325+ let Some ( header_value) = headers. get ( COUNTER_INCREMENT_HEADER ) else {
326+ return Ok ( None ) ;
327+ } ;
328+ Ok ( Some ( header_value. as_str ( ) . parse ( ) ?) )
329+ }
330+
272331#[ pyo3:: pymethods]
273332impl Counters {
274333 #[ pyo3( signature=( key, value, timeout=None ) ) ]
@@ -321,10 +380,28 @@ impl Counters {
321380 ) -> NatsrpyResult < Bound < ' py , PyAny > > {
322381 self . add ( py, key, -1 , timeout)
323382 }
383+
384+ #[ pyo3( signature=( key, timeout=None ) ) ]
385+ pub fn get < ' py > (
386+ & self ,
387+ py : Python < ' py > ,
388+ key : String ,
389+ timeout : Option < TimeValue > ,
390+ ) -> NatsrpyResult < Bound < ' py , PyAny > > {
391+ let stream_guard = self . stream . clone ( ) ;
392+ natsrpy_future_with_timeout ( py, timeout, async move {
393+ let message = stream_guard
394+ . read ( )
395+ . await
396+ . direct_get_last_for_subject ( key)
397+ . await ?;
398+ CounterEntry :: try_from ( message)
399+ } )
400+ }
324401}
325402
326403#[ pyo3:: pymodule( submodule, name = "counters" ) ]
327404pub mod pymod {
328405 #[ pymodule_export]
329- use super :: { Counters , CountersConfig } ;
406+ use super :: { CounterEntry , Counters , CountersConfig } ;
330407}
0 commit comments