Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frontend/src/lib/hooks/screen-params.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
oidc_prompt?: "none" | "login";
};

const zodScreenParams = z.object({
Expand All @@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
});

export function useScreenParams(params: URLSearchParams): ScreenParams {
Expand Down
26 changes: 21 additions & 5 deletions frontend/src/pages/authorize-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useEffect } from "react";

type Scope = {
id: string;
Expand Down Expand Up @@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);

const authorizeMutation = useMutation({
// TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";

const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket,
Expand All @@ -110,6 +119,12 @@ export const AuthorizePage = () => {
},
});

useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);

if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return (
<Navigate
Expand All @@ -119,7 +134,7 @@ export const AuthorizePage = () => {
);
}

if (!auth.authenticated) {
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />;
}

Expand Down Expand Up @@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)}
<CardFooter className="flex flex-col items-stretch gap-3">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
onClick={() => authorizeMutate()}
loading={authorizePending}
disabled={shouldAutoAuthorize}
>
{t("authorizeTitle")}
</Button>
<Button
onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending}
disabled={authorizePending || shouldAutoAuthorize}
variant="outline"
>
{t("cancelTitle")}
Expand Down
7 changes: 5 additions & 2 deletions frontend/src/pages/login-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ export const LoginPage = () => {

const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
Comment thread
Rycochet marked this conversation as resolved.
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
Expand Down Expand Up @@ -196,7 +199,7 @@ export const LoginPage = () => {
};
}, [redirectTimer, redirectButtonTimer]);

if (auth.authenticated) {
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />;
}

Expand Down
61 changes: 43 additions & 18 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ type ClientCredentials struct {
}

type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
}

type AuthorizeCompleteRequest struct {
Expand Down Expand Up @@ -167,20 +168,48 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return
}

prompt := controller.oidc.GetPrompt(req.Prompt)

userContext, err := new(model.UserContext).NewFromGin(c)

if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}

if (err != nil || !userContext.Authenticated) && prompt == service.OIDCPromptNone {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}

ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)

queries, err := query.Values(AuthorizeScreenParams{
values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
OIDCPrompt: prompt,
}

queries, err := query.Values(values)

if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
Expand Down Expand Up @@ -208,16 +237,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)

if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}

if !userContext.Authenticated {
if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
Expand Down Expand Up @@ -425,7 +450,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}

tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
Expand Down
2 changes: 2 additions & 0 deletions internal/model/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
type UserContext struct {
Authenticated bool
Provider ProviderType
AuthTime int64
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
Expand Down Expand Up @@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
}

switch session.Provider {
Expand Down
40 changes: 36 additions & 4 deletions internal/service/oidc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)

type OIDCPrompt string

const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)

var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}

// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
Expand All @@ -54,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Expand Down Expand Up @@ -117,6 +127,7 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
}

type AuthorizeCodeEntry struct {
Expand All @@ -127,6 +138,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
AuthTime int64
}

type UsedCodeEntry struct {
Expand Down Expand Up @@ -423,6 +435,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
}

if req.CodeChallenge != "" {
Expand Down Expand Up @@ -512,7 +525,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}

func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, auth_time int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()

Expand Down Expand Up @@ -549,6 +562,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Sub: user.Sub,
Iat: createdAt,
Exp: expiresAt,
AuthTime: auth_time,
Name: userInfo.Name,
Email: userInfo.Email,
EmailVerified: userInfo.EmailVerified,
Expand Down Expand Up @@ -578,8 +592,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}

func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, authTime)

if err != nil {
return nil, err
Expand Down Expand Up @@ -660,7 +674,7 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken

idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce, 0) // auth_time is not available during refresh, so we set it to 0
Comment thread
steveiliop56 marked this conversation as resolved.
Outdated

if err != nil {
return nil, err
Expand Down Expand Up @@ -929,5 +943,23 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil
}

// Return the first prompt in the list of prompts, or an empty string if no prompt is specified
func (service *OIDCService) GetPrompt(prompt string) OIDCPrompt {
if prompt == "" {
return ""
}

prompts := strings.Split(prompt, " ")

for _, p := range prompts {
if slices.Contains(SupportedPrompts, p) {
return OIDCPrompt(p)
}
}

return ""
}
Comment thread
steveiliop56 marked this conversation as resolved.