@@ -568,20 +568,11 @@ type UserList struct {
568568 } `json:"listUsers"`
569569}
570570
571- func (cu * CloudbeaverUsecase ) createUserIfNotExist (ctx context.Context , cloudbeaverUserId string , dmsUser * User ) error {
572- cloudbeaverUser , exist , err := cu .repo .GetCloudbeaverUserByID (ctx , cloudbeaverUserId )
573- if err != nil {
574- return err
575- }
571+ var reservedCloudbeaverUserId = map [string ]struct {}{"admin" : {}, "user" : {}}
576572
577- fingerprint := cu .userUsecase .GetUserFingerprint (dmsUser )
578- if exist && cloudbeaverUser .DMSFingerprint == fingerprint {
579- return nil
580- }
581-
582- reservedCloudbeaverUserId := map [string ]struct {}{"admin" : {}, "user" : {}}
573+ func (cu * CloudbeaverUsecase ) createUserIfNotExist (ctx context.Context , cloudbeaverUserId string , dmsUser * User ) error {
583574 if _ , ok := reservedCloudbeaverUserId [cloudbeaverUserId ]; ok {
584- return fmt .Errorf ("this username cannot be used" )
575+ return fmt .Errorf ("username %s is reserved, cann't be used" , cloudbeaverUserId )
585576 }
586577
587578 // 使用管理员身份登录
@@ -591,7 +582,6 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
591582 }
592583
593584 checkExistReq := cloudbeaver .NewRequest (cu .graphQl .IsUserExistQuery (cloudbeaverUserId ))
594-
595585 cloudbeaverUserList := UserList {}
596586 err = graphQLClient .Run (ctx , checkExistReq , & cloudbeaverUserList )
597587 if err != nil {
@@ -619,6 +609,15 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
619609 if err != nil {
620610 return fmt .Errorf ("grant cloudbeaver user failed: %v" , err )
621611 }
612+ } else {
613+ cloudbeaverUser , exist , err := cu .repo .GetCloudbeaverUserByID (ctx , cloudbeaverUserId )
614+ if err != nil {
615+ return err
616+ }
617+
618+ if exist && cloudbeaverUser .DMSFingerprint == cu .userUsecase .GetUserFingerprint (dmsUser ) {
619+ return nil
620+ }
622621 }
623622
624623 // 设置CloudBeaver用户密码
@@ -633,9 +632,9 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
633632 return fmt .Errorf ("update cloudbeaver user failed: %v" , err )
634633 }
635634
636- cloudbeaverUser = & CloudbeaverUser {
635+ cloudbeaverUser : = & CloudbeaverUser {
637636 DMSUserID : dmsUser .UID ,
638- DMSFingerprint : fingerprint ,
637+ DMSFingerprint : cu . userUsecase . GetUserFingerprint ( dmsUser ) ,
639638 CloudbeaverUserID : cloudbeaverUserId ,
640639 }
641640
@@ -701,17 +700,18 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
701700 activeDBServices = lastActiveDBServices
702701 }
703702
704- if err = cu .operateConnection (ctx , activeDBServices , dmsUser .UID ); err != nil {
705- return err
706- }
707-
708703 cloudbeaverUser , exist , err := cu .repo .GetCloudbeaverUserByID (ctx , cloudbeaverUserId )
709704 if err != nil {
710705 return err
711706 }
712707 if ! exist {
713708 return fmt .Errorf ("cloudbeaver user: %s not eixst" , cloudbeaverUserId )
714709 }
710+
711+ if err = cu .operateConnection (ctx , cloudbeaverUser , dmsUser , activeDBServices ); err != nil {
712+ return err
713+ }
714+
715715 if err = cu .grantAccessConnection (ctx , cloudbeaverUser , dmsUser , activeDBServices ); err != nil {
716716 return err
717717 }
@@ -725,11 +725,43 @@ func getDBPrimaryKey(dbUid, purpose, userUid string) string {
725725 return fmt .Sprint (dbUid , ":" , purpose , ":" , userUid )
726726}
727727
728- func (cu * CloudbeaverUsecase ) operateConnection (ctx context.Context , activeDBServices []* DBService , userId string ) error {
728+ type UserConnectionsResp struct {
729+ Connections []* struct {
730+ Id string `json:"id"`
731+ Template bool `json:"template"`
732+ } `json:"connections"`
733+ }
734+
735+ // 获取用户当前数据库连接ID
736+ func (cu * CloudbeaverUsecase ) getUserConnectionIds (ctx context.Context , cloudbeaverUser * CloudbeaverUser , dmsUser * User ) ([]string , error ) {
737+ client , err := cu .getGraphQLClient (cloudbeaverUser .CloudbeaverUserID , dmsUser .Password )
738+ if err != nil {
739+ return nil , err
740+ }
741+
742+ var userConnectionsResp UserConnectionsResp
743+
744+ variables := map [string ]interface {}{"projectId" : cloudbeaverProjectId }
745+ err = client .Run (ctx , cloudbeaver .NewRequest (cu .graphQl .GetUserConnectionsQuery (), variables ), & userConnectionsResp )
746+ if err != nil {
747+ return nil , err
748+ }
749+
750+ ret := make ([]string , 0 , len (userConnectionsResp .Connections ))
751+ for _ , connection := range userConnectionsResp .Connections {
752+ if ! connection .Template {
753+ ret = append (ret , connection .Id )
754+ }
755+ }
756+
757+ return ret , nil
758+ }
759+
760+ func (cu * CloudbeaverUsecase ) operateConnection (ctx context.Context , cloudbeaverUser * CloudbeaverUser , dmsUser * User , activeDBServices []* DBService ) error {
729761 dbServiceMap := map [string ]* DBService {}
730762 projectMap := map [string ]string {}
731763 for _ , service := range activeDBServices {
732- dbServiceMap [getDBPrimaryKey (service .UID , service .AccountPurpose , userId )] = service
764+ dbServiceMap [getDBPrimaryKey (service .UID , service .AccountPurpose , dmsUser . UID )] = service
733765
734766 project , err := cu .dbServiceUsecase .projectUsecase .GetProject (ctx , service .ProjectUID )
735767 if err != nil {
@@ -741,33 +773,41 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer
741773 }
742774
743775 //获取当前用户所有已创建的连接
744- cloudbeaverConnections , err := cu .repo .GetCloudbeaverConnectionsByUserId (ctx , userId )
776+ localCloudbeaverConnections , err := cu .repo .GetCloudbeaverConnectionsByUserId (ctx , dmsUser . UID )
745777 if err != nil {
746778 return err
747779 }
748780
781+ // cloudbeaver连接数为空则重置缓存
782+ if userConnectionIds , err := cu .getUserConnectionIds (ctx , cloudbeaverUser , dmsUser ); err != nil {
783+ return err
784+ } else if len (userConnectionIds ) == 0 {
785+ localCloudbeaverConnections = []* CloudbeaverConnection {}
786+ }
787+
749788 var deleteConnections []* CloudbeaverConnection
750789
751790 cloudbeaverConnectionMap := map [string ]* CloudbeaverConnection {}
752- for _ , connection := range cloudbeaverConnections {
791+ for _ , connection := range localCloudbeaverConnections {
753792 // 删除用户关联的连接
754- if connection .DMSUserId == userId {
793+ if connection .DMSUserId == dmsUser . UID {
755794 cloudbeaverConnectionMap [connection .PrimaryKey ()] = connection
756795 if _ , ok := dbServiceMap [connection .PrimaryKey ()]; ! ok {
757796 deleteConnections = append (deleteConnections , connection )
758797 }
759798 }
760799 }
761800
762- createConnections , updateConnections := []* CloudbeaverConnection {}, []* CloudbeaverConnection {}
801+ var createConnections []* CloudbeaverConnection
802+ var updateConnections []* CloudbeaverConnection
763803
764804 for _ , dbService := range dbServiceMap {
765- if cloudbeaverConnection , ok := cloudbeaverConnectionMap [getDBPrimaryKey (dbService .UID , dbService .AccountPurpose , userId )]; ok {
805+ if cloudbeaverConnection , ok := cloudbeaverConnectionMap [getDBPrimaryKey (dbService .UID , dbService .AccountPurpose , dmsUser . UID )]; ok {
766806 if cloudbeaverConnection .DMSDBServiceFingerprint != cu .dbServiceUsecase .GetDBServiceFingerprint (dbService ) {
767- updateConnections = append (updateConnections , & CloudbeaverConnection {DMSDBServiceID : dbService .UID , Purpose : dbService .AccountPurpose , DMSUserId : userId })
807+ updateConnections = append (updateConnections , & CloudbeaverConnection {DMSDBServiceID : dbService .UID , Purpose : dbService .AccountPurpose , DMSUserId : dmsUser . UID })
768808 }
769809 } else {
770- createConnections = append (createConnections , & CloudbeaverConnection {DMSDBServiceID : dbService .UID , Purpose : dbService .AccountPurpose , DMSUserId : userId })
810+ createConnections = append (createConnections , & CloudbeaverConnection {DMSDBServiceID : dbService .UID , Purpose : dbService .AccountPurpose , DMSUserId : dmsUser . UID })
771811 }
772812 }
773813
@@ -783,20 +823,20 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer
783823
784824 // 同步实例连接信息
785825 for _ , createConnection := range createConnections {
786- if err = cu .createCloudbeaverConnection (ctx , cloudbeaverClient , dbServiceMap [getDBPrimaryKey (createConnection .DMSDBServiceID , createConnection .Purpose , userId )],
787- projectMap [createConnection .DMSDBServiceID ], userId ); err != nil {
826+ if err = cu .createCloudbeaverConnection (ctx , cloudbeaverClient , dbServiceMap [getDBPrimaryKey (createConnection .DMSDBServiceID , createConnection .Purpose , dmsUser . UID )],
827+ projectMap [createConnection .DMSDBServiceID ], dmsUser . UID ); err != nil {
788828 cu .log .Errorf ("create connection %v failed: %v" , createConnection , err )
789829 }
790830 }
791831
792832 for _ , updateConnection := range updateConnections {
793- if err = cu .updateCloudbeaverConnection (ctx , cloudbeaverClient , updateConnection .CloudbeaverConnectionID , dbServiceMap [getDBPrimaryKey (updateConnection .DMSDBServiceID , updateConnection .Purpose , userId )], projectMap [updateConnection .DMSDBServiceID ], userId ); err != nil {
833+ if err = cu .updateCloudbeaverConnection (ctx , cloudbeaverClient , updateConnection .CloudbeaverConnectionID , dbServiceMap [getDBPrimaryKey (updateConnection .DMSDBServiceID , updateConnection .Purpose , dmsUser . UID )], projectMap [updateConnection .DMSDBServiceID ], dmsUser . UID ); err != nil {
794834 cu .log .Errorf ("update dnServerId %s to connection failed: %v" , updateConnection , err )
795835 }
796836 }
797837
798838 for _ , deleteConnection := range deleteConnections {
799- if err = cu .deleteCloudbeaverConnection (ctx , cloudbeaverClient , deleteConnection .CloudbeaverConnectionID , deleteConnection .DMSDBServiceID , userId , deleteConnection .Purpose ); err != nil {
839+ if err = cu .deleteCloudbeaverConnection (ctx , cloudbeaverClient , deleteConnection .CloudbeaverConnectionID , deleteConnection .DMSDBServiceID , dmsUser . UID , deleteConnection .Purpose ); err != nil {
800840 cu .log .Errorf ("delete connection %v failed: %v" , deleteConnection , err )
801841 }
802842 }
@@ -842,41 +882,29 @@ func (cu *CloudbeaverUsecase) grantAccessConnection(ctx context.Context, cloudbe
842882 for _ , dbService := range activeDBServices {
843883 dbServiceIds = append (dbServiceIds , dbService .UID )
844884 }
845- cloudbeaverConnections , err := cu .repo .GetCloudbeaverConnectionsByUserIdAndDBServiceIds (ctx , dmsUser .UID , dbServiceIds )
885+ localCloudbeaverConnections , err := cu .repo .GetCloudbeaverConnectionsByUserIdAndDBServiceIds (ctx , dmsUser .UID , dbServiceIds )
846886 if err != nil {
847887 return err
848888 }
849889
850890 // 从缓存中获取需要同步的CloudBeaver实例
851891 cloudbeaverConnectionMap := map [string ]* CloudbeaverConnection {}
852- for _ , cloudbeaverConnection := range cloudbeaverConnections {
853- cloudbeaverConnectionMap [cloudbeaverConnection .CloudbeaverConnectionID ] = cloudbeaverConnection
854- }
855-
856- // 获取用户当前实例列表
857- connResp := & struct {
858- Connections []* struct {
859- Id string `json:"id"`
860- } `json:"connections"`
861- }{}
862-
863- client , err := cu .getGraphQLClient (cloudbeaverUser .CloudbeaverUserID , dmsUser .Password )
864- if err != nil {
865- return err
892+ for _ , connection := range localCloudbeaverConnections {
893+ cloudbeaverConnectionMap [connection .CloudbeaverConnectionID ] = connection
866894 }
867895
868- err = client . Run (ctx , cloudbeaver . NewRequest ( cu . graphQl . GetUserConnectionsQuery (), nil ), connResp )
896+ cloudbeaverConnectionIds , err := cu . getUserConnectionIds (ctx , cloudbeaverUser , dmsUser )
869897 if err != nil {
870898 return err
871899 }
872900
873- if len (connResp . Connections ) != len (cloudbeaverConnections ) {
874- return cu .bindUserAccessConnection (ctx , cloudbeaverConnections , cloudbeaverUser .CloudbeaverUserID )
901+ if len (cloudbeaverConnectionIds ) != len (localCloudbeaverConnections ) {
902+ return cu .bindUserAccessConnection (ctx , localCloudbeaverConnections , cloudbeaverUser .CloudbeaverUserID )
875903 }
876904
877- for _ , connection := range connResp . Connections {
878- if _ , ok := cloudbeaverConnectionMap [connection . Id ]; ! ok {
879- return cu .bindUserAccessConnection (ctx , cloudbeaverConnections , cloudbeaverUser .CloudbeaverUserID )
905+ for _ , connectionId := range cloudbeaverConnectionIds {
906+ if _ , ok := cloudbeaverConnectionMap [connectionId ]; ! ok {
907+ return cu .bindUserAccessConnection (ctx , localCloudbeaverConnections , cloudbeaverUser .CloudbeaverUserID )
880908 }
881909 }
882910
0 commit comments