Skip to content

Commit 2d968c3

Browse files
authored
fix: apply group filter to channel list queries (#4885)
1 parent cb7a614 commit 2d968c3

2 files changed

Lines changed: 119 additions & 90 deletions

File tree

controller/channel.go

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/QuantumNous/new-api/service"
2020

2121
"github.com/gin-gonic/gin"
22+
"gorm.io/gorm"
2223
)
2324

2425
type OpenAIModel struct {
@@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) {
6869
}
6970
}
7071

72+
func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB {
73+
if statusFilter == common.ChannelStatusEnabled {
74+
return query.Where("status = ?", common.ChannelStatusEnabled)
75+
}
76+
if statusFilter == 0 {
77+
return query.Where("status != ?", common.ChannelStatusEnabled)
78+
}
79+
return query
80+
}
81+
82+
func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB {
83+
query := model.DB.Model(&model.Channel{})
84+
query = model.ApplyChannelGroupFilter(query, group)
85+
query = applyChannelStatusFilter(query, statusFilter)
86+
if typeFilter >= 0 {
87+
query = query.Where("type = ?", typeFilter)
88+
}
89+
return query
90+
}
91+
7192
func GetAllChannels(c *gin.Context) {
7293
pageInfo := common.GetPageQuery(c)
7394
channelData := make([]*model.Channel, 0)
7495
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
7596
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
7697
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
98+
groupFilter := model.NormalizeChannelGroupFilter(c.Query("group"))
7799
statusParam := c.Query("status")
78100
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
79101
statusFilter := parseStatusFilter(statusParam)
@@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) {
85107
typeFilter = t
86108
}
87109
}
88-
// group filter
89-
groupFilter := c.Query("group")
90110

91111
var total int64
92112

93113
if enableTagMode {
94-
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
114+
tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize())
95115
if err != nil {
96116
common.SysError("failed to get paginated tags: " + err.Error())
97117
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
98118
return
99119
}
120+
total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter))
121+
if err != nil {
122+
common.SysError("failed to count tags: " + err.Error())
123+
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"})
124+
return
125+
}
100126
for _, tag := range tags {
101127
if tag == nil || *tag == "" {
102128
continue
103129
}
104-
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
130+
var tagChannels []*model.Channel
131+
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)).
132+
Omit("key").
133+
Find(&tagChannels).Error
105134
if err != nil {
106-
continue
107-
}
108-
filtered := make([]*model.Channel, 0)
109-
for _, ch := range tagChannels {
110-
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
111-
continue
112-
}
113-
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
114-
continue
115-
}
116-
if typeFilter >= 0 && ch.Type != typeFilter {
117-
continue
118-
}
119-
if groupFilter != "" && groupFilter != "null" {
120-
if !strings.Contains(","+ch.Group+",", ","+groupFilter+",") {
121-
continue
122-
}
123-
}
124-
filtered = append(filtered, ch)
135+
common.SysError("failed to get channels by tag: " + err.Error())
136+
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"})
137+
return
125138
}
126-
channelData = append(channelData, filtered...)
139+
channelData = append(channelData, tagChannels...)
127140
}
128-
total, _ = model.CountAllTags()
129141
} else {
130-
baseQuery := model.DB.Model(&model.Channel{})
131-
if typeFilter >= 0 {
132-
baseQuery = baseQuery.Where("type = ?", typeFilter)
133-
}
134-
if statusFilter == common.ChannelStatusEnabled {
135-
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
136-
} else if statusFilter == 0 {
137-
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
138-
}
139-
if groupFilter != "" && groupFilter != "null" {
140-
if common.UsingMySQL {
141-
baseQuery = baseQuery.Where("CONCAT(',', `group`, ',') LIKE ?", "%,"+groupFilter+",%")
142-
} else {
143-
// SQLite, PostgreSQL
144-
baseQuery = baseQuery.Where("(',' || \"group\" || ',') LIKE ?", "%,"+groupFilter+",%")
145-
}
142+
if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil {
143+
common.SysError("failed to count channels: " + err.Error())
144+
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"})
145+
return
146146
}
147147

148-
baseQuery.Count(&total)
149-
150-
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
148+
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)).
149+
Limit(pageInfo.GetPageSize()).
150+
Offset(pageInfo.GetStartIdx()).
151+
Omit("key").
152+
Find(&channelData).Error
151153
if err != nil {
152154
common.SysError("failed to get channels: " + err.Error())
153155
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
@@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) {
159161
clearChannelInfo(datum)
160162
}
161163

162-
countQuery := model.DB.Model(&model.Channel{})
163-
if statusFilter == common.ChannelStatusEnabled {
164-
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
165-
} else if statusFilter == 0 {
166-
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
167-
}
164+
countQuery := buildChannelListQuery(groupFilter, statusFilter, -1)
168165
var results []struct {
169166
Type int64
170167
Count int64
171168
}
172-
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
169+
if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil {
170+
common.SysError("failed to count channel types: " + err.Error())
171+
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"})
172+
return
173+
}
173174
typeCounts := make(map[int64]int64)
174175
for _, r := range results {
175176
typeCounts[r.Type] = r.Count
@@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) {
277278
}
278279
for _, tag := range tags {
279280
if tag != nil && *tag != "" {
280-
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
281-
if err == nil {
282-
channelData = append(channelData, tagChannel...)
281+
var tagChannels []*model.Channel
282+
err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)).
283+
Omit("key").
284+
Find(&tagChannels).Error
285+
if err != nil {
286+
c.JSON(http.StatusOK, gin.H{
287+
"success": false,
288+
"message": err.Error(),
289+
})
290+
return
283291
}
292+
channelData = append(channelData, tagChannels...)
284293
}
285294
}
286295
} else {

model/channel.go

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,38 @@ func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) Ch
128128
return options
129129
}
130130

131+
func NormalizeChannelGroupFilter(group string) string {
132+
group = strings.TrimSpace(group)
133+
if group == "" || strings.EqualFold(group, "all") || strings.EqualFold(group, "null") {
134+
return ""
135+
}
136+
return group
137+
}
138+
139+
func channelGroupFilterCondition() string {
140+
if common.UsingMySQL {
141+
return `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ? ESCAPE '!'`
142+
}
143+
return `(',' || ` + commonGroupCol + ` || ',') LIKE ? ESCAPE '!'`
144+
}
145+
146+
func channelGroupFilterPattern(group string) string {
147+
group = strings.NewReplacer(
148+
"!", "!!",
149+
"%", "!%",
150+
"_", "!_",
151+
).Replace(group)
152+
return "%," + group + ",%"
153+
}
154+
155+
func ApplyChannelGroupFilter(query *gorm.DB, group string) *gorm.DB {
156+
group = NormalizeChannelGroupFilter(group)
157+
if group == "" {
158+
return query
159+
}
160+
return query.Where(channelGroupFilterCondition(), channelGroupFilterPattern(group))
161+
}
162+
131163
// Value implements driver.Valuer interface
132164
func (c ChannelInfo) Value() (driver.Value, error) {
133165
return common.Marshal(&c)
@@ -365,25 +397,12 @@ func SearchChannels(keyword string, group string, model string, idSort bool, sor
365397
baseQuery := DB.Model(&Channel{}).Omit("key")
366398

367399
// 构造WHERE子句
368-
var whereClause string
369-
var args []interface{}
370-
if group != "" && group != "null" {
371-
var groupCondition string
372-
if common.UsingMySQL {
373-
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
374-
} else {
375-
// sqlite, PostgreSQL
376-
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
377-
}
378-
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
379-
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
380-
} else {
381-
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
382-
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
383-
}
400+
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
401+
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
402+
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
384403

385404
// 执行查询
386-
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
405+
err := order.Apply(baseQuery).Find(&channels).Error
387406
if err != nil {
388407
return nil, err
389408
}
@@ -828,8 +847,18 @@ func DeleteDisabledChannel() (int64, error) {
828847
}
829848

830849
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
850+
return GetPaginatedChannelTags(DB.Model(&Channel{}), offset, limit)
851+
}
852+
853+
func GetPaginatedChannelTags(query *gorm.DB, offset int, limit int) ([]*string, error) {
831854
var tags []*string
832-
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
855+
err := query.
856+
Select("DISTINCT tag").
857+
Where("tag is not null AND tag != ''").
858+
Order(clause.OrderByColumn{Column: clause.Column{Name: "tag"}}).
859+
Offset(offset).
860+
Limit(limit).
861+
Find(&tags).Error
833862
return tags, err
834863
}
835864

@@ -857,24 +886,11 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
857886
baseQuery := DB.Model(&Channel{}).Omit("key")
858887

859888
// 构造WHERE子句
860-
var whereClause string
861-
var args []interface{}
862-
if group != "" && group != "null" {
863-
var groupCondition string
864-
if common.UsingMySQL {
865-
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
866-
} else {
867-
// sqlite, PostgreSQL
868-
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
869-
}
870-
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
871-
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
872-
} else {
873-
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
874-
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
875-
}
889+
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
890+
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
891+
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
876892

877-
subQuery := baseQuery.Where(whereClause, args...).
893+
subQuery := baseQuery.
878894
Select("tag").
879895
Where("tag != ''").
880896
Order(order)
@@ -1015,8 +1031,12 @@ func CountAllChannels() (int64, error) {
10151031

10161032
// CountAllTags returns number of non-empty distinct tags
10171033
func CountAllTags() (int64, error) {
1034+
return CountChannelTags(DB.Model(&Channel{}))
1035+
}
1036+
1037+
func CountChannelTags(query *gorm.DB) (int64, error) {
10181038
var total int64
1019-
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
1039+
err := query.Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
10201040
return total, err
10211041
}
10221042

0 commit comments

Comments
 (0)