Skip to content

Commit e86623c

Browse files
Merge pull request #216 from actiontech/fix-issue2309-1
Fix issue2309 1
2 parents 2cf3439 + 0a7d553 commit e86623c

11 files changed

Lines changed: 230 additions & 605 deletions

File tree

internal/apiserver/cmd/server/dms.pid

Lines changed: 0 additions & 1 deletion
This file was deleted.

internal/apiserver/cmd/server/logs/dms.log

Lines changed: 0 additions & 601 deletions
This file was deleted.

internal/apiserver/service/dms_controller.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,17 @@ func (a *DMSController) GenAccessToken(c echo.Context) error {
748748
if nil != err {
749749
return NewErrResp(c, err, apiError.BadRequestErr)
750750
}
751-
reply := &dmsV1.GenAccessTokenReply{}
751+
752+
// get current user id
753+
currentUid, err := jwt.GetUserUidStrFromContext(c)
754+
if err != nil {
755+
return NewErrResp(c, err, apiError.DMSServiceErr)
756+
}
757+
758+
reply, err := a.DMS.GenAccessToken(c.Request().Context(), currentUid, req)
759+
if nil != err {
760+
return NewErrResp(c, err, apiError.DMSServiceErr)
761+
}
752762
return NewOkRespWithReply(c, reply)
753763
}
754764

internal/apiserver/service/router.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ func (s *APIServer) installMiddleware() error {
232232
if strings.HasSuffix(c.Request().RequestURI, dmsV1.SessionRouterGroup) ||
233233
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/oauth2" /* TODO 使用统一方法skip */) ||
234234
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/personalization/logo") ||
235-
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/configurations/license" /* TODO 使用统一方法skip */) ||
236-
!strings.HasPrefix(c.Request().RequestURI, dmsV1.CurrentGroupVersion) {
235+
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/configurations/license" /* TODO 使用统一方法skip */) {
237236
logger.Debugf("skipper url jwt check: %v", c.Request().RequestURI)
238237
return true
239238
}
@@ -245,6 +244,8 @@ func (s *APIServer) installMiddleware() error {
245244

246245
s.echo.Use(dmsMiddleware.LicenseAdapter(s.DMSController.DMS.LicenseUsecase))
247246

247+
s.echo.Use(s.DMSController.DMS.AuthAccessTokenUseCase.CheckLatestAccessToken())
248+
248249
s.echo.Use(middleware.ProxyWithConfig(middleware.ProxyConfig{
249250
Skipper: s.DMSController.DMS.DmsProxyUsecase.GetEchoProxySkipper(),
250251
Balancer: s.DMSController.DMS.DmsProxyUsecase.GetEchoProxyBalancer(),

internal/dms/biz/access_token.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package biz
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
8+
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
9+
"github.com/golang-jwt/jwt/v4"
10+
"github.com/labstack/echo/v4"
11+
)
12+
13+
const AccessTokenLogin = "access_token_login"
14+
15+
type AuthAccessTokenUsecase struct {
16+
userUsecase *UserUsecase
17+
log *utilLog.Helper
18+
}
19+
20+
func NewAuthAccessTokenUsecase(log utilLog.Logger, usecase *UserUsecase) *AuthAccessTokenUsecase {
21+
au := &AuthAccessTokenUsecase{
22+
userUsecase: usecase,
23+
log: utilLog.NewHelper(log, utilLog.WithMessageKey("biz.accesstoken")),
24+
}
25+
return au
26+
}
27+
28+
func (au *AuthAccessTokenUsecase) CheckLatestAccessToken() echo.MiddlewareFunc {
29+
return func(next echo.HandlerFunc) echo.HandlerFunc {
30+
return func(c echo.Context) error {
31+
user := c.Get("user")
32+
// 获取token为空,代表该请求不需要校验token,例如:/v1/dms/oauth2
33+
if user == nil {
34+
return next(c)
35+
}
36+
token, ok := user.(*jwt.Token)
37+
if !ok {
38+
return echo.NewHTTPError(http.StatusBadRequest, "failed to convert user from jwt token")
39+
}
40+
41+
claims, ok := token.Claims.(jwt.MapClaims)
42+
if !ok {
43+
return echo.NewHTTPError(http.StatusBadRequest, "failed to convert token claims to jwt")
44+
}
45+
46+
// 如果不存在JWTLoginType字段,代表是账号密码登录获取的token或者是扫描任务的凭证,不进行校验
47+
loginType, ok := claims[jwtPkg.JWTLoginType]
48+
if !ok {
49+
return next(c)
50+
}
51+
if loginType != AccessTokenLogin {
52+
return echo.NewHTTPError(http.StatusUnauthorized, "access token login type is error")
53+
}
54+
uidStr := fmt.Sprintf("%v", claims[jwtPkg.JWTUserId])
55+
accessTokenInfo, err := au.userUsecase.repo.GetAccessTokenByUser(c.Request().Context(), uidStr)
56+
if err != nil {
57+
return err
58+
}
59+
60+
if accessTokenInfo.Token != token.Raw {
61+
return echo.NewHTTPError(http.StatusUnauthorized, "access token is not latest")
62+
}
63+
64+
return next(c)
65+
}
66+
}
67+
}

internal/dms/biz/user.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/tls"
66
"errors"
77
"fmt"
8+
"strconv"
89
"time"
910

1011
pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant"
@@ -83,6 +84,13 @@ type User struct {
8384
Deleted bool
8485
}
8586

87+
type AccessTokenInfo struct {
88+
UID string
89+
UserID uint
90+
Token string
91+
ExpiredTime time.Time
92+
}
93+
8694
func initUsers() []*User {
8795
return []*User{
8896
{
@@ -162,6 +170,8 @@ type UserRepo interface {
162170
GetUserGroupsByUser(ctx context.Context, userUid string) ([]*UserGroup, error)
163171
GetOpPermissionsByUser(ctx context.Context, userUid string) ([]*OpPermission, error)
164172
GetUserByThirdPartyUserID(ctx context.Context, thirdPartyUserUID string) (*User, error)
173+
SaveAccessToken(ctx context.Context, accessTokenInfo *AccessTokenInfo) error
174+
GetAccessTokenByUser(ctx context.Context, UserUid string) (*AccessTokenInfo, error)
165175
}
166176

167177
type UserUsecase struct {
@@ -769,3 +779,25 @@ func (d *UserUsecase) GetBizUserWithNameByUids(ctx context.Context, uids []strin
769779
}
770780
return ret
771781
}
782+
783+
func (d *UserUsecase) SaveAccessToken(ctx context.Context, userId string, token string, expiredTime time.Time) error {
784+
userIdInt, err := strconv.Atoi(userId)
785+
if err != nil {
786+
return err
787+
}
788+
uid, err := pkgRand.GenStrUid()
789+
if err != nil {
790+
return err
791+
}
792+
793+
tokenInfo := &AccessTokenInfo{UID: uid, UserID: uint(userIdInt), Token: token, ExpiredTime: expiredTime}
794+
return d.repo.SaveAccessToken(ctx, tokenInfo)
795+
}
796+
797+
func (d *UserUsecase) GetAccessTokenByUser(ctx context.Context, UserUid string) (*AccessTokenInfo, error) {
798+
accessTokenInfo, err := d.repo.GetAccessTokenByUser(ctx, UserUid)
799+
if err != nil {
800+
return nil, err
801+
}
802+
return accessTokenInfo, nil
803+
}

internal/dms/service/service.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type DMSService struct {
3737
ClusterUsecase *biz.ClusterUsecase
3838
DataExportWorkflowUsecase *biz.DataExportWorkflowUsecase
3939
DataMaskingUsecase *biz.DataMaskingUsecase
40+
AuthAccessTokenUseCase *biz.AuthAccessTokenUsecase
4041
log *utilLog.Helper
4142
shutdownCallback func() error
4243
}
@@ -111,6 +112,7 @@ func NewAndInitDMSService(logger utilLog.Logger, opts *conf.DMSOptions) (*DMSSer
111112
workflowRepo := storage.NewWorkflowRepo(logger, st)
112113
DataExportWorkflowUsecase := biz.NewDataExportWorkflowUsecase(logger, tx, workflowRepo, dataExportTaskRepo, dbServiceRepo, opPermissionVerifyUsecase, projectUsecase, dmsProxyTargetRepo, clusterUsecase, webhookConfigurationUsecase, userUsecase, fmt.Sprintf("%s:%d", opts.ReportHost, opts.APIServiceOpts.Port))
113114
dataMasking, err := maskingBiz.NewDataMaskingUseCase(logger)
115+
authAccessTokenUsecase := biz.NewAuthAccessTokenUsecase(logger, userUsecase)
114116
if err != nil {
115117
return nil, fmt.Errorf("failed to new data masking use case: %v", err)
116118
}
@@ -147,6 +149,7 @@ func NewAndInitDMSService(logger utilLog.Logger, opts *conf.DMSOptions) (*DMSSer
147149
ClusterUsecase: clusterUsecase,
148150
DataExportWorkflowUsecase: DataExportWorkflowUsecase,
149151
DataMaskingUsecase: dataMaskingUsecase,
152+
AuthAccessTokenUseCase: authAccessTokenUsecase,
150153
log: utilLog.NewHelper(logger, utilLog.WithMessageKey("dms.service")),
151154
shutdownCallback: func() error {
152155
if err := st.Close(); nil != err {

internal/dms/service/user.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@ package service
33
import (
44
"context"
55
"fmt"
6+
"strconv"
67
"strings"
8+
"time"
79

810
dmsV1 "github.com/actiontech/dms/api/dms/service/v1"
911
"github.com/actiontech/dms/internal/dms/biz"
1012
pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant"
1113

1214
dmsCommonV1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
15+
jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
16+
"github.com/golang-jwt/jwt/v4"
1317
)
1418

1519
func (d *DMSService) VerifyUserLogin(ctx context.Context, req *dmsV1.VerifyUserLoginReq) (reply *dmsV1.VerifyUserLoginReply, err error) {
@@ -505,6 +509,19 @@ func (d *DMSService) GetUser(ctx context.Context, req *dmsCommonV1.GetUserReq) (
505509
}
506510
dmsCommonUser.UserBindProjects = userBindProjects
507511

512+
// 获取用户access token
513+
tokenInfo, err := d.UserUsecase.GetAccessTokenByUser(ctx, u.UID)
514+
if err != nil {
515+
return nil, fmt.Errorf("failed to get user access token: %v", err)
516+
}
517+
accessToken := dmsCommonV1.AccessTokenInfo{}
518+
accessToken.AccessToken = tokenInfo.Token
519+
accessToken.ExpiredTime = tokenInfo.ExpiredTime.Format("2006-01-02T15:04:05-07:00")
520+
if tokenInfo.ExpiredTime.Before(time.Now()) {
521+
accessToken.IsExpired = true
522+
}
523+
dmsCommonUser.AccessTokenInfo = accessToken
524+
508525
reply = &dmsCommonV1.GetUserReply{
509526
Data: dmsCommonUser,
510527
}
@@ -513,6 +530,31 @@ func (d *DMSService) GetUser(ctx context.Context, req *dmsCommonV1.GetUserReq) (
513530
return reply, nil
514531
}
515532

533+
func (d *DMSService) GenAccessToken(ctx context.Context, currentUserUid string, req *dmsCommonV1.GenAccessToken) (reply *dmsCommonV1.GenAccessTokenReply, err error) {
534+
days, err := strconv.ParseUint(req.ExpirationDays, 10, 64)
535+
if err != nil {
536+
return nil, err
537+
}
538+
539+
expiredTime := time.Now().Add(time.Duration(days) * 24 * time.Hour)
540+
token, err := jwtPkg.GenJwtTokenWithExpirationTime(jwt.NewNumericDate(expiredTime), jwtPkg.WithUserId(currentUserUid), jwtPkg.WithAccessTokenMark(biz.AccessTokenLogin))
541+
if err != nil {
542+
return nil, fmt.Errorf("gen access token failed: %v", err)
543+
}
544+
if err := d.UserUsecase.SaveAccessToken(ctx, currentUserUid, token, expiredTime); err != nil {
545+
return nil, fmt.Errorf("save access token failed: %v", err)
546+
}
547+
548+
reply = &dmsCommonV1.GenAccessTokenReply{
549+
Data: &dmsCommonV1.AccessTokenInfo{
550+
AccessToken: token,
551+
ExpiredTime: expiredTime.Format("2006-01-02T15:04:05-07:00"),
552+
},
553+
}
554+
555+
return reply, nil
556+
}
557+
516558
func convertBizOpPermission(opPermissionUid string) (apiOpPermissionTyp dmsCommonV1.OpPermissionType, err error) {
517559
switch opPermissionUid {
518560
case pkgConst.UIDOfOpPermissionCreateWorkflow:

internal/dms/storage/model/model.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ var AutoMigrateList = []interface{}{
4545
WorkflowStep{},
4646
DataExportTask{},
4747
DataExportTaskRecord{},
48+
UserAccessToken{},
4849
}
4950

5051
type Model struct {
@@ -140,6 +141,15 @@ type OpPermission struct {
140141
RangeType string `json:"range_type" gorm:"size:255;column:range_type"`
141142
}
142143

144+
type UserAccessToken struct {
145+
Model
146+
Token string `json:"token" gorm:"size:255"`
147+
ExpiredTime time.Time `json:"expired_time" example:"2018-10-21T16:40:23+08:00"`
148+
UserID uint `json:"user_id" gorm:"size:32;index:user_id,unique"`
149+
150+
User *User `json:"user" gorm:"foreignkey:user_id"`
151+
}
152+
143153
type DMSConfig struct {
144154
Model
145155
NeedInitOpPermissions bool `json:"need_init_op_permissions" gorm:"column:need_init_op_permissions"`

internal/dms/storage/user.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package storage
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67

78
"github.com/actiontech/dms/internal/dms/biz"
@@ -12,6 +13,7 @@ import (
1213
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
1314

1415
"gorm.io/gorm"
16+
"gorm.io/gorm/clause"
1517
)
1618

1719
var _ biz.UserRepo = (*UserRepo)(nil)
@@ -331,3 +333,43 @@ func (d *UserRepo) GetUserByThirdPartyUserID(ctx context.Context, thirdPartyUser
331333
}
332334
return ret, nil
333335
}
336+
337+
func (d *UserRepo) SaveAccessToken(ctx context.Context, tokenInfo *biz.AccessTokenInfo) error {
338+
userAccessToekn := &model.UserAccessToken{
339+
Model: model.Model{
340+
UID: tokenInfo.UID,
341+
},
342+
UserID: tokenInfo.UserID,
343+
Token: tokenInfo.Token,
344+
ExpiredTime: tokenInfo.ExpiredTime,
345+
}
346+
347+
tx := d.db.Clauses(clause.OnConflict{
348+
Columns: []clause.Column{{Name: "user_id"}},
349+
DoUpdates: clause.Assignments(map[string]interface{}{"token": tokenInfo.Token, "expired_time": tokenInfo.ExpiredTime}),
350+
}).Create(userAccessToekn)
351+
352+
if tx.Error != nil {
353+
return fmt.Errorf("failed to save access token: %v", tx.Error)
354+
}
355+
356+
return nil
357+
}
358+
359+
func (d *UserRepo) GetAccessTokenByUser(ctx context.Context, userUid string) (*biz.AccessTokenInfo, error) {
360+
var userToken *model.UserAccessToken
361+
if err := transaction(d.log, ctx, d.db, func(tx *gorm.DB) error {
362+
if err := tx.First(&userToken, "user_id = ?", userUid).Error; err != nil {
363+
// 未找到记录返回空,不影响获取用户信息的功能
364+
if errors.Is(err, gorm.ErrRecordNotFound) {
365+
return nil
366+
}
367+
return fmt.Errorf("failed to get user access token: %v", err)
368+
}
369+
return nil
370+
}); err != nil {
371+
return nil, err
372+
}
373+
374+
return &biz.AccessTokenInfo{Token: userToken.Token, ExpiredTime: userToken.ExpiredTime}, nil
375+
}

0 commit comments

Comments
 (0)