Skip to content

Commit 2cf3439

Browse files
Merge pull request #217 from actiontech/issues-213
/issues/213: support switch cloudbeaver connection address
2 parents a42692f + 596b680 commit 2cf3439

3 files changed

Lines changed: 86 additions & 54 deletions

File tree

internal/dms/biz/cloudbeaver.go

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -568,20 +568,11 @@ type UserList struct {
568568
} `json:"listUsers"`
569569
}
570570

571-
func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbeaverUserId string, dmsUser *User) error {
572-
cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
573-
if err != nil {
574-
return err
575-
}
571+
var reservedCloudbeaverUserId = map[string]struct{}{"admin": {}, "user": {}}
576572

577-
fingerprint := cu.userUsecase.GetUserFingerprint(dmsUser)
578-
if exist && cloudbeaverUser.DMSFingerprint == fingerprint {
579-
return nil
580-
}
581-
582-
reservedCloudbeaverUserId := map[string]struct{}{"admin": {}, "user": {}}
573+
func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbeaverUserId string, dmsUser *User) error {
583574
if _, ok := reservedCloudbeaverUserId[cloudbeaverUserId]; ok {
584-
return fmt.Errorf("this username cannot be used")
575+
return fmt.Errorf("username %s is reserved, cann't be used", cloudbeaverUserId)
585576
}
586577

587578
// 使用管理员身份登录
@@ -591,7 +582,6 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
591582
}
592583

593584
checkExistReq := cloudbeaver.NewRequest(cu.graphQl.IsUserExistQuery(cloudbeaverUserId))
594-
595585
cloudbeaverUserList := UserList{}
596586
err = graphQLClient.Run(ctx, checkExistReq, &cloudbeaverUserList)
597587
if err != nil {
@@ -619,6 +609,15 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
619609
if err != nil {
620610
return fmt.Errorf("grant cloudbeaver user failed: %v", err)
621611
}
612+
} else {
613+
cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
614+
if err != nil {
615+
return err
616+
}
617+
618+
if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) {
619+
return nil
620+
}
622621
}
623622

624623
// 设置CloudBeaver用户密码
@@ -633,9 +632,9 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
633632
return fmt.Errorf("update cloudbeaver user failed: %v", err)
634633
}
635634

636-
cloudbeaverUser = &CloudbeaverUser{
635+
cloudbeaverUser := &CloudbeaverUser{
637636
DMSUserID: dmsUser.UID,
638-
DMSFingerprint: fingerprint,
637+
DMSFingerprint: cu.userUsecase.GetUserFingerprint(dmsUser),
639638
CloudbeaverUserID: cloudbeaverUserId,
640639
}
641640

@@ -701,17 +700,18 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
701700
activeDBServices = lastActiveDBServices
702701
}
703702

704-
if err = cu.operateConnection(ctx, activeDBServices, dmsUser.UID); err != nil {
705-
return err
706-
}
707-
708703
cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
709704
if err != nil {
710705
return err
711706
}
712707
if !exist {
713708
return fmt.Errorf("cloudbeaver user: %s not eixst", cloudbeaverUserId)
714709
}
710+
711+
if err = cu.operateConnection(ctx, cloudbeaverUser, dmsUser, activeDBServices); err != nil {
712+
return err
713+
}
714+
715715
if err = cu.grantAccessConnection(ctx, cloudbeaverUser, dmsUser, activeDBServices); err != nil {
716716
return err
717717
}
@@ -725,11 +725,43 @@ func getDBPrimaryKey(dbUid, purpose, userUid string) string {
725725
return fmt.Sprint(dbUid, ":", purpose, ":", userUid)
726726
}
727727

728-
func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBServices []*DBService, userId string) error {
728+
type UserConnectionsResp struct {
729+
Connections []*struct {
730+
Id string `json:"id"`
731+
Template bool `json:"template"`
732+
} `json:"connections"`
733+
}
734+
735+
// 获取用户当前数据库连接ID
736+
func (cu *CloudbeaverUsecase) getUserConnectionIds(ctx context.Context, cloudbeaverUser *CloudbeaverUser, dmsUser *User) ([]string, error) {
737+
client, err := cu.getGraphQLClient(cloudbeaverUser.CloudbeaverUserID, dmsUser.Password)
738+
if err != nil {
739+
return nil, err
740+
}
741+
742+
var userConnectionsResp UserConnectionsResp
743+
744+
variables := map[string]interface{}{"projectId": cloudbeaverProjectId}
745+
err = client.Run(ctx, cloudbeaver.NewRequest(cu.graphQl.GetUserConnectionsQuery(), variables), &userConnectionsResp)
746+
if err != nil {
747+
return nil, err
748+
}
749+
750+
ret := make([]string, 0, len(userConnectionsResp.Connections))
751+
for _, connection := range userConnectionsResp.Connections {
752+
if !connection.Template {
753+
ret = append(ret, connection.Id)
754+
}
755+
}
756+
757+
return ret, nil
758+
}
759+
760+
func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, cloudbeaverUser *CloudbeaverUser, dmsUser *User, activeDBServices []*DBService) error {
729761
dbServiceMap := map[string]*DBService{}
730762
projectMap := map[string]string{}
731763
for _, service := range activeDBServices {
732-
dbServiceMap[getDBPrimaryKey(service.UID, service.AccountPurpose, userId)] = service
764+
dbServiceMap[getDBPrimaryKey(service.UID, service.AccountPurpose, dmsUser.UID)] = service
733765

734766
project, err := cu.dbServiceUsecase.projectUsecase.GetProject(ctx, service.ProjectUID)
735767
if err != nil {
@@ -741,33 +773,41 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer
741773
}
742774

743775
//获取当前用户所有已创建的连接
744-
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserId(ctx, userId)
776+
localCloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserId(ctx, dmsUser.UID)
745777
if err != nil {
746778
return err
747779
}
748780

781+
// cloudbeaver连接数为空则重置缓存
782+
if userConnectionIds, err := cu.getUserConnectionIds(ctx, cloudbeaverUser, dmsUser); err != nil {
783+
return err
784+
} else if len(userConnectionIds) == 0 {
785+
localCloudbeaverConnections = []*CloudbeaverConnection{}
786+
}
787+
749788
var deleteConnections []*CloudbeaverConnection
750789

751790
cloudbeaverConnectionMap := map[string]*CloudbeaverConnection{}
752-
for _, connection := range cloudbeaverConnections {
791+
for _, connection := range localCloudbeaverConnections {
753792
// 删除用户关联的连接
754-
if connection.DMSUserId == userId {
793+
if connection.DMSUserId == dmsUser.UID {
755794
cloudbeaverConnectionMap[connection.PrimaryKey()] = connection
756795
if _, ok := dbServiceMap[connection.PrimaryKey()]; !ok {
757796
deleteConnections = append(deleteConnections, connection)
758797
}
759798
}
760799
}
761800

762-
createConnections, updateConnections := []*CloudbeaverConnection{}, []*CloudbeaverConnection{}
801+
var createConnections []*CloudbeaverConnection
802+
var updateConnections []*CloudbeaverConnection
763803

764804
for _, dbService := range dbServiceMap {
765-
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[getDBPrimaryKey(dbService.UID, dbService.AccountPurpose, userId)]; ok {
805+
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[getDBPrimaryKey(dbService.UID, dbService.AccountPurpose, dmsUser.UID)]; ok {
766806
if cloudbeaverConnection.DMSDBServiceFingerprint != cu.dbServiceUsecase.GetDBServiceFingerprint(dbService) {
767-
updateConnections = append(updateConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: userId})
807+
updateConnections = append(updateConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: dmsUser.UID})
768808
}
769809
} else {
770-
createConnections = append(createConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: userId})
810+
createConnections = append(createConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: dmsUser.UID})
771811
}
772812
}
773813

@@ -783,20 +823,20 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer
783823

784824
// 同步实例连接信息
785825
for _, createConnection := range createConnections {
786-
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[getDBPrimaryKey(createConnection.DMSDBServiceID, createConnection.Purpose, userId)],
787-
projectMap[createConnection.DMSDBServiceID], userId); err != nil {
826+
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[getDBPrimaryKey(createConnection.DMSDBServiceID, createConnection.Purpose, dmsUser.UID)],
827+
projectMap[createConnection.DMSDBServiceID], dmsUser.UID); err != nil {
788828
cu.log.Errorf("create connection %v failed: %v", createConnection, err)
789829
}
790830
}
791831

792832
for _, updateConnection := range updateConnections {
793-
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, updateConnection.CloudbeaverConnectionID, dbServiceMap[getDBPrimaryKey(updateConnection.DMSDBServiceID, updateConnection.Purpose, userId)], projectMap[updateConnection.DMSDBServiceID], userId); err != nil {
833+
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, updateConnection.CloudbeaverConnectionID, dbServiceMap[getDBPrimaryKey(updateConnection.DMSDBServiceID, updateConnection.Purpose, dmsUser.UID)], projectMap[updateConnection.DMSDBServiceID], dmsUser.UID); err != nil {
794834
cu.log.Errorf("update dnServerId %s to connection failed: %v", updateConnection, err)
795835
}
796836
}
797837

798838
for _, deleteConnection := range deleteConnections {
799-
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, deleteConnection.CloudbeaverConnectionID, deleteConnection.DMSDBServiceID, userId, deleteConnection.Purpose); err != nil {
839+
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, deleteConnection.CloudbeaverConnectionID, deleteConnection.DMSDBServiceID, dmsUser.UID, deleteConnection.Purpose); err != nil {
800840
cu.log.Errorf("delete connection %v failed: %v", deleteConnection, err)
801841
}
802842
}
@@ -842,41 +882,29 @@ func (cu *CloudbeaverUsecase) grantAccessConnection(ctx context.Context, cloudbe
842882
for _, dbService := range activeDBServices {
843883
dbServiceIds = append(dbServiceIds, dbService.UID)
844884
}
845-
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx, dmsUser.UID, dbServiceIds)
885+
localCloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx, dmsUser.UID, dbServiceIds)
846886
if err != nil {
847887
return err
848888
}
849889

850890
// 从缓存中获取需要同步的CloudBeaver实例
851891
cloudbeaverConnectionMap := map[string]*CloudbeaverConnection{}
852-
for _, cloudbeaverConnection := range cloudbeaverConnections {
853-
cloudbeaverConnectionMap[cloudbeaverConnection.CloudbeaverConnectionID] = cloudbeaverConnection
854-
}
855-
856-
// 获取用户当前实例列表
857-
connResp := &struct {
858-
Connections []*struct {
859-
Id string `json:"id"`
860-
} `json:"connections"`
861-
}{}
862-
863-
client, err := cu.getGraphQLClient(cloudbeaverUser.CloudbeaverUserID, dmsUser.Password)
864-
if err != nil {
865-
return err
892+
for _, connection := range localCloudbeaverConnections {
893+
cloudbeaverConnectionMap[connection.CloudbeaverConnectionID] = connection
866894
}
867895

868-
err = client.Run(ctx, cloudbeaver.NewRequest(cu.graphQl.GetUserConnectionsQuery(), nil), connResp)
896+
cloudbeaverConnectionIds, err := cu.getUserConnectionIds(ctx, cloudbeaverUser, dmsUser)
869897
if err != nil {
870898
return err
871899
}
872900

873-
if len(connResp.Connections) != len(cloudbeaverConnections) {
874-
return cu.bindUserAccessConnection(ctx, cloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
901+
if len(cloudbeaverConnectionIds) != len(localCloudbeaverConnections) {
902+
return cu.bindUserAccessConnection(ctx, localCloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
875903
}
876904

877-
for _, connection := range connResp.Connections {
878-
if _, ok := cloudbeaverConnectionMap[connection.Id]; !ok {
879-
return cu.bindUserAccessConnection(ctx, cloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
905+
for _, connectionId := range cloudbeaverConnectionIds {
906+
if _, ok := cloudbeaverConnectionMap[connectionId]; !ok {
907+
return cu.bindUserAccessConnection(ctx, localCloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
880908
}
881909
}
882910

internal/dms/storage/cloudbeaver.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/actiontech/dms/internal/dms/biz"
88
"github.com/actiontech/dms/internal/dms/storage/model"
9+
"gorm.io/gorm/clause"
910

1011
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
1112

@@ -118,7 +119,9 @@ func (cr *CloudbeaverRepo) GetCloudbeaverConnectionsByUserId(ctx context.Context
118119

119120
func (cr *CloudbeaverRepo) UpdateCloudbeaverConnectionCache(ctx context.Context, u *biz.CloudbeaverConnection) error {
120121
return transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
121-
if err := tx.WithContext(ctx).Save(convertBizCloudbeaverConnection(u)).Error; err != nil {
122+
if err := tx.WithContext(ctx).Clauses(clause.OnConflict{
123+
UpdateAll: true,
124+
}).Create(convertBizCloudbeaverConnection(u)).Error; err != nil {
122125
return fmt.Errorf("failed to update cloudbeaver db Service: %v", err)
123126
}
124127
return nil

internal/pkg/cloudbeaver/cloudbeaver.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ query getUserConnections (
139139
}
140140
fragment DatabaseConnection on ConnectionInfo {
141141
id
142+
template
142143
}
143144
`
144145
}

0 commit comments

Comments
 (0)