diff --git a/backend/api/api.go b/backend/api/api.go index 21716fedc..ad589e5dd 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -52,7 +52,7 @@ func Init( ckDB ck.Provider, translater i18n.ITranslater, ) (*apis.APIHandler, error) { - foundationHandler, err := apis.InitFoundationHandler(idgen, db, batchObjectStorage) + foundationHandler, err := apis.InitFoundationHandler(idgen, db, batchObjectStorage, configFactory) if err != nil { return nil, err } diff --git a/backend/api/handler/coze/loop/apis/wire.go b/backend/api/handler/coze/loop/apis/wire.go index 2008d8cb9..dd26d446b 100644 --- a/backend/api/handler/coze/loop/apis/wire.go +++ b/backend/api/handler/coze/loop/apis/wire.go @@ -90,6 +90,7 @@ func InitFoundationHandler( idgen idgen.IIDGenerator, db db.Provider, objectStorage fileserver.BatchObjectStorage, + configFactory conf.IConfigLoaderFactory, ) (*FoundationHandler, error) { wire.Build( foundationSet, diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index 971d1bbb4..e83641abc 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,7 +8,6 @@ package apis import ( "context" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" @@ -42,7 +41,7 @@ import ( // Injectors from wire.go: -func InitFoundationHandler(idgen2 idgen.IIDGenerator, db2 db.Provider, objectStorage fileserver.BatchObjectStorage) (*FoundationHandler, error) { +func InitFoundationHandler(idgen2 idgen.IIDGenerator, db2 db.Provider, objectStorage fileserver.BatchObjectStorage, configFactory conf.IConfigLoaderFactory) (*FoundationHandler, error) { authService, err := application.InitAuthApplication(idgen2, db2) if err != nil { return nil, err @@ -55,7 +54,7 @@ func InitFoundationHandler(idgen2 idgen.IIDGenerator, db2 db.Provider, objectSto if err != nil { return nil, err } - userService, err := application.InitUserApplication(idgen2, db2) + userService, err := application.InitUserApplication(idgen2, db2, configFactory) if err != nil { return nil, err } diff --git a/backend/go.sum b/backend/go.sum index 8bf231631..d133d5563 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -465,6 +465,7 @@ github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5/go.mod h1:vavhavw2zAx github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/backend/modules/foundation/application/user.go b/backend/modules/foundation/application/user.go index 1f7329e7c..ea44486e4 100644 --- a/backend/modules/foundation/application/user.go +++ b/backend/modules/foundation/application/user.go @@ -7,7 +7,9 @@ import ( "context" "net/mail" "strconv" + "strings" + "github.com/bytedance/gg/gptr" "github.com/bytedance/gg/gslice" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" @@ -17,31 +19,64 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/user/entity" "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/user/service" "github.com/coze-dev/coze-loop/backend/modules/foundation/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/conf" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/conv" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" + "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -type UserApplicationImpl struct { - userService service.IUserService -} - func NewUserApplication( userService service.IUserService, -) user.UserService { - return &UserApplicationImpl{ - userService: userService, + configFactory conf.IConfigLoaderFactory, +) (user.UserService, error) { + ua := &UserApplicationImpl{ + userService: userService, + registerController: userRegisterController{}, + } + if loader, err := configFactory.NewConfigLoader("foundation.yaml"); err == nil { + ua.registerController.configLoader = loader } + return ua, nil } -func (u *UserApplicationImpl) Register(ctx context.Context, request *user.UserRegisterRequest) (r *user.UserRegisterResponse, err error) { - if request.Email == nil || request.Password == nil { - return nil, errorx.NewByCode(errno.CommonInvalidParamCode) +type UserApplicationImpl struct { + userService service.IUserService + registerController userRegisterController +} + +type userRegisterController struct { + // configLoader weak dependency, might be nil + configLoader conf.IConfigLoader +} + +type userRegisterControlConfig struct { + Block bool `mapstructure:"block"` + AllowedEmails string `mapstructure:"allowed_emails"` +} + +func (u *userRegisterController) allowRegister(ctx context.Context, email string) bool { + if u.configLoader == nil { + return true + } + + const keyUserRegisterControl = "user_register_control" + var config userRegisterControlConfig + if err := u.configLoader.UnmarshalKey(ctx, keyUserRegisterControl, &config); err != nil { + logs.CtxWarn(ctx, "load user_register_control config fail, err: %v", err) + return false + } + + if !config.Block { + return true } + return slices.Contains(strings.Split(config.AllowedEmails, ";"), email) +} - if _, err = mail.ParseAddress(*request.Email); err != nil { - return nil, errorx.NewByCode(errno.CommonInvalidParamCode, errorx.WithExtraMsg("email is invalid")) +func (u *UserApplicationImpl) Register(ctx context.Context, request *user.UserRegisterRequest) (r *user.UserRegisterResponse, err error) { + if err := u.validateRegisterReq(ctx, request); err != nil { + return nil, err } userDO, err := u.userService.Create(ctx, &service.CreateUserRequest{ @@ -66,6 +101,22 @@ func (u *UserApplicationImpl) Register(ctx context.Context, request *user.UserRe return r, nil } +func (u *UserApplicationImpl) validateRegisterReq(ctx context.Context, request *user.UserRegisterRequest) error { + if request.Email == nil || request.Password == nil { + return errorx.NewByCode(errno.CommonInvalidParamCode) + } + + if _, err := mail.ParseAddress(gptr.Indirect(request.Email)); err != nil { + return errorx.NewByCode(errno.CommonInvalidParamCode, errorx.WithExtraMsg("email is invalid")) + } + + if !u.registerController.allowRegister(ctx, request.GetEmail()) { + return errorx.NewByCode(errno.UserRegistrationControlBlockCode) + } + + return nil +} + func (u *UserApplicationImpl) ResetPassword(ctx context.Context, request *user.ResetPasswordRequest) (r *user.ResetPasswordResponse, err error) { r = user.NewResetPasswordResponse() diff --git a/backend/modules/foundation/application/user_test.go b/backend/modules/foundation/application/user_test.go index d6f28e75c..b47415605 100644 --- a/backend/modules/foundation/application/user_test.go +++ b/backend/modules/foundation/application/user_test.go @@ -19,14 +19,376 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/user/service" servicemocks "github.com/coze-dev/coze-loop/backend/modules/foundation/domain/user/service/mocks" "github.com/coze-dev/coze-loop/backend/modules/foundation/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/conf" + confmocks "github.com/coze-dev/coze-loop/backend/pkg/conf/mocks" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/unittest" ) +func Test_userRegisterController_allowRegister(t *testing.T) { + type fields struct { + configLoader conf.IConfigLoader + } + type args struct { + ctx context.Context + email string + } + tests := []struct { + name string + fields func(ctrl *gomock.Controller) fields + args args + want bool + }{ + { + name: "configLoader is nil - return true (weak dependency)", + fields: func(ctrl *gomock.Controller) fields { + return fields{ + configLoader: nil, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: true, + }, + { + name: "config load fail - return false", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + Return(errors.New("config load error")) + return fields{ + configLoader: mockLoader, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: false, + }, + { + name: "block=false - allow all users", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = false + config.AllowedEmails = "" + return nil + }) + return fields{ + configLoader: mockLoader, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: true, + }, + { + name: "block=true and email in whitelist - allow", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "test@example.com;admin@example.com" + return nil + }) + return fields{ + configLoader: mockLoader, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: true, + }, + { + name: "block=true and email not in whitelist - deny", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "admin@example.com;allowed@example.com" + return nil + }) + return fields{ + configLoader: mockLoader, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: false, + }, + { + name: "block=true and empty whitelist - deny all", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "" + return nil + }) + return fields{ + configLoader: mockLoader, + } + }, + args: args{ + ctx: context.Background(), + email: "test@example.com", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fields := tt.fields(ctrl) + u := &userRegisterController{ + configLoader: fields.configLoader, + } + got := u.allowRegister(tt.args.ctx, tt.args.email) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUserApplicationImpl_validateRegisterReq(t *testing.T) { + type fields struct { + userService service.IUserService + registerController userRegisterController + } + type args struct { + ctx context.Context + request *user.UserRegisterRequest + } + tests := []struct { + name string + fields func(ctrl *gomock.Controller) fields + args args + wantErr error + }{ + { + name: "missing email", + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{}, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Password: ptr.Of("password123"), + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "missing password", + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{}, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "invalid email format", + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{}, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Email: ptr.Of("invalid-email"), + Password: ptr.Of("password123"), + }, + }, + wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), + }, + { + name: "registration blocked by control", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "admin@example.com" + return nil + }) + return fields{ + userService: nil, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + wantErr: errorx.NewByCode(errno.UserRegistrationControlBlockCode), + }, + { + name: "all validations pass with config loader", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = false + config.AllowedEmails = "" + return nil + }) + return fields{ + userService: nil, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + wantErr: nil, + }, + { + name: "all validations pass with nil config loader (weak dependency)", + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{ + configLoader: nil, + }, + } + }, + args: args{ + ctx: context.Background(), + request: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fields := tt.fields(ctrl) + u := &UserApplicationImpl{ + userService: fields.userService, + registerController: fields.registerController, + } + err := u.validateRegisterReq(tt.args.ctx, tt.args.request) + unittest.AssertErrorEqual(t, tt.wantErr, err) + }) + } +} + +func TestNewUserApplication(t *testing.T) { + type args struct { + userService service.IUserService + configFactory conf.IConfigLoaderFactory + } + tests := []struct { + name string + args func(ctrl *gomock.Controller) args + want user.UserService + wantErr error + }{ + { + name: "config loader return error - weak dependency", + args: func(ctrl *gomock.Controller) args { + mockUserService := servicemocks.NewMockIUserService(ctrl) + mockConfigFactory := confmocks.NewMockIConfigLoaderFactory(ctrl) + mockConfigFactory.EXPECT().NewConfigLoader("foundation.yaml"). + Return(nil, errors.New("config loader creation failed")) + return args{ + userService: mockUserService, + configFactory: mockConfigFactory, + } + }, + want: &UserApplicationImpl{}, + wantErr: nil, + }, + { + name: "success with config loader", + args: func(ctrl *gomock.Controller) args { + mockUserService := servicemocks.NewMockIUserService(ctrl) + mockConfigFactory := confmocks.NewMockIConfigLoaderFactory(ctrl) + mockConfigLoader := confmocks.NewMockIConfigLoader(ctrl) + mockConfigFactory.EXPECT().NewConfigLoader("foundation.yaml"). + Return(mockConfigLoader, nil) + return args{ + userService: mockUserService, + configFactory: mockConfigFactory, + } + }, + want: &UserApplicationImpl{}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + args := tt.args(ctrl) + got, err := NewUserApplication(args.userService, args.configFactory) + unittest.AssertErrorEqual(t, tt.wantErr, err) + if tt.wantErr == nil { + assert.NotNil(t, got) + impl, ok := got.(*UserApplicationImpl) + assert.True(t, ok) + assert.Equal(t, args.userService, impl.userService) + if tt.name == "config loader return error - weak dependency" { + assert.Nil(t, impl.registerController.configLoader) + } else { + assert.NotNil(t, impl.registerController.configLoader) + } + } else { + assert.Nil(t, got) + } + }) + } +} + func TestUserApplicationImpl_Register(t *testing.T) { type fields struct { - userService service.IUserService + userService service.IUserService + registerController userRegisterController } type args struct { ctx context.Context @@ -39,15 +401,18 @@ func TestUserApplicationImpl_Register(t *testing.T) { } tests := []struct { name string - fields fields + fields func(ctrl *gomock.Controller) fields args args want *user.UserRegisterResponse wantErr error }{ { name: "invalid email", - fields: fields{ - userService: nil, + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{}, + } }, args: args{ ctx: context.Background(), @@ -61,8 +426,11 @@ func TestUserApplicationImpl_Register(t *testing.T) { }, { name: "missing email", - fields: fields{ - userService: nil, + fields: func(ctrl *gomock.Controller) fields { + return fields{ + userService: nil, + registerController: userRegisterController{}, + } }, args: args{ ctx: context.Background(), @@ -73,18 +441,54 @@ func TestUserApplicationImpl_Register(t *testing.T) { want: nil, wantErr: errorx.NewByCode(errno.CommonInvalidParamCode), }, + { + name: "registration blocked by control", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "admin@example.com" + return nil + }) + return fields{ + userService: nil, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } + }, + args: args{ + ctx: context.Background(), + req: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + want: nil, + wantErr: errorx.NewByCode(errno.UserRegistrationControlBlockCode), + }, { name: "create user error", - fields: fields{ - userService: func() service.IUserService { - ctrl := gomock.NewController(t) - mockService := servicemocks.NewMockIUserService(ctrl) - mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ - Email: "test@example.com", - Password: "password123", - }).Return(nil, errors.New("db error")) - return mockService - }(), + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = false + config.AllowedEmails = "" + return nil + }) + mockService := servicemocks.NewMockIUserService(ctrl) + mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ + Email: "test@example.com", + Password: "password123", + }).Return(nil, errors.New("db error")) + return fields{ + userService: mockService, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } }, args: args{ ctx: context.Background(), @@ -98,17 +502,26 @@ func TestUserApplicationImpl_Register(t *testing.T) { }, { name: "create session error", - fields: fields{ - userService: func() service.IUserService { - ctrl := gomock.NewController(t) - mockService := servicemocks.NewMockIUserService(ctrl) - mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ - Email: "test@example.com", - Password: "password123", - }).Return(mockUser, nil) - mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("", errors.New("session error")) - return mockService - }(), + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = false + config.AllowedEmails = "" + return nil + }) + mockService := servicemocks.NewMockIUserService(ctrl) + mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ + Email: "test@example.com", + Password: "password123", + }).Return(mockUser, nil) + mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("", errors.New("session error")) + return fields{ + userService: mockService, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } }, args: args{ ctx: context.Background(), @@ -121,18 +534,94 @@ func TestUserApplicationImpl_Register(t *testing.T) { wantErr: errors.New("session error"), }, { - name: "success", - fields: fields{ - userService: func() service.IUserService { - ctrl := gomock.NewController(t) - mockService := servicemocks.NewMockIUserService(ctrl) - mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ - Email: "test@example.com", - Password: "password123", - }).Return(mockUser, nil) - mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("session_key", nil) - return mockService - }(), + name: "success - block=false allows all", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = false + config.AllowedEmails = "" + return nil + }) + mockService := servicemocks.NewMockIUserService(ctrl) + mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ + Email: "test@example.com", + Password: "password123", + }).Return(mockUser, nil) + mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("session_key", nil) + return fields{ + userService: mockService, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } + }, + args: args{ + ctx: context.Background(), + req: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + want: &user.UserRegisterResponse{ + UserInfo: convertor.UserDO2DTO(mockUser), + Token: ptr.Of("session_key"), + ExpireTime: ptr.Of(int64(session.SessionExpires)), + }, + wantErr: nil, + }, + { + name: "success - block=true but email in whitelist", + fields: func(ctrl *gomock.Controller) fields { + mockLoader := confmocks.NewMockIConfigLoader(ctrl) + mockLoader.EXPECT().UnmarshalKey(gomock.Any(), "user_register_control", gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, config *userRegisterControlConfig, opts ...conf.DecodeOptionFn) error { + config.Block = true + config.AllowedEmails = "test@example.com;admin@example.com" + return nil + }) + mockService := servicemocks.NewMockIUserService(ctrl) + mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ + Email: "test@example.com", + Password: "password123", + }).Return(mockUser, nil) + mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("session_key", nil) + return fields{ + userService: mockService, + registerController: userRegisterController{ + configLoader: mockLoader, + }, + } + }, + args: args{ + ctx: context.Background(), + req: &user.UserRegisterRequest{ + Email: ptr.Of("test@example.com"), + Password: ptr.Of("password123"), + }, + }, + want: &user.UserRegisterResponse{ + UserInfo: convertor.UserDO2DTO(mockUser), + Token: ptr.Of("session_key"), + ExpireTime: ptr.Of(int64(session.SessionExpires)), + }, + wantErr: nil, + }, + { + name: "success - nil config loader (weak dependency)", + fields: func(ctrl *gomock.Controller) fields { + mockService := servicemocks.NewMockIUserService(ctrl) + mockService.EXPECT().Create(gomock.Any(), &service.CreateUserRequest{ + Email: "test@example.com", + Password: "password123", + }).Return(mockUser, nil) + mockService.EXPECT().CreateSession(gomock.Any(), mockUser.UserID).Return("session_key", nil) + return fields{ + userService: mockService, + registerController: userRegisterController{ + configLoader: nil, + }, + } }, args: args{ ctx: context.Background(), @@ -151,8 +640,13 @@ func TestUserApplicationImpl_Register(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + fields := tt.fields(ctrl) p := &UserApplicationImpl{ - userService: tt.fields.userService, + userService: fields.userService, + registerController: fields.registerController, } got, err := p.Register(tt.args.ctx, tt.args.req) unittest.AssertErrorEqual(t, tt.wantErr, err) diff --git a/backend/modules/foundation/application/wire.go b/backend/modules/foundation/application/wire.go index d951d73a6..3607c11f0 100644 --- a/backend/modules/foundation/application/wire.go +++ b/backend/modules/foundation/application/wire.go @@ -23,6 +23,7 @@ import ( auth2 "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/auth" "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/repo" "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/pkg/conf" ) var ( @@ -92,6 +93,7 @@ func InitSpaceApplication( func InitUserApplication( idgen idgen.IIDGenerator, db db.Provider, + configFactory conf.IConfigLoaderFactory, ) (user.UserService, error) { wire.Build(userSet) return nil, nil diff --git a/backend/modules/foundation/application/wire_gen.go b/backend/modules/foundation/application/wire_gen.go index ca9476d6f..d75769646 100644 --- a/backend/modules/foundation/application/wire_gen.go +++ b/backend/modules/foundation/application/wire_gen.go @@ -21,6 +21,7 @@ import ( auth2 "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/auth" "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/repo" "github.com/coze-dev/coze-loop/backend/modules/foundation/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/pkg/conf" "github.com/google/wire" ) @@ -54,13 +55,16 @@ func InitSpaceApplication(idgen2 idgen.IIDGenerator, db2 db.Provider) (space.Spa return spaceService, nil } -func InitUserApplication(idgen2 idgen.IIDGenerator, db2 db.Provider) (user.UserService, error) { +func InitUserApplication(idgen2 idgen.IIDGenerator, db2 db.Provider, configFactory conf.IConfigLoaderFactory) (user.UserService, error) { iUserDAO := mysql.NewUserDAOImpl(db2) iSpaceDAO := mysql.NewSpaceDAOImpl(db2) iSpaceUserDAO := mysql.NewSpaceUserDAOImpl(db2) iUserRepo := repo.NewUserRepo(db2, idgen2, iUserDAO, iSpaceDAO, iSpaceUserDAO) iUserService := service.NewUserService(db2, iUserRepo, idgen2) - userService := NewUserApplication(iUserService) + userService, err := NewUserApplication(iUserService, configFactory) + if err != nil { + return nil, err + } return userService, nil } diff --git a/backend/modules/foundation/pkg/errno/foundation.go b/backend/modules/foundation/pkg/errno/foundation.go index 481bb973c..5f682a51b 100644 --- a/backend/modules/foundation/pkg/errno/foundation.go +++ b/backend/modules/foundation/pkg/errno/foundation.go @@ -79,6 +79,10 @@ const ( AccountOverdraftCodeCode = 602002007 accountOverdraftCodeMessage = "account overdraft" accountOverdraftCodeNoAffectStability = true + + UserRegistrationControlBlockCode = 602002008 + userRegistrationControlBlockMessage = "email address is restricted from registration based on account security protocols" + userRegistrationControlBlockNoAffectStability = true ) func init() { @@ -191,4 +195,10 @@ func init() { code.WithAffectStability(!accountOverdraftCodeNoAffectStability), ) + code.Register( + UserRegistrationControlBlockCode, + userRegistrationControlBlockMessage, + code.WithAffectStability(!userRegistrationControlBlockNoAffectStability), + ) + } diff --git a/backend/script/errorx/README.md b/backend/script/errorx/README.md index f985fcf71..2b1136c44 100644 --- a/backend/script/errorx/README.md +++ b/backend/script/errorx/README.md @@ -74,7 +74,7 @@ Where: # Generate code for evaluation domain ./code_gen.py evaluation --output-dir backend/module/evaluation/pkg/errno - # Or use default output directory (GOPATH/src/github.com/coze-dev/backend/module/{biz}/pkg/errno) + # Or use default output directory ({project_path}/backend/module/{biz}/pkg/errno) ./code_gen.py evaluation ``` diff --git a/backend/script/errorx/code_gen.py b/backend/script/errorx/code_gen.py index d0f7b863f..b808652cf 100755 --- a/backend/script/errorx/code_gen.py +++ b/backend/script/errorx/code_gen.py @@ -134,15 +134,15 @@ def generate_biz_code(biz_name: str, biz_code: int, common_errors: List[Dict], o biz_errors = load_yaml(biz_error_file)['error_code'] # Generate and output code - project_dir = os.environ.get('PROJECT_DIR', - os.path.join(os.environ['GOPATH'], 'src/github.com/coze-dev/coze-loop')) + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + project_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_file_dir))) if not output_dir: output_dir = os.path.join(project_dir, 'backend/modules', biz_name, 'pkg/errno') else: output_dir = os.path.expandvars(output_dir) if not os.path.isabs(output_dir): output_dir = os.path.join(project_dir, output_dir) - + return generate_go_code(biz_name, biz_code, common_errors, biz_errors, output_dir) diff --git a/backend/script/errorx/foundation.yaml b/backend/script/errorx/foundation.yaml index 6d72c7814..ebef242ad 100644 --- a/backend/script/errorx/foundation.yaml +++ b/backend/script/errorx/foundation.yaml @@ -43,3 +43,8 @@ error_code: code: 2007 message: account overdraft no_affect_stability: true + + - name: UserRegistrationControlBlock + code: 2008 + message: email address is restricted from registration based on account security protocols + no_affect_stability: true diff --git a/conf/default/app/runtime/foundation.yaml b/conf/default/app/runtime/foundation.yaml new file mode 100644 index 000000000..273a72f24 --- /dev/null +++ b/conf/default/app/runtime/foundation.yaml @@ -0,0 +1,3 @@ +user_register_control: + block: false + allowed_emails: "locala@doamin;localb@domain" \ No newline at end of file diff --git a/conf/default/app/runtime/locales/zh-CN.yaml b/conf/default/app/runtime/locales/zh-CN.yaml index 654da74da..baf567231 100644 --- a/conf/default/app/runtime/locales/zh-CN.yaml +++ b/conf/default/app/runtime/locales/zh-CN.yaml @@ -146,3 +146,4 @@ "602002005": "账户余额不足" "602002006": "账户已过期" "602002007": "账户透支" +"602002008": "根据账户安全协议,该邮箱地址禁止注册"