Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions api/handler/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package handler

import (
"errors"
"fmt"
"log/slog"
"net/http"
"slices"
"strconv"

"opencsg.com/csghub-server/common/errorx"
Expand All @@ -30,6 +32,69 @@ type UserHandler struct {
user component.UserComponent
}

func hasUserRepoFilterQuery(ctx *gin.Context) bool {
filterKeys := []string{
"search", "sort", "source", "status", "xnet_migration_status",
"dataset_type", "user_purchased", "list_serverless",
"tag_category", "tag_name", "tag_group",
"model_tree",
"model_params_min", "model_params_max",
"repo_size_min", "repo_size_max",
}
for _, key := range filterKeys {
if _, ok := ctx.GetQuery(key); ok {
return true
}
}
return false
}

func parseUserRepoFilter(ctx *gin.Context, owner, currentUser string) (*types.RepoFilter, error) {
if !hasUserRepoFilterQuery(ctx) {
return nil, nil
}

filter := &types.RepoFilter{
Tags: parseTagReqs(ctx),
Owner: owner,
Username: currentUser,
}

tree, err := parseTreeReqs(ctx)
if err != nil {
return nil, err
}
filter.Tree = tree
filter = getFilterFromContext(ctx, filter)
filter.SpaceSDK = ctx.Query("sdk")

if listServerless, err := strconv.ParseBool(ctx.Query("list_serverless")); err == nil {
filter.ListServerless = listServerless
}

filter.ModelParamsMin, filter.ModelParamsMax, err = parseFloatRangeFromContext(ctx, "model_params_min", "model_params_max")
if err != nil {
return nil, errorx.ReqParamInvalid(err, errorx.Ctx().Set("query", "model_params_range"))
}

filter.RepoSizeMin, filter.RepoSizeMax, err = parseInt64RangeFromContext(ctx, "repo_size_min", "repo_size_max")
if err != nil {
return nil, errorx.ReqParamInvalid(err, errorx.Ctx().Set("query", "repo_size_range"))
}

if !slices.Contains(types.Sorts, filter.Sort) {
err := fmt.Errorf("sort parameter must be one of %v", types.Sorts)
return nil, errorx.ReqParamInvalid(err, errorx.Ctx().Set("query", "sort_filter"))
}

if filter.Source != "" && !slices.Contains(types.Sources, filter.Source) {
err := fmt.Errorf("source parameter must be one of %v", types.Sources)
return nil, errorx.ReqParamInvalid(err, errorx.Ctx().Set("query", "source_filter"))
}

return filter, nil
}

// GetUserDatasets godoc
// @Security ApiKey
// @Summary Get user datasets
Expand All @@ -53,6 +118,12 @@ func (h *UserHandler) Datasets(ctx *gin.Context) {

req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user datasets filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
ds, total, err := h.user.Datasets(ctx.Request.Context(), &req)
Expand Down Expand Up @@ -94,6 +165,12 @@ func (h *UserHandler) Models(ctx *gin.Context) {

req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user models filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
ms, total, err := h.user.Models(ctx.Request.Context(), &req)
Expand Down Expand Up @@ -136,6 +213,12 @@ func (h *UserHandler) Codes(ctx *gin.Context) {

req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user codes filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
ms, total, err := h.user.Codes(ctx.Request.Context(), &req)
Expand Down Expand Up @@ -178,6 +261,12 @@ func (h *UserHandler) Spaces(ctx *gin.Context) {
req.SDK = ctx.Query("sdk")
req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user spaces filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
ms, total, err := h.user.Spaces(ctx.Request.Context(), &req)
Expand Down Expand Up @@ -950,6 +1039,12 @@ func (h *UserHandler) MCPServers(ctx *gin.Context) {

req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user mcp servers filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
mcps, total, err := h.user.MCPServers(ctx.Request.Context(), &req)
Expand Down Expand Up @@ -992,6 +1087,12 @@ func (h *UserHandler) Skills(ctx *gin.Context) {

req.Owner = ctx.Param("username")
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Filter, err = parseUserRepoFilter(ctx, req.Owner, req.CurrentUser)
if err != nil {
slog.ErrorContext(ctx.Request.Context(), "Bad user skills filter request format", "error", err)
httpbase.BadRequestWithExt(ctx, err)
return
}
req.Page = page
req.PageSize = per
skills, total, err := h.user.Skills(ctx.Request.Context(), &req)
Expand Down
75 changes: 75 additions & 0 deletions api/handler/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,55 @@ func TestUserHandler_Models(t *testing.T) {
})
}

func TestUserHandler_Models_WithRepoFilters(t *testing.T) {
tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc {
return h.Models
})

modelParamsMin := 1.5
modelParamsMax := 9.5
tester.mocks.user.EXPECT().Models(tester.Ctx(), &types.UserDatasetsReq{
Owner: "go",
CurrentUser: "u",
Filter: &types.RepoFilter{
Owner: "go",
Username: "u",
Search: "hello",
Sort: "trending",
Source: "local",
ListServerless: true,
Tags: []types.TagReq{{
Category: "runtime_framework",
Group: "inference",
}},
ModelParamsMin: &modelParamsMin,
ModelParamsMax: &modelParamsMax,
},
PageOpts: types.PageOpts{
Page: 1,
PageSize: 10,
},
}).Return([]types.Model{{Name: "ds"}}, 100, nil)
tester.AddPagination(1, 10).
WithUser().
WithParam("username", "go").
WithQuery("search", "hello").
WithQuery("sort", "trending").
WithQuery("source", "local").
WithQuery("list_serverless", "true").
WithQuery("tag_category", "runtime_framework").
WithQuery("tag_group", "inference").
WithQuery("tag_name", "").
WithQuery("model_params_min", "1.5").
WithQuery("model_params_max", "9.5").
Execute()
tester.ResponseEqSimple(t, 200, gin.H{
"message": "OK",
"data": []types.Model{{Name: "ds"}},
"total": 100,
})
}

func TestUserHandler_Codes(t *testing.T) {
tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc {
return h.Codes
Expand All @@ -97,6 +146,32 @@ func TestUserHandler_Codes(t *testing.T) {
})
}

func TestUserHandler_Spaces_WithSDKOnlyKeepsLegacyRequest(t *testing.T) {
tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc {
return h.Spaces
})

tester.mocks.user.EXPECT().Spaces(tester.Ctx(), &types.UserSpacesReq{
SDK: "gradio",
Owner: "go",
CurrentUser: "u",
PageOpts: types.PageOpts{
Page: 1,
PageSize: 10,
},
}).Return([]types.Space{{Name: "ds"}}, 100, nil)
tester.AddPagination(1, 10).
WithUser().
WithParam("username", "go").
WithQuery("sdk", "gradio").
Execute()
tester.ResponseEqSimple(t, 200, gin.H{
"message": "OK",
"data": []types.Space{{Name: "ds"}},
"total": 100,
})
}

func TestUserHandler_Spaces(t *testing.T) {
tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc {
return h.Spaces
Expand Down
17 changes: 17 additions & 0 deletions builder/store/database/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ var (
redisOnce sync.Once
)

func escapeLikePattern(value string) string {
var builder strings.Builder
for _, r := range value {
if r == '\\' || r == '%' || r == '_' {
builder.WriteRune('\\')
}
builder.WriteRune(r)
}
return builder.String()
}

type repoStoreImpl struct {
config *config.Config
db *DB
Expand Down Expand Up @@ -677,6 +688,9 @@ func (s *repoStoreImpl) PublicToUser(ctx context.Context, repoType types.Reposit
Relation("Tags")

q.Where("repository.repository_type = ?", repoType)
if filter.Owner != "" {
q.Where("repository.path LIKE ? ESCAPE '\\'", fmt.Sprintf("%s/%%", escapeLikePattern(filter.Owner)))
}

switch repoType {
case types.ModelRepo:
Expand Down Expand Up @@ -842,6 +856,9 @@ func (s *repoStoreImpl) publicToUserTrending(ctx context.Context, repoType types
Join("JOIN repositories AS r ON r.id = rrs.repository_id").
Where("rrs.weight_name = ?", RecomWeightTotal).
Where("r.repository_type = ?", repoType)
if filter.Owner != "" {
q.Where("r.path LIKE ? ESCAPE '\\'", fmt.Sprintf("%s/%%", escapeLikePattern(filter.Owner)))
}

// Join with business table
q.Join(fmt.Sprintf("INNER JOIN %s ON %s.repository_id = r.id", bizTable, bizTable))
Expand Down
38 changes: 38 additions & 0 deletions builder/store/database/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,44 @@ func TestRepoStore_PublicToUser(t *testing.T) {
}
}

func TestRepoStore_PublicToUserOwnerFilter(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

store := database.NewRepoStoreWithDB(db)
repos := []database.Repository{
{Name: "match", Path: "owner/match", GitPath: "owner/match", UserID: 123, RepositoryType: types.CodeRepo},
{Name: "other", Path: "other/match", GitPath: "other/match", UserID: 123, RepositoryType: types.CodeRepo},
{Name: "underscore-match", Path: "own_er/match", GitPath: "own_er/match", UserID: 123, RepositoryType: types.CodeRepo},
{Name: "underscore-other", Path: "ownerX/match", GitPath: "ownerX/match", UserID: 123, RepositoryType: types.CodeRepo},
}
for _, repo := range repos {
created, err := store.CreateRepo(ctx, repo)
require.Nil(t, err)
_, err = db.Core.NewInsert().Model(&database.Code{RepositoryID: created.ID}).Exec(ctx)
require.Nil(t, err)
}

rs, count, err := store.PublicToUser(ctx, types.CodeRepo, []int64{123}, &types.RepoFilter{
Owner: "owner",
Sort: "recently_update",
}, 10, 1, false)
require.Nil(t, err)
require.Equal(t, 1, count)
require.Len(t, rs, 1)
require.Equal(t, "owner/match", rs[0].Path)

rs, count, err = store.PublicToUser(ctx, types.CodeRepo, []int64{123}, &types.RepoFilter{
Owner: "own_er",
Sort: "recently_update",
}, 10, 1, false)
require.Nil(t, err)
require.Equal(t, 1, count)
require.Len(t, rs, 1)
require.Equal(t, "own_er/match", rs[0].Path)
}

func TestRepoStore_PublicToUserRangeFilters(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
Expand Down
1 change: 1 addition & 0 deletions common/types/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ type RepoFilter struct {
Sort string
Search string
Source string
Owner string
Username string
Tree *TreeReq
ListServerless bool
Expand Down
2 changes: 2 additions & 0 deletions common/types/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,15 @@ type UpdateAPIKeyRequest struct {
type UserDatasetsReq struct {
Owner string `json:"owner"`
CurrentUser string `json:"current_user"`
Filter *RepoFilter
PageOpts
}

type UserSpacesReq struct {
SDK string `json:"sdk"`
Owner string `json:"owner"`
CurrentUser string `json:"current_user"`
Filter *RepoFilter
PageOpts
}

Expand Down
Loading
Loading