@@ -19,6 +19,7 @@ import (
1919 "github.com/QuantumNous/new-api/service"
2020
2121 "github.com/gin-gonic/gin"
22+ "gorm.io/gorm"
2223)
2324
2425type OpenAIModel struct {
@@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) {
6869 }
6970}
7071
72+ func applyChannelStatusFilter (query * gorm.DB , statusFilter int ) * gorm.DB {
73+ if statusFilter == common .ChannelStatusEnabled {
74+ return query .Where ("status = ?" , common .ChannelStatusEnabled )
75+ }
76+ if statusFilter == 0 {
77+ return query .Where ("status != ?" , common .ChannelStatusEnabled )
78+ }
79+ return query
80+ }
81+
82+ func buildChannelListQuery (group string , statusFilter int , typeFilter int ) * gorm.DB {
83+ query := model .DB .Model (& model.Channel {})
84+ query = model .ApplyChannelGroupFilter (query , group )
85+ query = applyChannelStatusFilter (query , statusFilter )
86+ if typeFilter >= 0 {
87+ query = query .Where ("type = ?" , typeFilter )
88+ }
89+ return query
90+ }
91+
7192func GetAllChannels (c * gin.Context ) {
7293 pageInfo := common .GetPageQuery (c )
7394 channelData := make ([]* model.Channel , 0 )
7495 idSort , _ := strconv .ParseBool (c .Query ("id_sort" ))
7596 sortOptions := model .NewChannelSortOptions (c .Query ("sort_by" ), c .Query ("sort_order" ), idSort )
7697 enableTagMode , _ := strconv .ParseBool (c .Query ("tag_mode" ))
98+ groupFilter := model .NormalizeChannelGroupFilter (c .Query ("group" ))
7799 statusParam := c .Query ("status" )
78100 // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
79101 statusFilter := parseStatusFilter (statusParam )
@@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) {
85107 typeFilter = t
86108 }
87109 }
88- // group filter
89- groupFilter := c .Query ("group" )
90110
91111 var total int64
92112
93113 if enableTagMode {
94- tags , err := model .GetPaginatedTags ( pageInfo .GetStartIdx (), pageInfo .GetPageSize ())
114+ tags , err := model .GetPaginatedChannelTags ( buildChannelListQuery ( groupFilter , statusFilter , typeFilter ), pageInfo .GetStartIdx (), pageInfo .GetPageSize ())
95115 if err != nil {
96116 common .SysError ("failed to get paginated tags: " + err .Error ())
97117 c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取标签失败,请稍后重试" })
98118 return
99119 }
120+ total , err = model .CountChannelTags (buildChannelListQuery (groupFilter , statusFilter , typeFilter ))
121+ if err != nil {
122+ common .SysError ("failed to count tags: " + err .Error ())
123+ c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取标签数量失败,请稍后重试" })
124+ return
125+ }
100126 for _ , tag := range tags {
101127 if tag == nil || * tag == "" {
102128 continue
103129 }
104- tagChannels , err := model .GetChannelsByTag (* tag , idSort , false , sortOptions )
130+ var tagChannels []* model.Channel
131+ err := sortOptions .Apply (buildChannelListQuery (groupFilter , statusFilter , typeFilter ).Where ("tag = ?" , * tag )).
132+ Omit ("key" ).
133+ Find (& tagChannels ).Error
105134 if err != nil {
106- continue
107- }
108- filtered := make ([]* model.Channel , 0 )
109- for _ , ch := range tagChannels {
110- if statusFilter == common .ChannelStatusEnabled && ch .Status != common .ChannelStatusEnabled {
111- continue
112- }
113- if statusFilter == 0 && ch .Status == common .ChannelStatusEnabled {
114- continue
115- }
116- if typeFilter >= 0 && ch .Type != typeFilter {
117- continue
118- }
119- if groupFilter != "" && groupFilter != "null" {
120- if ! strings .Contains ("," + ch .Group + "," , "," + groupFilter + "," ) {
121- continue
122- }
123- }
124- filtered = append (filtered , ch )
135+ common .SysError ("failed to get channels by tag: " + err .Error ())
136+ c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取标签渠道失败,请稍后重试" })
137+ return
125138 }
126- channelData = append (channelData , filtered ... )
139+ channelData = append (channelData , tagChannels ... )
127140 }
128- total , _ = model .CountAllTags ()
129141 } else {
130- baseQuery := model .DB .Model (& model.Channel {})
131- if typeFilter >= 0 {
132- baseQuery = baseQuery .Where ("type = ?" , typeFilter )
133- }
134- if statusFilter == common .ChannelStatusEnabled {
135- baseQuery = baseQuery .Where ("status = ?" , common .ChannelStatusEnabled )
136- } else if statusFilter == 0 {
137- baseQuery = baseQuery .Where ("status != ?" , common .ChannelStatusEnabled )
138- }
139- if groupFilter != "" && groupFilter != "null" {
140- if common .UsingMySQL {
141- baseQuery = baseQuery .Where ("CONCAT(',', `group`, ',') LIKE ?" , "%," + groupFilter + ",%" )
142- } else {
143- // SQLite, PostgreSQL
144- baseQuery = baseQuery .Where ("(',' || \" group\" || ',') LIKE ?" , "%," + groupFilter + ",%" )
145- }
142+ if err := buildChannelListQuery (groupFilter , statusFilter , typeFilter ).Count (& total ).Error ; err != nil {
143+ common .SysError ("failed to count channels: " + err .Error ())
144+ c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取渠道数量失败,请稍后重试" })
145+ return
146146 }
147147
148- baseQuery .Count (& total )
149-
150- err := sortOptions .Apply (baseQuery ).Limit (pageInfo .GetPageSize ()).Offset (pageInfo .GetStartIdx ()).Omit ("key" ).Find (& channelData ).Error
148+ err := sortOptions .Apply (buildChannelListQuery (groupFilter , statusFilter , typeFilter )).
149+ Limit (pageInfo .GetPageSize ()).
150+ Offset (pageInfo .GetStartIdx ()).
151+ Omit ("key" ).
152+ Find (& channelData ).Error
151153 if err != nil {
152154 common .SysError ("failed to get channels: " + err .Error ())
153155 c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取渠道列表失败,请稍后重试" })
@@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) {
159161 clearChannelInfo (datum )
160162 }
161163
162- countQuery := model .DB .Model (& model.Channel {})
163- if statusFilter == common .ChannelStatusEnabled {
164- countQuery = countQuery .Where ("status = ?" , common .ChannelStatusEnabled )
165- } else if statusFilter == 0 {
166- countQuery = countQuery .Where ("status != ?" , common .ChannelStatusEnabled )
167- }
164+ countQuery := buildChannelListQuery (groupFilter , statusFilter , - 1 )
168165 var results []struct {
169166 Type int64
170167 Count int64
171168 }
172- _ = countQuery .Select ("type, count(*) as count" ).Group ("type" ).Find (& results ).Error
169+ if err := countQuery .Select ("type, count(*) as count" ).Group ("type" ).Find (& results ).Error ; err != nil {
170+ common .SysError ("failed to count channel types: " + err .Error ())
171+ c .JSON (http .StatusOK , gin.H {"success" : false , "message" : "获取渠道类型统计失败,请稍后重试" })
172+ return
173+ }
173174 typeCounts := make (map [int64 ]int64 )
174175 for _ , r := range results {
175176 typeCounts [r .Type ] = r .Count
@@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) {
277278 }
278279 for _ , tag := range tags {
279280 if tag != nil && * tag != "" {
280- tagChannel , err := model .GetChannelsByTag (* tag , idSort , false , sortOptions )
281- if err == nil {
282- channelData = append (channelData , tagChannel ... )
281+ var tagChannels []* model.Channel
282+ err := sortOptions .Apply (buildChannelListQuery (group , - 1 , - 1 ).Where ("tag = ?" , * tag )).
283+ Omit ("key" ).
284+ Find (& tagChannels ).Error
285+ if err != nil {
286+ c .JSON (http .StatusOK , gin.H {
287+ "success" : false ,
288+ "message" : err .Error (),
289+ })
290+ return
283291 }
292+ channelData = append (channelData , tagChannels ... )
284293 }
285294 }
286295 } else {
0 commit comments