Skip to content

Commit ce34736

Browse files
shoyufcodex
andcommitted
fix(smb): close smb sessions during transfers
Co-authored-by: Codex <codex@openai.com> Signed-off-by: shoyuf <shoyuf@shoyuf.top>
1 parent 054db9f commit ce34736

2 files changed

Lines changed: 133 additions & 36 deletions

File tree

drivers/smb/driver.go

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package smb
33
import (
44
"context"
55
"errors"
6+
"net"
67
"path"
78
"path/filepath"
89
"strings"
10+
"sync"
911

1012
"github.com/OpenListTeam/OpenList/v4/internal/driver"
1113
"github.com/OpenListTeam/OpenList/v4/internal/model"
@@ -19,7 +21,11 @@ type SMB struct {
1921
lastConnTime int64
2022
model.Storage
2123
Addition
22-
fs *smb2.Share
24+
connMu sync.Mutex
25+
activeOps int
26+
conn net.Conn
27+
session *smb2.Session
28+
fs *smb2.Share
2329
}
2430

2531
func (d *SMB) Config() driver.Config {
@@ -38,18 +44,17 @@ func (d *SMB) Init(ctx context.Context) error {
3844
}
3945

4046
func (d *SMB) Drop(ctx context.Context) error {
41-
if d.fs != nil {
42-
_ = d.fs.Umount()
43-
}
44-
return nil
47+
return d.closeFS()
4548
}
4649

4750
func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
48-
if err := d.checkConn(ctx); err != nil {
51+
fs, release, err := d.acquireConn(ctx)
52+
if err != nil {
4953
return nil, err
5054
}
55+
defer release()
5156
fullPath := dir.GetPath()
52-
rawFiles, err := d.fs.ReadDir(fullPath)
57+
rawFiles, err := fs.ReadDir(fullPath)
5358
if err != nil {
5459
d.cleanLastConnTime()
5560
return nil, err
@@ -72,11 +77,18 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m
7277
}
7378

7479
func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
75-
if err := d.checkConn(ctx); err != nil {
80+
fs, release, err := d.acquireConn(ctx)
81+
if err != nil {
7682
return nil, err
7783
}
84+
needRelease := true
85+
defer func() {
86+
if needRelease {
87+
release()
88+
}
89+
}()
7890
fullPath := file.GetPath()
79-
remoteFile, err := d.fs.Open(fullPath)
91+
remoteFile, err := fs.Open(fullPath)
8092
if err != nil {
8193
d.cleanLastConnTime()
8294
return nil, err
@@ -87,19 +99,25 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
8799
Limiter: stream.ServerDownloadLimit,
88100
Ctx: ctx,
89101
}
102+
needRelease = false
90103
return &model.Link{
91-
RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
92-
SyncClosers: utils.NewSyncClosers(remoteFile),
104+
RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
105+
SyncClosers: utils.NewSyncClosers(remoteFile, utils.CloseFunc(func() error {
106+
release()
107+
return nil
108+
})),
93109
RequireReference: true,
94110
}, nil
95111
}
96112

97113
func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {
98-
if err := d.checkConn(ctx); err != nil {
114+
fs, release, err := d.acquireConn(ctx)
115+
if err != nil {
99116
return err
100117
}
118+
defer release()
101119
fullPath := filepath.Join(parentDir.GetPath(), dirName)
102-
err := d.fs.MkdirAll(fullPath, 0700)
120+
err = fs.MkdirAll(fullPath, 0700)
103121
if err != nil {
104122
d.cleanLastConnTime()
105123
return err
@@ -109,12 +127,14 @@ func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string)
109127
}
110128

111129
func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
112-
if err := d.checkConn(ctx); err != nil {
130+
fs, release, err := d.acquireConn(ctx)
131+
if err != nil {
113132
return err
114133
}
134+
defer release()
115135
srcPath := srcObj.GetPath()
116136
dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName())
117-
err := d.fs.Rename(srcPath, dstPath)
137+
err = fs.Rename(srcPath, dstPath)
118138
if err != nil {
119139
d.cleanLastConnTime()
120140
return err
@@ -124,12 +144,14 @@ func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
124144
}
125145

126146
func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error {
127-
if err := d.checkConn(ctx); err != nil {
147+
fs, release, err := d.acquireConn(ctx)
148+
if err != nil {
128149
return err
129150
}
151+
defer release()
130152
srcPath := srcObj.GetPath()
131153
dstPath := filepath.Join(filepath.Dir(srcPath), newName)
132-
err := d.fs.Rename(srcPath, dstPath)
154+
err = fs.Rename(srcPath, dstPath)
133155
if err != nil {
134156
d.cleanLastConnTime()
135157
return err
@@ -139,12 +161,13 @@ func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) erro
139161
}
140162

141163
func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
142-
if err := d.checkConn(ctx); err != nil {
164+
_, release, err := d.acquireConn(ctx)
165+
if err != nil {
143166
return err
144167
}
168+
defer release()
145169
srcPath := srcObj.GetPath()
146170
dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName())
147-
var err error
148171
if srcObj.IsDir() {
149172
err = d.CopyDir(srcPath, dstPath)
150173
} else {
@@ -159,15 +182,16 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
159182
}
160183

161184
func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
162-
if err := d.checkConn(ctx); err != nil {
185+
fs, release, err := d.acquireConn(ctx)
186+
if err != nil {
163187
return err
164188
}
165-
var err error
189+
defer release()
166190
fullPath := obj.GetPath()
167191
if obj.IsDir() {
168-
err = d.fs.RemoveAll(fullPath)
192+
err = fs.RemoveAll(fullPath)
169193
} else {
170-
err = d.fs.Remove(fullPath)
194+
err = fs.Remove(fullPath)
171195
}
172196
if err != nil {
173197
d.cleanLastConnTime()
@@ -178,11 +202,13 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
178202
}
179203

180204
func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
181-
if err := d.checkConn(ctx); err != nil {
205+
fs, release, err := d.acquireConn(ctx)
206+
if err != nil {
182207
return err
183208
}
209+
defer release()
184210
fullPath := filepath.Join(dstDir.GetPath(), stream.GetName())
185-
out, err := d.fs.Create(fullPath)
211+
out, err := fs.Create(fullPath)
186212
if err != nil {
187213
d.cleanLastConnTime()
188214
return err
@@ -191,7 +217,7 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream
191217
defer func() {
192218
_ = out.Close()
193219
if errors.Is(err, context.Canceled) {
194-
_ = d.fs.Remove(fullPath)
220+
_ = fs.Remove(fullPath)
195221
}
196222
}()
197223
err = utils.CopyWithCtx(ctx, out, driver.NewLimitedUploadStream(ctx, stream), stream.GetSize(), up)
@@ -202,13 +228,16 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream
202228
}
203229

204230
func (d *SMB) GetDetails(ctx context.Context) (*model.StorageDetails, error) {
205-
if err := d.checkConn(ctx); err != nil {
231+
fs, release, err := d.acquireConn(ctx)
232+
if err != nil {
206233
return nil, err
207234
}
208-
stat, err := d.fs.Statfs(d.RootFolderPath)
235+
defer release()
236+
stat, err := fs.Statfs(d.RootFolderPath)
209237
if err != nil {
210238
return nil, err
211239
}
240+
d.updateLastConnTime()
212241
total := int64(stat.BlockSize() * stat.TotalBlockCount())
213242
free := int64(stat.BlockSize() * stat.AvailableBlockCount())
214243
return &model.StorageDetails{

drivers/smb/util.go

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package smb
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io/fs"
8+
"net"
79
"os"
810
"path/filepath"
911
"sync/atomic"
@@ -29,37 +31,103 @@ func (d *SMB) getLastConnTime() time.Time {
2931

3032
func (d *SMB) initFS(ctx context.Context) error {
3133
_, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (any, error) {
32-
return nil, d._initFS(ctx)
34+
d.connMu.Lock()
35+
defer d.connMu.Unlock()
36+
return nil, d.initFSLocked(ctx)
3337
})
3438
return err
3539
}
40+
3641
func (d *SMB) _initFS(ctx context.Context) error {
42+
d.connMu.Lock()
43+
defer d.connMu.Unlock()
44+
return d.initFSLocked(ctx)
45+
}
46+
47+
func (d *SMB) initFSLocked(ctx context.Context) error {
48+
_ = d.closeFSLocked()
3749
dialer := &smb2.Dialer{
3850
Initiator: &smb2.NTLMInitiator{
3951
User: d.Username,
4052
Password: d.Password,
4153
},
4254
}
43-
s, err := dialer.Dial(ctx, d.Address)
55+
conn, err := net.Dial("tcp", d.Address)
4456
if err != nil {
4557
return err
4658
}
47-
d.fs, err = s.Mount(d.ShareName)
59+
s, err := dialer.DialConn(ctx, conn, d.Address)
4860
if err != nil {
61+
_ = conn.Close()
4962
return err
5063
}
64+
fs, err := s.Mount(d.ShareName)
65+
if err != nil {
66+
_ = s.Logoff()
67+
_ = conn.Close()
68+
return err
69+
}
70+
d.conn = conn
71+
d.session = s
72+
d.fs = fs
5173
d.updateLastConnTime()
74+
return nil
75+
}
76+
77+
func (d *SMB) closeFS() error {
78+
d.connMu.Lock()
79+
defer d.connMu.Unlock()
80+
return d.closeFSLocked()
81+
}
82+
83+
func (d *SMB) closeFSLocked() error {
84+
var err error
85+
if d.fs != nil {
86+
err = errors.Join(err, d.fs.Umount())
87+
d.fs = nil
88+
}
89+
if d.session != nil {
90+
err = errors.Join(err, d.session.Logoff())
91+
d.session = nil
92+
}
93+
if d.conn != nil {
94+
err = errors.Join(err, d.conn.Close())
95+
d.conn = nil
96+
}
97+
d.cleanLastConnTime()
5298
return err
5399
}
54100

55101
func (d *SMB) checkConn(ctx context.Context) error {
56-
if time.Since(d.getLastConnTime()) < 5*time.Minute {
57-
return nil
102+
_, release, err := d.acquireConn(ctx)
103+
if release != nil {
104+
release()
58105
}
59-
if d.fs != nil {
60-
_ = d.fs.Umount()
106+
return err
107+
}
108+
109+
func (d *SMB) acquireConn(ctx context.Context) (*smb2.Share, func(), error) {
110+
d.connMu.Lock()
111+
defer d.connMu.Unlock()
112+
113+
if d.fs == nil || (time.Since(d.getLastConnTime()) >= 5*time.Minute && d.activeOps == 0) {
114+
if err := d.initFSLocked(ctx); err != nil {
115+
return nil, nil, err
116+
}
117+
}
118+
if d.fs == nil {
119+
return nil, nil, errors.New("smb share is not initialized")
120+
}
121+
d.activeOps++
122+
return d.fs, d.releaseConn, nil
123+
}
124+
125+
func (d *SMB) releaseConn() {
126+
d.connMu.Lock()
127+
defer d.connMu.Unlock()
128+
if d.activeOps > 0 {
129+
d.activeOps--
61130
}
62-
return d.initFS(ctx)
63131
}
64132

65133
// CopyFile File copies a single file from src to dst

0 commit comments

Comments
 (0)