Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frontend/src/lib/hooks/screen-params.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
oidc_prompt?: "none" | "login";
};

const zodScreenParams = z.object({
Expand All @@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
});

export function useScreenParams(params: URLSearchParams): ScreenParams {
Expand Down
26 changes: 21 additions & 5 deletions frontend/src/pages/authorize-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useEffect } from "react";

type Scope = {
id: string;
Expand Down Expand Up @@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);

const authorizeMutation = useMutation({
// TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";

const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket,
Expand All @@ -110,6 +119,12 @@ export const AuthorizePage = () => {
},
});

useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);

if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return (
<Navigate
Expand All @@ -119,7 +134,7 @@ export const AuthorizePage = () => {
);
}

if (!auth.authenticated) {
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />;
}

Expand Down Expand Up @@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)}
<CardFooter className="flex flex-col items-stretch gap-3">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
onClick={() => authorizeMutate()}
loading={authorizePending}
disabled={shouldAutoAuthorize}
>
{t("authorizeTitle")}
</Button>
<Button
onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending}
disabled={authorizePending || shouldAutoAuthorize}
variant="outline"
>
{t("cancelTitle")}
Expand Down
7 changes: 5 additions & 2 deletions frontend/src/pages/login-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ export const LoginPage = () => {

const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
Expand Down Expand Up @@ -196,7 +199,7 @@ export const LoginPage = () => {
};
}, [redirectTimer, redirectButtonTimer]);

if (auth.authenticated) {
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />;
}

Expand Down
102 changes: 84 additions & 18 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
Expand Down Expand Up @@ -69,10 +71,11 @@ type ClientCredentials struct {
}

type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
}

type AuthorizeCompleteRequest struct {
Expand Down Expand Up @@ -167,20 +170,87 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return
}

prompts := controller.oidc.GetPrompt(req.Prompt)

if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}

userContext, err := new(model.UserContext).NewFromGin(c)

if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}

if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}

ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)

queries, err := query.Values(AuthorizeScreenParams{
values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
}

if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}

if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}

if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}

queries, err := query.Values(values)

if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
Expand Down Expand Up @@ -208,16 +278,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)

if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}

if !userContext.Authenticated {
if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
Expand Down Expand Up @@ -425,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}

tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
Expand Down
2 changes: 2 additions & 0 deletions internal/model/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
type UserContext struct {
Authenticated bool
Provider ProviderType
AuthTime int64
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
Expand Down Expand Up @@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
}

switch session.Provider {
Expand Down
42 changes: 38 additions & 4 deletions internal/service/oidc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)

type OIDCPrompt string

const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)

var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}

// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
Expand All @@ -54,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Expand Down Expand Up @@ -117,6 +127,8 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
}

type AuthorizeCodeEntry struct {
Expand All @@ -127,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
AuthTime int64
}

type UsedCodeEntry struct {
Expand Down Expand Up @@ -423,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
}

if req.CodeChallenge != "" {
Expand Down Expand Up @@ -512,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}

func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, auth_time int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()

Expand Down Expand Up @@ -549,6 +563,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Sub: user.Sub,
Iat: createdAt,
Exp: expiresAt,
AuthTime: auth_time,
Name: userInfo.Name,
Email: userInfo.Email,
EmailVerified: userInfo.EmailVerified,
Expand Down Expand Up @@ -578,8 +593,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}

func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, authTime)

if err != nil {
return nil, err
Expand Down Expand Up @@ -660,7 +675,7 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken

idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce, 0) // auth_time is not available during refresh, so we set it to 0

if err != nil {
return nil, err
Expand Down Expand Up @@ -929,5 +944,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil
}

func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}

parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")

for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}

return parsedPromps
}