Skip to content

Commit 880dad8

Browse files
committed
[#2473] move extract groups claim
1 parent a7209bd commit 880dad8

2 files changed

Lines changed: 91 additions & 19 deletions

File tree

backend/server/oidc/controller.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,29 @@ func sanitizeReturnURL(returnURL string) string {
343343
return sanitizedPath
344344
}
345345

346+
// Extracts groups claim from raw claims and returns the groups as slice of strings.
347+
func (ctl *Controller) extractGroupsFromClaim(rawClaims map[string]interface{}) []string {
348+
// Do custom unmarshaling of the groups claim, because we can't be sure
349+
// how the claim is formatted on the OpenID Provider side.
350+
// We should have the groups extracted as slice of strings.
351+
var groups []string
352+
if val, ok := rawClaims[ctl.settings.GroupsClaim]; ok {
353+
switch claim := val.(type) {
354+
case []interface{}:
355+
for _, g := range claim {
356+
if s, ok := g.(string); ok {
357+
groups = append(groups, s)
358+
}
359+
}
360+
case []string:
361+
groups = claim
362+
case string:
363+
groups = []string{claim}
364+
}
365+
}
366+
return groups
367+
}
368+
346369
// Handles OIDC callback endpoint which interprets redirection from OpenID Provider
347370
// after user successfully authenticates at the OP and authorizes Stork as a
348371
// Relying Party. It verifies the response, extracts required parameters and
@@ -431,25 +454,8 @@ func (ctl *Controller) callbackHandler(w http.ResponseWriter, r *http.Request) {
431454
http.Redirect(w, r, authErrorURLPath, http.StatusFound)
432455
return
433456
}
434-
// Do custom unmarshaling of the groups claim, because we can't be sure
435-
// how the claim is formatted on the OpenID Provider side.
436-
// We should have the groups extracted as slice of strings.
437-
if val, ok := rawClaims[ctl.settings.GroupsClaim]; ok {
438-
var groups []string
439-
switch claim := val.(type) {
440-
case []interface{}:
441-
for _, g := range claim {
442-
if s, ok := g.(string); ok {
443-
groups = append(groups, s)
444-
}
445-
}
446-
case []string:
447-
groups = claim
448-
case string:
449-
groups = []string{claim}
450-
}
451-
claims.Groups = groups
452-
}
457+
groups := ctl.extractGroupsFromClaim(rawClaims)
458+
claims.Groups = groups
453459
}
454460
log.Debugf("Claims received during OIDC authentication %+v", claims)
455461

backend/server/oidc/controller_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,3 +911,69 @@ func TestCallbackEndpointAuthorizesUserGroupMappingDisabled(t *testing.T) {
911911
require.Len(t, dbUser.Groups, 1)
912912
require.Equal(t, dbmodel.ReadOnlyGroupID, dbUser.Groups[0].ID)
913913
}
914+
915+
// Test if extractGroupsFromClaim works fine.
916+
func TestExtractGroupsFromClaim(t *testing.T) {
917+
// Arrange
918+
db, _, teardown := dbtest.SetupDatabaseTestCase(t)
919+
defer teardown()
920+
issuerURL, srvTeardown, err := oidctest.PrepareTestOIDCServer()
921+
require.NoError(t, err)
922+
defer srvTeardown()
923+
settings := Settings{
924+
IssuerURL: issuerURL,
925+
ClientID: "clientID",
926+
GroupsClaim: "groups",
927+
MandatoryAllowGroup: "stork-users",
928+
}
929+
controller := NewController(settings, db)
930+
require.NotNil(t, controller)
931+
932+
// Act & Assert
933+
m := make(map[string]interface{})
934+
t.Run("empty claims", func(t *testing.T) {
935+
res := controller.extractGroupsFromClaim(m)
936+
require.Empty(t, res)
937+
})
938+
939+
m["sub"] = "foo"
940+
941+
t.Run("no groups", func(t *testing.T) {
942+
res := controller.extractGroupsFromClaim(m)
943+
require.Empty(t, res)
944+
})
945+
946+
t.Run("slice of strings", func(t *testing.T) {
947+
m["groups"] = []string{"a", "b", "c"}
948+
res := controller.extractGroupsFromClaim(m)
949+
require.NotEmpty(t, res)
950+
require.Len(t, res, 3)
951+
require.Contains(t, res, "a")
952+
require.Contains(t, res, "b")
953+
require.Contains(t, res, "c")
954+
})
955+
956+
t.Run("one string", func(t *testing.T) {
957+
m["groups"] = "groupA"
958+
res := controller.extractGroupsFromClaim(m)
959+
require.NotEmpty(t, res)
960+
require.Len(t, res, 1)
961+
require.Contains(t, res, "groupA")
962+
})
963+
964+
t.Run("slice of interfaces", func(t *testing.T) {
965+
var groupA, groupB, groupC interface{}
966+
groupA = "groupA"
967+
groupB = "groupB"
968+
groupC = "groupC"
969+
var groups []interface{}
970+
groups = append(groups, groupA, groupB, groupC)
971+
m["groups"] = groups
972+
res := controller.extractGroupsFromClaim(m)
973+
require.NotEmpty(t, res)
974+
require.Len(t, res, 3)
975+
require.Contains(t, res, "groupA")
976+
require.Contains(t, res, "groupB")
977+
require.Contains(t, res, "groupC")
978+
})
979+
}

0 commit comments

Comments
 (0)