Skip to content

Commit e06d56b

Browse files
committed
standartise stream close in getters
1 parent 1dd3546 commit e06d56b

6 files changed

Lines changed: 36 additions & 34 deletions

File tree

share/shwap/p2p/shrex/error.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package shrex
22

33
import "errors"
44

5-
// errorContains reports whether any error in err's tree matches any error in targets tree.
6-
func errorContains(err, target error) bool {
5+
// ErrorContains reports whether any error in err's tree matches any error in targets tree.
6+
func ErrorContains(err, target error) bool {
77
if errors.Is(err, target) || target == nil {
88
return true
99
}
@@ -12,5 +12,5 @@ func errorContains(err, target error) bool {
1212
if target == nil {
1313
return false
1414
}
15-
return errorContains(err, target)
15+
return ErrorContains(err, target)
1616
}

share/shwap/p2p/shrex/error_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ func Test_ErrorContains(t *testing.T) {
105105
t.Run(tt.name, func(t *testing.T) {
106106
assert.Equalf(t,
107107
tt.want,
108-
errorContains(tt.args.err, tt.args.target),
109-
"errorContains(%v, %v)", tt.args.err, tt.args.target)
108+
ErrorContains(tt.args.err, tt.args.target),
109+
"ErrorContains(%v, %v)", tt.args.err, tt.args.target)
110110
})
111111
}
112112
}

share/shwap/p2p/shrex/shrex_getter/shrex.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (sg *Getter) GetEDS(ctx context.Context, header *header.ExtendedHeader) (*r
201201
setStatus(peers.ResultCooldownPeer)
202202
}
203203

204-
if !shrex.errorContains(err, getErr) {
204+
if !shrex.ErrorContains(err, getErr) {
205205
err = errors.Join(err, getErr)
206206
}
207207
log.Debugw("eds: request failed",
@@ -285,7 +285,7 @@ func (sg *Getter) GetSharesByNamespace(
285285
setStatus(peers.ResultCooldownPeer)
286286
}
287287

288-
if !shrex.errorContains(err, getErr) {
288+
if !shrex.ErrorContains(err, getErr) {
289289
err = errors.Join(err, getErr)
290290
}
291291
log.Debugw("nd: request failed",

share/shwap/p2p/shrex/shrexeds/client.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/celestiaorg/go-libp2p-messenger/serde"
1818
"github.com/celestiaorg/rsmt2d"
1919

20+
"github.com/celestiaorg/celestia-node/libs/utils"
2021
"github.com/celestiaorg/celestia-node/share"
2122
eds "github.com/celestiaorg/celestia-node/share/new_eds"
2223
"github.com/celestiaorg/celestia-node/share/shwap"
@@ -96,7 +97,7 @@ func (c *Client) doRequest(
9697
if err != nil {
9798
return nil, fmt.Errorf("open stream: %w", err)
9899
}
99-
defer stream.Close()
100+
defer utils.CloseAndLog(log, "client", stream)
100101

101102
c.setStreamDeadlines(ctx, stream)
102103
// request ODS
@@ -109,7 +110,6 @@ func (c *Client) doRequest(
109110
}
110111
_, err = id.WriteTo(stream)
111112
if err != nil {
112-
stream.Reset() //nolint:errcheck
113113
return nil, fmt.Errorf("write request to stream: %w", err)
114114
}
115115

@@ -131,7 +131,6 @@ func (c *Client) doRequest(
131131
c.metrics.ObserveRequests(ctx, 1, shrex.StatusRateLimited)
132132
return nil, shrex.ErrNotFound
133133
}
134-
stream.Reset() //nolint:errcheck
135134
return nil, fmt.Errorf("read status from stream: %w", err)
136135
}
137136
switch resp.Status {

share/shwap/p2p/shrex/shrexeds/server.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func NewServer(params *Parameters, host host.Host, store *store.Store) (*Server,
5454

5555
func (s *Server) Start(context.Context) error {
5656
s.ctx, s.cancel = context.WithCancel(context.Background())
57-
s.host.SetStreamHandler(s.protocolID, s.middleware.RateLimitHandler(s.handleStream))
57+
s.host.SetStreamHandler(s.protocolID, s.middleware.RateLimitHandler(s.streamHandler(s.ctx)))
5858
return nil
5959
}
6060

@@ -71,18 +71,29 @@ func (s *Server) observeRateLimitedRequests() {
7171
}
7272
}
7373

74-
func (s *Server) handleStream(stream network.Stream) {
74+
func (srv *Server) streamHandler(ctx context.Context) network.StreamHandler {
75+
return func(s network.Stream) {
76+
err := srv.handleEDS(s)
77+
if err != nil {
78+
s.Reset() //nolint:errcheck
79+
return
80+
}
81+
srv.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess)
82+
if err = s.Close(); err != nil {
83+
log.Debugw("server: closing stream", "err", err)
84+
}
85+
}
86+
}
87+
88+
func (s *Server) handleEDS(stream network.Stream) error {
7589
logger := log.With("peer", stream.Conn().RemotePeer().String())
7690
logger.Debug("server: handling eds request")
7791

78-
s.observeRateLimitedRequests()
79-
8092
// read request from stream to get the dataHash for store lookup
8193
id, err := s.readRequest(logger, stream)
8294
if err != nil {
8395
logger.Warnw("server: reading request from stream", "err", err)
84-
stream.Reset() //nolint:errcheck
85-
return
96+
return err
8697
}
8798

8899
logger = logger.With("height", id.Height)
@@ -112,31 +123,20 @@ func (s *Server) handleStream(stream network.Stream) {
112123
err = s.writeStatus(logger, status, stream)
113124
if err != nil {
114125
logger.Warnw("server: writing status to stream", "err", err)
115-
stream.Reset() //nolint:errcheck
116-
return
126+
return err
117127
}
118128
// if we cannot serve the EDS, we are already done
119129
if status != shrexpb.Status_OK {
120-
err = stream.Close()
121-
if err != nil {
122-
logger.Debugw("server: closing stream", "err", err)
123-
}
124-
return
130+
return nil
125131
}
126132

127133
// start streaming the ODS to the client
128134
err = s.writeODS(logger, file, stream)
129135
if err != nil {
130136
logger.Warnw("server: writing ods to stream", "err", err)
131-
stream.Reset() //nolint:errcheck
132-
return
133-
}
134-
135-
s.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess)
136-
err = stream.Close()
137-
if err != nil {
138-
logger.Debugw("server: closing stream", "err", err)
137+
return err
139138
}
139+
return nil
140140
}
141141

142142
func (s *Server) readRequest(logger *zap.SugaredLogger, stream network.Stream) (shwap.EdsID, error) {

share/shwap/p2p/shrex/shrexnd/client.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"github.com/celestiaorg/go-libp2p-messenger/serde"
1717

18+
"github.com/celestiaorg/celestia-node/libs/utils"
1819
"github.com/celestiaorg/celestia-node/share"
1920
"github.com/celestiaorg/celestia-node/share/shwap"
2021
"github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex"
@@ -85,11 +86,13 @@ func (c *Client) doRequest(
8586
namespace share.Namespace,
8687
peerID peer.ID,
8788
) (shwap.NamespacedData, error) {
88-
stream, err := c.host.NewStream(ctx, peerID, c.protocolID)
89+
streamOpenCtx, cancel := context.WithTimeout(ctx, c.params.ServerReadTimeout)
90+
defer cancel()
91+
stream, err := c.host.NewStream(streamOpenCtx, peerID, c.protocolID)
8992
if err != nil {
9093
return nil, err
9194
}
92-
defer stream.Close()
95+
defer utils.CloseAndLog(log, "client", stream)
9396

9497
c.setStreamDeadlines(ctx, stream)
9598

@@ -101,7 +104,6 @@ func (c *Client) doRequest(
101104
_, err = req.WriteTo(stream)
102105
if err != nil {
103106
c.metrics.ObserveRequests(ctx, 1, shrex.StatusSendReqErr)
104-
stream.Reset() //nolint:errcheck
105107
return nil, fmt.Errorf("client-nd: writing request: %w", err)
106108
}
107109

@@ -111,6 +113,7 @@ func (c *Client) doRequest(
111113
}
112114

113115
if err := c.readStatus(ctx, stream); err != nil {
116+
c.metrics.ObserveRequests(ctx, 1, shrex.StatusReadRespErr)
114117
return nil, err
115118
}
116119

0 commit comments

Comments
 (0)