Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/service/generic_oauth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/steveiliop56/tinyauth/internal/config"
Expand Down Expand Up @@ -124,9 +125,50 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
return user, err
}

// If the userinfo endpoint did not return groups (e.g. Microsoft Entra ID),
// try to extract them from the ID token which may contain a "groups" claim.
if user.Groups == nil {
if idTokenStr, ok := generic.token.Extra("id_token").(string); ok && idTokenStr != "" {
if groups, err := parseIDTokenGroups(idTokenStr); err == nil {
tlog.App.Debug().Msg("Extracted groups from ID token (userinfo had none)")
user.Groups = groups
} else {
tlog.App.Debug().Err(err).Msg("Could not extract groups from ID token")
}
}
}

return user, nil
}

// parseIDTokenGroups decodes the payload of a JWT ID token and extracts the
// "groups" claim. The token signature is not verified here because the token
// was received directly from the token endpoint over TLS in the same request.
func parseIDTokenGroups(idToken string) (any, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}

var claims struct {
Groups any `json:"groups"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT payload: %w", err)
}

if claims.Groups == nil {
return nil, fmt.Errorf("no groups claim in ID token")
}

return claims.Groups, nil
}

func (generic *GenericOAuthService) GetName() string {
return generic.name
}