@@ -17,6 +17,7 @@ module Test.Ouroboros.Network.TxSubmission.Types
1717 , readMempool
1818 , getMempoolReader
1919 , getMempoolWriter
20+ , InvalidTx (.. )
2021 , maxTxSize
2122 , LargeNonEmptyList (.. )
2223 , SimResults (.. )
@@ -31,6 +32,7 @@ import Prelude hiding (seq)
3132import NoThunks.Class
3233
3334import Control.Concurrent.Class.MonadSTM
35+ import Control.Concurrent.Class.MonadSTM.Strict qualified as StrictSTM
3436import Control.DeepSeq
3537import Control.Exception (SomeException (.. ))
3638import Control.Monad.Class.MonadAsync
@@ -48,6 +50,10 @@ import Codec.CBOR.Encoding qualified as CBOR
4850import Codec.CBOR.Read qualified as CBOR
4951
5052import Data.ByteString.Lazy (ByteString )
53+ import Data.Either (partitionEithers )
54+ import Data.List qualified as List
55+ import Data.Sequence qualified as Seq
56+ import Data.Set qualified as Set
5157import Data.Typeable (Typeable )
5258import GHC.Generics (Generic )
5359
@@ -137,18 +143,73 @@ getMempoolWriter :: forall txid m.
137143 => TVar m [txid ]
138144 -> Mempool m txid (Tx txid )
139145 -> TxSubmissionMempoolWriter txid (Tx txid ) Integer m InvalidTx
140- getMempoolWriter duplicateVar =
141- Mempool. getWriter DuplicateTx
142- getTxId
143- (\ _ txs -> return
144- [ if getTxValid tx
145- then Right tx
146- else Left (getTxId tx, InvalidTx )
147- | tx <- txs
148- ]
149- )
150- (\ t -> atomically $ modifyTVar' duplicateVar
151- (map fst (filter ((== DuplicateTx ) . snd ) t) <> ))
146+ getMempoolWriter duplicateVar (Mempool. Mempool mempoolVar) =
147+ TxSubmissionMempoolWriter {
148+ txId = getTxId,
149+ mempoolAddTxs = \ txs -> do
150+ (acceptedTxs, rejectedTxs, duplicateValidTxIds) <- atomically $ do
151+ Mempool. MempoolSeq { Mempool. mempoolSet, Mempool. mempoolSeq, Mempool. nextIdx } <-
152+ StrictSTM. readTVar mempoolVar
153+
154+ let (duplicateTxs, txsToValidate) =
155+ List. partition (\ tx -> getTxId tx `Set.member` mempoolSet) txs
156+ duplicateRejectedTxs =
157+ [ (getTxId tx, DuplicateTx )
158+ | tx <- duplicateTxs
159+ ]
160+ duplicateValidTxIds =
161+ [ getTxId tx
162+ | tx <- duplicateTxs
163+ , getTxValid tx
164+ ]
165+ (invalidRejectedTxs, validTxs) =
166+ partitionEithers
167+ [ if getTxValid tx
168+ then Right tx
169+ else Left (getTxId tx, InvalidTx )
170+ | tx <- txsToValidate
171+ ]
172+
173+ (delta, mempoolSeq', nextIdx', acceptedTxs, duplicateValidTxIds') =
174+ List. foldl'
175+ (\ (set, seq , idx, accepted, duplicates) tx ->
176+ let txid = getTxId tx in
177+ if txid `Set.member` set
178+ then ( set
179+ , seq
180+ , idx
181+ , accepted
182+ , txid : duplicates
183+ )
184+ else ( Set. insert txid set
185+ , seq Seq. |> Mempool. WithIndex idx tx
186+ , succ idx
187+ , txid : accepted
188+ , duplicates
189+ )
190+ )
191+ (Set. empty, mempoolSeq, nextIdx, [] , [] )
192+ validTxs
193+
194+ StrictSTM. writeTVar
195+ mempoolVar
196+ Mempool. MempoolSeq {
197+ Mempool. mempoolSet = mempoolSet `Set.union` delta,
198+ Mempool. mempoolSeq = mempoolSeq',
199+ Mempool. nextIdx = nextIdx'
200+ }
201+
202+ pure
203+ ( acceptedTxs
204+ , invalidRejectedTxs
205+ ++ duplicateRejectedTxs
206+ ++ [ (txid, DuplicateTx ) | txid <- duplicateValidTxIds' ]
207+ , duplicateValidTxIds ++ duplicateValidTxIds'
208+ )
209+
210+ atomically $ modifyTVar' duplicateVar (duplicateValidTxIds <> )
211+ pure (acceptedTxs, rejectedTxs)
212+ }
152213
153214
154215txSubmissionCodec2 :: MonadST m
0 commit comments