Skip to content

Commit 7f18b45

Browse files
authored
feat: support for the prompt parameter in the oidc flow (#948)
1 parent 6ccc894 commit 7f18b45

6 files changed

Lines changed: 131 additions & 29 deletions

File tree

frontend/src/lib/hooks/screen-params.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ type ScreenParams = {
66
oidc_ticket?: string;
77
oidc_scope?: string;
88
oidc_name?: string;
9+
oidc_prompt?: "none" | "login";
910
};
1011

1112
const zodScreenParams = z.object({
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
1415
oidc_ticket: z.string().optional(),
1516
oidc_scope: z.string().optional(),
1617
oidc_name: z.string().optional(),
18+
oidc_prompt: z.enum(["none", "login"]).optional(),
1719
});
1820

1921
export function useScreenParams(params: URLSearchParams): ScreenParams {

frontend/src/pages/authorize-page.tsx

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
recompileScreenParams,
2626
useScreenParams,
2727
} from "@/lib/hooks/screen-params";
28+
import { useEffect } from "react";
2829

2930
type Scope = {
3031
id: string;
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
9091
const isOidc = screenParams.login_for === "oidc";
9192
const compiledParams = recompileScreenParams(screenParams);
9293

93-
const authorizeMutation = useMutation({
94+
// TODO: maybe a better way to do this
95+
const shouldAutoAuthorize =
96+
auth.authenticated &&
97+
isOidc &&
98+
screenParams.oidc_ticket !== undefined &&
99+
screenParams.oidc_scope !== undefined &&
100+
screenParams.oidc_prompt === "none";
101+
102+
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
94103
mutationFn: () => {
95104
return axios.post("/api/oidc/authorize-complete", {
96105
ticket: screenParams.oidc_ticket,
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
110119
},
111120
});
112121

122+
useEffect(() => {
123+
if (shouldAutoAuthorize) {
124+
authorizeMutate();
125+
}
126+
}, [shouldAutoAuthorize, authorizeMutate]);
127+
113128
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
114129
return (
115130
<Navigate
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
119134
);
120135
}
121136

122-
if (!auth.authenticated) {
137+
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
123138
return <Navigate to={`/login${compiledParams}`} replace />;
124139
}
125140

@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
168183
)}
169184
<CardFooter className="flex flex-col items-stretch gap-3">
170185
<Button
171-
onClick={() => authorizeMutation.mutate()}
172-
loading={authorizeMutation.isPending}
186+
onClick={() => authorizeMutate()}
187+
loading={authorizePending}
188+
disabled={shouldAutoAuthorize}
173189
>
174190
{t("authorizeTitle")}
175191
</Button>
176192
<Button
177193
onClick={() => navigate(`/logout${compiledParams}`)}
178-
disabled={authorizeMutation.isPending}
194+
disabled={authorizePending || shouldAutoAuthorize}
179195
variant="outline"
180196
>
181197
{t("cancelTitle")}

frontend/src/pages/login-page.tsx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ export const LoginPage = () => {
6363

6464
const searchParams = new URLSearchParams(search);
6565
const screenParams = useScreenParams(searchParams);
66-
const compiledParams = recompileScreenParams(screenParams);
66+
const compiledParams = recompileScreenParams({
67+
...screenParams,
68+
oidc_prompt: undefined,
69+
});
6770
const loginForUrl = useLoginFor({
6871
login_for: screenParams.login_for,
6972
compiledParams,
@@ -196,7 +199,7 @@ export const LoginPage = () => {
196199
};
197200
}, [redirectTimer, redirectButtonTimer]);
198201

199-
if (auth.authenticated) {
202+
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
200203
return <Navigate to={loginForUrl} replace />;
201204
}
202205

internal/controller/oidc_controller.go

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ type ClientCredentials struct {
6969
}
7070

7171
type AuthorizeScreenParams struct {
72-
LoginFor FrontendLoginFor `url:"login_for"`
73-
OIDCTicket string `url:"oidc_ticket"`
74-
OIDCScope string `url:"oidc_scope"`
75-
OIDCName string `url:"oidc_name"`
72+
LoginFor FrontendLoginFor `url:"login_for"`
73+
OIDCTicket string `url:"oidc_ticket"`
74+
OIDCScope string `url:"oidc_scope"`
75+
OIDCName string `url:"oidc_name"`
76+
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
7677
}
7778

7879
type AuthorizeCompleteRequest struct {
@@ -167,20 +168,65 @@ func (controller *OIDCController) authorize(c *gin.Context) {
167168
return
168169
}
169170

171+
prompts := controller.oidc.GetPrompt(req.Prompt)
172+
173+
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
174+
controller.authorizeError(c, authorizeErrorParams{
175+
err: errors.New("invalid prompt"),
176+
reason: "Invalid prompt",
177+
reasonPublic: "The prompt parameters are invalid",
178+
callback: req.RedirectURI,
179+
callbackError: "invalid_request",
180+
state: req.State,
181+
})
182+
return
183+
}
184+
185+
userContext, err := new(model.UserContext).NewFromGin(c)
186+
187+
if err != nil {
188+
if !errors.Is(err, model.ErrUserContextNotFound) {
189+
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
190+
}
191+
}
192+
193+
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
194+
controller.authorizeError(c, authorizeErrorParams{
195+
err: errors.New("user not logged in"),
196+
reason: "User not logged in",
197+
reasonPublic: "The user is not logged in",
198+
callback: req.RedirectURI,
199+
callbackError: "login_required",
200+
state: req.State,
201+
})
202+
return
203+
}
204+
170205
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
171206

172-
queries, err := query.Values(AuthorizeScreenParams{
207+
values := AuthorizeScreenParams{
173208
LoginFor: FrontendLoginForOIDC,
174209
OIDCTicket: ticket,
175210
OIDCScope: req.Scope,
176211
OIDCName: client.Name,
177-
})
212+
}
213+
214+
if slices.Contains(prompts, service.OIDCPromptLogin) {
215+
values.OIDCPrompt = service.OIDCPromptLogin
216+
} else if slices.Contains(prompts, service.OIDCPromptNone) {
217+
values.OIDCPrompt = service.OIDCPromptNone
218+
}
219+
220+
queries, err := query.Values(values)
178221

179222
if err != nil {
180223
controller.authorizeError(c, authorizeErrorParams{
181-
err: err,
182-
reason: "Failed to compile authorize queries",
183-
reasonPublic: "An internal error occured while processing your request",
224+
err: err,
225+
reason: "Failed to compile authorize queries",
226+
reasonPublic: "An internal error occured while processing your request",
227+
callback: req.RedirectURI,
228+
callbackError: "server_error",
229+
state: req.State,
184230
})
185231
return
186232
}
@@ -208,16 +254,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
208254
userContext, err := new(model.UserContext).NewFromGin(c)
209255

210256
if err != nil {
211-
controller.authorizeError(c, authorizeErrorParams{
212-
err: err,
213-
reason: "Failed to get user context",
214-
reasonPublic: "User is not logged in or the session is invalid",
215-
json: true,
216-
})
217-
return
257+
if !errors.Is(err, model.ErrUserContextNotFound) {
258+
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
259+
}
218260
}
219261

220-
if !userContext.Authenticated {
262+
if err != nil || !userContext.Authenticated {
221263
controller.authorizeError(c, authorizeErrorParams{
222264
err: errors.New("err user not logged in"),
223265
reason: "User not logged in",
@@ -425,7 +467,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
425467
return
426468
}
427469

428-
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
470+
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
429471

430472
if err != nil {
431473
controller.log.App.Error().Err(err).Msg("Failed to generate access token")

internal/model/context.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const (
2525
type UserContext struct {
2626
Authenticated bool
2727
Provider ProviderType
28+
AuthTime int64
2829
Local *LocalContext
2930
OAuth *OAuthContext
3031
LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
110111
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
111112
*c = UserContext{
112113
Authenticated: !session.TotpPending,
114+
AuthTime: session.CreatedAt,
113115
}
114116

115117
switch session.Provider {

internal/service/oidc_service.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ var (
4444
ErrInvalidClient = errors.New("invalid_client")
4545
)
4646

47+
type OIDCPrompt string
48+
49+
const (
50+
OIDCPromptLogin OIDCPrompt = "login"
51+
OIDCPromptNone OIDCPrompt = "none"
52+
)
53+
54+
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
55+
4756
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
4857
// it has became a "standard" and apps are looking for the claims in the ID tokens
4958
// instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -54,6 +63,7 @@ type ClaimSet struct {
5463
Sub string `json:"sub"`
5564
Iat int64 `json:"iat"`
5665
Exp int64 `json:"exp"`
66+
AuthTime int64 `json:"auth_time,omitempty"`
5767
Name string `json:"name,omitempty"`
5868
GivenName string `json:"given_name,omitempty"`
5969
FamilyName string `json:"family_name,omitempty"`
@@ -117,6 +127,7 @@ type AuthorizeRequest struct {
117127
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
118128
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
119129
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
130+
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
120131
}
121132

122133
type AuthorizeCodeEntry struct {
@@ -127,6 +138,7 @@ type AuthorizeCodeEntry struct {
127138
Nonce string
128139
CodeChallenge string
129140
Userinfo UserinfoResponse
141+
AuthTime int64
130142
}
131143

132144
type UsedCodeEntry struct {
@@ -423,6 +435,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
423435
ClientID: req.ClientID,
424436
Nonce: req.Nonce,
425437
Userinfo: service.userinfoFromContext(userContext, sub),
438+
AuthTime: userContext.AuthTime,
426439
}
427440

428441
if req.CodeChallenge != "" {
@@ -512,7 +525,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
512525
return &entry, true
513526
}
514527

515-
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
528+
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
516529
createdAt := time.Now().Unix()
517530
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
518531

@@ -557,6 +570,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
557570
Nonce: nonce,
558571
}
559572

573+
if authTime != nil {
574+
claims.AuthTime = *authTime
575+
}
576+
560577
payload, err := json.Marshal(claims)
561578

562579
if err != nil {
@@ -578,8 +595,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
578595
return token, nil
579596
}
580597

581-
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
582-
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
598+
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
599+
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
583600

584601
if err != nil {
585602
return nil, err
@@ -658,9 +675,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
658675
return nil, err
659676
}
660677

678+
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
661679
idToken, err := service.generateIDToken(model.OIDCClientConfig{
662680
ClientID: entry.ClientID,
663-
}, userInfo, entry.Scope, entry.Nonce)
681+
}, userInfo, entry.Scope, entry.Nonce, nil)
664682

665683
if err != nil {
666684
return nil, err
@@ -929,5 +947,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
929947
Nonce: get("nonce"),
930948
CodeChallenge: get("code_challenge"),
931949
CodeChallengeMethod: get("code_challenge_method"),
950+
Prompt: get("prompt"),
932951
}, nil
933952
}
953+
954+
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
955+
if prompt == "" {
956+
return []OIDCPrompt{}
957+
}
958+
959+
parsedPromps := make([]OIDCPrompt, 0)
960+
prompts := strings.SplitSeq(prompt, " ")
961+
962+
for p := range prompts {
963+
if !slices.Contains(SupportedPrompts, p) {
964+
continue
965+
}
966+
parsedPromps = append(parsedPromps, OIDCPrompt(p))
967+
}
968+
969+
return parsedPromps
970+
}

0 commit comments

Comments
 (0)