@@ -54,7 +54,7 @@ func NewServer(params *Parameters, host host.Host, store *store.Store) (*Server,
5454
5555func (s * Server ) Start (context.Context ) error {
5656 s .ctx , s .cancel = context .WithCancel (context .Background ())
57- s .host .SetStreamHandler (s .protocolID , s .middleware .RateLimitHandler (s .handleStream ))
57+ s .host .SetStreamHandler (s .protocolID , s .middleware .RateLimitHandler (s .streamHandler ( s . ctx ) ))
5858 return nil
5959}
6060
@@ -71,18 +71,29 @@ func (s *Server) observeRateLimitedRequests() {
7171 }
7272}
7373
74- func (s * Server ) handleStream (stream network.Stream ) {
74+ func (srv * Server ) streamHandler (ctx context.Context ) network.StreamHandler {
75+ return func (s network.Stream ) {
76+ err := srv .handleEDS (s )
77+ if err != nil {
78+ s .Reset () //nolint:errcheck
79+ return
80+ }
81+ srv .metrics .ObserveRequests (ctx , 1 , shrex .StatusSuccess )
82+ if err = s .Close (); err != nil {
83+ log .Debugw ("server: closing stream" , "err" , err )
84+ }
85+ }
86+ }
87+
88+ func (s * Server ) handleEDS (stream network.Stream ) error {
7589 logger := log .With ("peer" , stream .Conn ().RemotePeer ().String ())
7690 logger .Debug ("server: handling eds request" )
7791
78- s .observeRateLimitedRequests ()
79-
8092 // read request from stream to get the dataHash for store lookup
8193 id , err := s .readRequest (logger , stream )
8294 if err != nil {
8395 logger .Warnw ("server: reading request from stream" , "err" , err )
84- stream .Reset () //nolint:errcheck
85- return
96+ return err
8697 }
8798
8899 logger = logger .With ("height" , id .Height )
@@ -112,31 +123,20 @@ func (s *Server) handleStream(stream network.Stream) {
112123 err = s .writeStatus (logger , status , stream )
113124 if err != nil {
114125 logger .Warnw ("server: writing status to stream" , "err" , err )
115- stream .Reset () //nolint:errcheck
116- return
126+ return err
117127 }
118128 // if we cannot serve the EDS, we are already done
119129 if status != shrexpb .Status_OK {
120- err = stream .Close ()
121- if err != nil {
122- logger .Debugw ("server: closing stream" , "err" , err )
123- }
124- return
130+ return nil
125131 }
126132
127133 // start streaming the ODS to the client
128134 err = s .writeODS (logger , file , stream )
129135 if err != nil {
130136 logger .Warnw ("server: writing ods to stream" , "err" , err )
131- stream .Reset () //nolint:errcheck
132- return
133- }
134-
135- s .metrics .ObserveRequests (ctx , 1 , shrex .StatusSuccess )
136- err = stream .Close ()
137- if err != nil {
138- logger .Debugw ("server: closing stream" , "err" , err )
137+ return err
139138 }
139+ return nil
140140}
141141
142142func (s * Server ) readRequest (logger * zap.SugaredLogger , stream network.Stream ) (shwap.EdsID , error ) {
0 commit comments