Skip to content

Commit b40e7b0

Browse files
committed
dev: Create new GetAccessTokenWithSession to minimize calls
1 parent 2928a54 commit b40e7b0

7 files changed

Lines changed: 177 additions & 7 deletions

File tree

config/messages.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8846,6 +8846,15 @@
88468846
"file": "oauth.go"
88478847
}
88488848
},
8849+
"error:pkg/oauth:session_expired": {
8850+
"translations": {
8851+
"en": "session expired"
8852+
},
8853+
"description": {
8854+
"package": "pkg/oauth",
8855+
"file": "storage.go"
8856+
}
8857+
},
88498858
"error:pkg/oauth:token": {
88508859
"translations": {
88518860
"en": "invalid token"

pkg/identityserver/bunstore/oauth_store.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,54 @@ func (s *oauthStore) GetAccessToken(ctx context.Context, id string) (*ttnpb.OAut
698698
return pb, nil
699699
}
700700

701+
func (s *oauthStore) GetAccessTokenWithSession(
702+
ctx context.Context, id string,
703+
) (*ttnpb.OAuthAccessToken, *ttnpb.UserSession, error) {
704+
ctx, span := tracer.StartFromContext(ctx, "GetAccessTokenWithSession")
705+
defer span.End()
706+
707+
model := &AccessToken{}
708+
selectQuery := s.newSelectModel(ctx, model).
709+
Where("token_id = ?", id).
710+
Relation("User", func(q *bun.SelectQuery) *bun.SelectQuery {
711+
return q.Column("account_uid")
712+
}).
713+
Relation("Client", func(q *bun.SelectQuery) *bun.SelectQuery {
714+
return q.Column("client_id")
715+
}).
716+
Relation("UserSession")
717+
718+
if err := selectQuery.Scan(ctx); err != nil {
719+
err = storeutil.WrapDriverError(err)
720+
if errors.IsNotFound(err) {
721+
return nil, nil, store.ErrAccessTokenNotFound.WithAttributes(
722+
"access_token_id", id,
723+
)
724+
}
725+
return nil, nil, err
726+
}
727+
728+
// NOTE: This imposes a limitation on the client's rights if the token's user is the unique support user.
729+
if model.User.Account.UID == ttnpb.SupportUserID {
730+
model.Rights = convertIntSlice[ttnpb.Right, int](ttnpb.AllReadAdminRights.GetRights())
731+
}
732+
733+
pb, err := accessTokenToPB(model, nil, nil)
734+
if err != nil {
735+
return nil, nil, err
736+
}
737+
738+
var sessionPB *ttnpb.UserSession
739+
if model.UserSession != nil {
740+
sessionPB, err = userSessionToPB(model.UserSession, pb.UserIds)
741+
if err != nil {
742+
return nil, nil, err
743+
}
744+
}
745+
746+
return pb, sessionPB, nil
747+
}
748+
701749
func (s *oauthStore) DeleteAccessToken(ctx context.Context, id string) error {
702750
ctx, span := tracer.StartFromContext(ctx, "DeleteAccessToken")
703751
defer span.End()

pkg/identityserver/entity_access.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func (is *IdentityServer) authInfo(ctx context.Context) (info *ttnpb.AuthInfoRes
178178
}
179179
case auth.AccessToken:
180180
fetch = func(ctx context.Context, st store.Store) error {
181-
accessToken, err := st.GetAccessToken(ctx, tokenID)
181+
accessToken, session, err := st.GetAccessTokenWithSession(ctx, tokenID)
182182
if err != nil {
183183
if errors.IsNotFound(err) {
184184
return errTokenNotFound.WithCause(err)
@@ -198,12 +198,8 @@ func (is *IdentityServer) authInfo(ctx context.Context) (info *ttnpb.AuthInfoRes
198198
return errTokenExpired.New()
199199
}
200200
if accessToken.UserSessionId != "" {
201-
session, err := st.GetSession(ctx, accessToken.UserIds, accessToken.UserSessionId)
202-
if err != nil {
203-
if errors.IsNotFound(err) {
204-
return errTokenExpired.WithCause(err)
205-
}
206-
return err
201+
if session == nil {
202+
return errTokenExpired.New()
207203
}
208204
if expiresAt := ttnpb.StdTime(session.ExpiresAt); expiresAt != nil && expiresAt.Before(time.Now()) {
209205
return errTokenExpired.New()

pkg/identityserver/entity_access_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
package identityserver
1616

1717
import (
18+
"context"
1819
"testing"
1920
"time"
2021

22+
"go.thethings.network/lorawan-stack/v3/pkg/auth"
23+
"go.thethings.network/lorawan-stack/v3/pkg/auth/pbkdf2"
2124
"go.thethings.network/lorawan-stack/v3/pkg/errors"
2225
"go.thethings.network/lorawan-stack/v3/pkg/identityserver/storetest"
26+
"go.thethings.network/lorawan-stack/v3/pkg/rpcmetadata"
2327
"go.thethings.network/lorawan-stack/v3/pkg/ttnpb"
2428
"go.thethings.network/lorawan-stack/v3/pkg/util/test"
2529
"go.thethings.network/lorawan-stack/v3/pkg/util/test/assertions/should"
@@ -64,7 +68,86 @@ func TestEntityAccess(t *testing.T) {
6468
storedKey.ExpiresAt = timestamppb.New(time.Now().Add(-10 * time.Minute))
6569
expiredCreds := rpcCreds(expiredKey)
6670

71+
oauthUsr := p.NewUser()
72+
oauthClient := p.NewClient(oauthUsr.GetOrganizationOrUserIdentifiers())
73+
oauthClient.Rights = []ttnpb.Right{ttnpb.Right_RIGHT_USER_ALL}
74+
75+
// Session that is already expired.
76+
expiredSession := p.NewUserSession(oauthUsr.GetIds())
77+
p.UserSessions[len(p.UserSessions)-1].ExpiresAt = timestamppb.New(time.Now().Add(-10 * time.Minute))
78+
79+
// Session that is still valid.
80+
validSession := p.NewUserSession(oauthUsr.GetIds())
81+
p.UserSessions[len(p.UserSessions)-1].ExpiresAt = timestamppb.New(time.Now().Add(10 * time.Minute))
82+
83+
// Generate access token bearer strings. The stored AccessToken is the hashed key.
84+
newBearerAccessToken := func() (bearer, tokenID, hashed string) {
85+
t.Helper()
86+
raw, err := auth.AccessToken.Generate(context.Background(), "")
87+
if err != nil {
88+
t.Fatal(err)
89+
}
90+
_, tokenID, key, err := auth.SplitToken(raw)
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
hashValidator := pbkdf2.Default()
95+
hashValidator.Iterations = 10
96+
hashed, err = auth.Hash(auth.NewContextWithHashValidator(context.Background(), hashValidator), key)
97+
if err != nil {
98+
t.Fatal(err)
99+
}
100+
return raw, tokenID, hashed
101+
}
102+
103+
expiredSessionToken, expiredSessionTokenID, expiredSessionTokenHash := newBearerAccessToken()
104+
missingSessionToken, missingSessionTokenID, missingSessionTokenHash := newBearerAccessToken()
105+
validSessionToken, validSessionTokenID, validSessionTokenHash := newBearerAccessToken()
106+
107+
bearerCreds := func(bearer string) grpc.CallOption {
108+
return grpc.PerRPCCredentials(rpcmetadata.MD{
109+
AuthType: "bearer",
110+
AuthValue: bearer,
111+
AllowInsecure: true,
112+
})
113+
}
114+
67115
testWithIdentityServer(t, func(is *IdentityServer, cc *grpc.ClientConn) {
116+
ctx := test.Context()
117+
if _, err := is.store.CreateAccessToken(ctx, &ttnpb.OAuthAccessToken{
118+
UserIds: oauthUsr.GetIds(),
119+
ClientIds: oauthClient.GetIds(),
120+
UserSessionId: expiredSession.GetSessionId(),
121+
Id: expiredSessionTokenID,
122+
AccessToken: expiredSessionTokenHash,
123+
Rights: oauthClient.GetRights(),
124+
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
125+
}, ""); err != nil {
126+
t.Fatal(err)
127+
}
128+
if _, err := is.store.CreateAccessToken(ctx, &ttnpb.OAuthAccessToken{
129+
UserIds: oauthUsr.GetIds(),
130+
ClientIds: oauthClient.GetIds(),
131+
UserSessionId: "00000000-0000-0000-0000-000000000000",
132+
Id: missingSessionTokenID,
133+
AccessToken: missingSessionTokenHash,
134+
Rights: oauthClient.GetRights(),
135+
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
136+
}, ""); err != nil {
137+
t.Fatal(err)
138+
}
139+
if _, err := is.store.CreateAccessToken(ctx, &ttnpb.OAuthAccessToken{
140+
UserIds: oauthUsr.GetIds(),
141+
ClientIds: oauthClient.GetIds(),
142+
UserSessionId: validSession.GetSessionId(),
143+
Id: validSessionTokenID,
144+
AccessToken: validSessionTokenHash,
145+
Rights: oauthClient.GetRights(),
146+
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
147+
}, ""); err != nil {
148+
t.Fatal(err)
149+
}
150+
68151
is.config.UserRegistration.ContactInfoValidation.Required = true
69152

70153
cli := ttnpb.NewEntityAccessClient(cc)
@@ -166,5 +249,29 @@ func TestEntityAccess(t *testing.T) {
166249
a.So(errors.IsUnauthenticated(err), should.BeTrue)
167250
}
168251
})
252+
253+
t.Run("Access Token with Valid Session", func(t *testing.T) {
254+
a, ctx := test.New(t)
255+
authInfo, err := cli.AuthInfo(ctx, ttnpb.Empty, bearerCreds(validSessionToken))
256+
if a.So(err, should.BeNil) && a.So(authInfo, should.NotBeNil) {
257+
a.So(authInfo.GetOauthAccessToken(), should.NotBeNil)
258+
}
259+
})
260+
261+
t.Run("Access Token with Expired Session", func(t *testing.T) {
262+
a, ctx := test.New(t)
263+
_, err := cli.AuthInfo(ctx, ttnpb.Empty, bearerCreds(expiredSessionToken))
264+
if a.So(err, should.NotBeNil) {
265+
a.So(errors.IsUnauthenticated(err), should.BeTrue)
266+
}
267+
})
268+
269+
t.Run("Access Token with Missing Session", func(t *testing.T) {
270+
a, ctx := test.New(t)
271+
_, err := cli.AuthInfo(ctx, ttnpb.Empty, bearerCreds(missingSessionToken))
272+
if a.So(err, should.NotBeNil) {
273+
a.So(errors.IsUnauthenticated(err), should.BeTrue)
274+
}
275+
})
169276
}, withPrivateTestDatabase(p))
170277
}

pkg/identityserver/identityserver.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ func New(c *component.Component, config *Config) (is *IdentityServer, err error)
160160

161161
is.config.OAuth.CSRFAuthKey = is.GetBaseConfig(is.Context()).HTTP.Cookie.HashKey
162162
is.config.OAuth.UI.FrontendConfig.EnableUserRegistration = is.config.UserRegistration.Enabled
163+
if is.config.UserLogin.SessionTTL < 0 {
164+
log.FromContext(is.Context()).WithField("session_ttl", is.config.UserLogin.SessionTTL).Warn(
165+
"Negative is.user-login.session-ttl; resetting to 0 (sessions never expire)",
166+
)
167+
is.config.UserLogin.SessionTTL = 0
168+
}
163169
is.config.OAuth.Login.SessionTTL = is.config.UserLogin.SessionTTL
164170
is.oauth, err = oauth.NewServer(c, &oauthAppStore{is.store}, is.config.OAuth, GenerateCSPString)
165171
if err != nil {

pkg/identityserver/store/store_interfaces.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ type OAuthStore interface {
310310
ctx context.Context, userIDs *ttnpb.UserIdentifiers, clientIDs *ttnpb.ClientIdentifiers,
311311
) ([]*ttnpb.OAuthAccessToken, error)
312312
GetAccessToken(ctx context.Context, id string) (*ttnpb.OAuthAccessToken, error)
313+
// GetAccessTokenWithSession returns the access token and its linked user session.
314+
// The returned session is nil if the token is not linked to a session or the session no longer exists.
315+
GetAccessTokenWithSession(ctx context.Context, id string) (*ttnpb.OAuthAccessToken, *ttnpb.UserSession, error)
313316
DeleteAccessToken(ctx context.Context, id string) error
314317
}
315318

pkg/webui/locales/ja.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,6 +2871,7 @@
28712871
"error:pkg/oauth:no_access_token": "提供されたトークンはアクセストークンではありません",
28722872
"error:pkg/oauth:no_refresh_token": "提供されたトークンは更新トークンではありません",
28732873
"error:pkg/oauth:parse": "リクエストボディの解析",
2874+
"error:pkg/oauth:session_expired": "",
28742875
"error:pkg/oauth:token": "無効なトークン",
28752876
"error:pkg/oauth:token_mismatch": "更新トークンID `{refresh_token_id}` はアクセストークンID `{access_token_id}` と一致しません",
28762877
"error:pkg/oauth:unauthorized_client": "クライアントがこの手段でトークンを要求することは認証されていません",

0 commit comments

Comments
 (0)