Skip to content

Commit 8a878fa

Browse files
committed
feat(func): support virtual host
1 parent 12bd5a0 commit 8a878fa

6 files changed

Lines changed: 282 additions & 13 deletions

File tree

internal/conf/const.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,5 @@ const (
192192
SharingIDKey
193193
SkipHookKey
194194
VirtualHostKey
195+
VhostPrefixKey
195196
)

internal/op/virtual_host.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package op
2+
3+
import (
4+
"time"
5+
6+
"github.com/OpenListTeam/OpenList/v4/internal/db"
7+
"github.com/OpenListTeam/OpenList/v4/internal/model"
8+
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
9+
"github.com/OpenListTeam/go-cache"
10+
"github.com/pkg/errors"
11+
"gorm.io/gorm"
12+
)
13+
14+
var vhostCache = cache.NewMemCache(cache.WithShards[*model.VirtualHost](2))
15+
16+
// GetVirtualHostByDomain 根据域名获取虚拟主机配置(带缓存)
17+
func GetVirtualHostByDomain(domain string) (*model.VirtualHost, error) {
18+
if v, ok := vhostCache.Get(domain); ok {
19+
if v == nil {
20+
utils.Log.Infof("[VirtualHost] cache hit (nil) for domain=%q", domain)
21+
return nil, errors.New("virtual host not found")
22+
}
23+
utils.Log.Infof("[VirtualHost] cache hit for domain=%q id=%d", domain, v.ID)
24+
return v, nil
25+
}
26+
utils.Log.Infof("[VirtualHost] cache miss for domain=%q, querying db...", domain)
27+
v, err := db.GetVirtualHostByDomain(domain)
28+
if err != nil {
29+
if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) {
30+
utils.Log.Infof("[VirtualHost] domain=%q not found in db, caching nil", domain)
31+
vhostCache.Set(domain, nil, cache.WithEx[*model.VirtualHost](time.Minute*5))
32+
return nil, errors.New("virtual host not found")
33+
}
34+
utils.Log.Errorf("[VirtualHost] db error for domain=%q: %v", domain, err)
35+
return nil, err
36+
}
37+
utils.Log.Infof("[VirtualHost] db found domain=%q id=%d enabled=%v web_hosting=%v", domain, v.ID, v.Enabled, v.WebHosting)
38+
vhostCache.Set(domain, v, cache.WithEx[*model.VirtualHost](time.Hour))
39+
return v, nil
40+
}
41+
42+
func GetVirtualHostById(id uint) (*model.VirtualHost, error) {
43+
return db.GetVirtualHostById(id)
44+
}
45+
46+
func CreateVirtualHost(v *model.VirtualHost) error {
47+
v.Path = utils.FixAndCleanPath(v.Path)
48+
vhostCache.Del(v.Domain)
49+
return db.CreateVirtualHost(v)
50+
}
51+
52+
func UpdateVirtualHost(v *model.VirtualHost) error {
53+
v.Path = utils.FixAndCleanPath(v.Path)
54+
old, err := db.GetVirtualHostById(v.ID)
55+
if err != nil {
56+
return err
57+
}
58+
// 如果域名变更,清除旧域名缓存
59+
vhostCache.Del(old.Domain)
60+
vhostCache.Del(v.Domain)
61+
return db.UpdateVirtualHost(v)
62+
}
63+
64+
func DeleteVirtualHostById(id uint) error {
65+
old, err := db.GetVirtualHostById(id)
66+
if err != nil {
67+
return err
68+
}
69+
vhostCache.Del(old.Domain)
70+
return db.DeleteVirtualHostById(id)
71+
}
72+
73+
func GetVirtualHosts(pageIndex, pageSize int) ([]model.VirtualHost, int64, error) {
74+
return db.GetVirtualHosts(pageIndex, pageSize)
75+
}

server/handles/fsread.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ func FsListSplit(c *gin.Context) {
6868
SharingList(c, &req)
6969
return
7070
}
71+
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
72+
req.Path = applyVhostPathMapping(c, req.Path)
7173
user := c.Request.Context().Value(conf.UserKey).(*model.User)
7274
if user.IsGuest() && user.Disabled {
7375
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
@@ -273,6 +275,11 @@ func FsGetSplit(c *gin.Context) {
273275
SharingGet(c, &req)
274276
return
275277
}
278+
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
279+
// 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀
280+
var vhostPrefix string
281+
req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path)
282+
common.GinWithValue(c, conf.VhostPrefixKey, vhostPrefix)
276283
user := c.Request.Context().Value(conf.UserKey).(*model.User)
277284
if user.IsGuest() && user.Disabled {
278285
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
@@ -322,12 +329,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) {
322329
rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath)
323330
if rawURL == "" {
324331
query := ""
332+
// 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致
333+
downPath := stripVhostPrefix(c, reqPath)
325334
if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) {
326335
query = "?sign=" + sign.Sign(reqPath)
327336
}
328337
rawURL = fmt.Sprintf("%s/p%s%s",
329338
common.GetApiUrl(c),
330-
utils.EncodePath(reqPath, true),
339+
utils.EncodePath(downPath, true),
331340
query)
332341
}
333342
} else {
@@ -432,3 +441,60 @@ func FsOther(c *gin.Context) {
432441
}
433442
common.SuccessResp(c, res)
434443
}
444+
445+
// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。
446+
func applyVhostPathMapping(c *gin.Context, reqPath string) string {
447+
mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath)
448+
return mapped
449+
}
450+
451+
// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配虚拟主机规则,
452+
// 将请求路径映射到虚拟主机配置的实际路径,同时返回 vhost.Path 前缀(用于生成下载链接时去掉前缀)。
453+
// 例如:vhost.Path="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads")
454+
// 例如:vhost.Path="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads")
455+
// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "")
456+
func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) {
457+
rawHost := c.Request.Host
458+
domain := stripHostPortForVhost(rawHost)
459+
if domain == "" {
460+
return reqPath, ""
461+
}
462+
vhost, err := op.GetVirtualHostByDomain(domain)
463+
if err != nil || vhost == nil {
464+
return reqPath, ""
465+
}
466+
if !vhost.Enabled || vhost.WebHosting {
467+
// 未启用,或者是 Web 托管模式(Web 托管不做路径重映射)
468+
return reqPath, ""
469+
}
470+
// 路径重映射:将 reqPath 拼接到 vhost.Path 后面
471+
mapped := stdpath.Join(vhost.Path, reqPath)
472+
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
473+
return mapped, vhost.Path
474+
}
475+
476+
// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。
477+
// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。
478+
func stripVhostPrefix(c *gin.Context, path string) string {
479+
prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string)
480+
if !ok || prefix == "" {
481+
return path
482+
}
483+
if strings.HasPrefix(path, prefix+"/") {
484+
return path[len(prefix):]
485+
}
486+
if path == prefix {
487+
return "/"
488+
}
489+
return path
490+
}
491+
492+
// stripHostPortForVhost 去掉 host 中的端口号,返回纯域名
493+
func stripHostPortForVhost(host string) string {
494+
if idx := strings.LastIndex(host, ":"); idx != -1 {
495+
if !strings.Contains(host, "[") {
496+
return host[:idx]
497+
}
498+
}
499+
return host
500+
}

server/middlewares/down.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middlewares
22

33
import (
4+
stdpath "path"
45
"strings"
56

67
"github.com/OpenListTeam/OpenList/v4/internal/conf"
@@ -17,10 +18,46 @@ import (
1718

1819
func PathParse(c *gin.Context) {
1920
rawPath := parsePath(c.Param("path"))
21+
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
22+
// 例如:vhost.Path="/123pan/Downloads",rawPath="/tests.html" -> "/123pan/Downloads/tests.html"
23+
rawPath = applyDownVhostPathMapping(c, rawPath)
2024
common.GinWithValue(c, conf.PathKey, rawPath)
2125
c.Next()
2226
}
2327

28+
// applyDownVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,
29+
// 将下载/预览路由的路径映射到虚拟主机配置的实际路径。
30+
// 仅在虚拟主机启用且非 Web 托管模式时生效。
31+
func applyDownVhostPathMapping(c *gin.Context, reqPath string) string {
32+
rawHost := c.Request.Host
33+
domain := stripDownHostPort(rawHost)
34+
if domain == "" {
35+
return reqPath
36+
}
37+
vhost, err := op.GetVirtualHostByDomain(domain)
38+
if err != nil || vhost == nil {
39+
return reqPath
40+
}
41+
if !vhost.Enabled || vhost.WebHosting {
42+
// 未启用,或者是 Web 托管模式(Web 托管不做路径重映射)
43+
return reqPath
44+
}
45+
// 路径重映射:将 reqPath 拼接到 vhost.Path 后面
46+
mapped := stdpath.Join(vhost.Path, reqPath)
47+
utils.Log.Debugf("[VirtualHost] down path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
48+
return mapped
49+
}
50+
51+
// stripDownHostPort 去掉 host 中的端口号,返回纯域名
52+
func stripDownHostPort(host string) string {
53+
if idx := strings.LastIndex(host, ":"); idx != -1 {
54+
if !strings.Contains(host, "[") {
55+
return host[:idx]
56+
}
57+
}
58+
return host
59+
}
60+
2461
func Down(verifyFunc func(string, string) error) func(c *gin.Context) {
2562
return func(c *gin.Context) {
2663
rawPath := c.Request.Context().Value(conf.PathKey).(string)

server/router.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ func Init(e *gin.Engine) {
3636
g.GET("/i/:link_name", handles.Plist)
3737
common.SecretKey = []byte(conf.Conf.JwtSecret)
3838
g.Use(middlewares.StoragesLoaded)
39-
g.Use(middlewares.VirtualHost)
4039
if conf.Conf.MaxConnections > 0 {
4140
g.Use(middlewares.MaxAllowed(conf.Conf.MaxConnections))
4241
}

0 commit comments

Comments
 (0)