diff --git a/pkg/oidc/introspection.go b/pkg/oidc/introspection.go index 517f6c93..34377529 100644 --- a/pkg/oidc/introspection.go +++ b/pkg/oidc/introspection.go @@ -16,21 +16,21 @@ type ClientAssertionParams struct { // https://www.rfc-editor.org/rfc/rfc7662.html#section-2.2. // https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims. type IntrospectionResponse struct { - Active bool `json:"active"` - Scope SpaceDelimitedArray `json:"scope,omitempty"` - ClientID string `json:"client_id,omitempty"` - TokenType string `json:"token_type,omitempty"` - Expiration Time `json:"exp,omitempty"` - IssuedAt Time `json:"iat,omitempty"` - AuthTime Time `json:"auth_time,omitempty"` - NotBefore Time `json:"nbf,omitempty"` - Subject string `json:"sub,omitempty"` - Audience Audience `json:"aud,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - Issuer string `json:"iss,omitempty"` - JWTID string `json:"jti,omitempty"` - Username string `json:"username,omitempty"` - Actor *ActorClaims `json:"act,omitempty"` + Active bool `json:"active"` + Scope SpaceDelimitedArray `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + TokenType string `json:"token_type,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + AuthenticationMethodsReferences AuthenticationMethodsReferences `json:"amr,omitempty"` + Issuer string `json:"iss,omitempty"` + JWTID string `json:"jti,omitempty"` + Username string `json:"username,omitempty"` + Actor *ActorClaims `json:"act,omitempty"` UserInfoProfile UserInfoEmail UserInfoPhone diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index d2b6f6d4..e1a0ae93 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -35,20 +35,20 @@ type Tokens[C IDClaims] struct { // TokenClaims implements the Claims interface, // and can be used to extend larger claim types by embedding. type TokenClaims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Audience Audience `json:"aud,omitempty"` - Expiration Time `json:"exp,omitempty"` - IssuedAt Time `json:"iat,omitempty"` - AuthTime Time `json:"auth_time,omitempty"` - NotBefore Time `json:"nbf,omitempty"` - Nonce string `json:"nonce,omitempty"` - AuthenticationContextClassReference string `json:"acr,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - AuthorizedParty string `json:"azp,omitempty"` - ClientID string `json:"client_id,omitempty"` - JWTID string `json:"jti,omitempty"` - Actor *ActorClaims `json:"act,omitempty"` + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences AuthenticationMethodsReferences `json:"amr,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + ClientID string `json:"client_id,omitempty"` + JWTID string `json:"jti,omitempty"` + Actor *ActorClaims `json:"act,omitempty"` // Additional information set by this framework SignatureAlg jose.SignatureAlgorithm `json:"-"` diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index ec8c126a..4e288525 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -1,6 +1,7 @@ package oidc import ( + "encoding/json" "testing" "time" @@ -243,6 +244,20 @@ func TestIDTokenClaims_GetUserInfo(t *testing.T) { assert.Equal(t, want, got) } +func TestIDTokenClaims_UnmarshalJSON_StringAMR(t *testing.T) { + var got IDTokenClaims + err := json.Unmarshal([]byte(`{"iss":"zitadel","sub":"hello@me.com","aud":"foo","exp":12345,"iat":12000,"amr":"pwd"}`), &got) + assert.NoError(t, err) + assert.Equal(t, AuthenticationMethodsReferences{"pwd"}, got.AuthenticationMethodsReferences) +} + +func TestIntrospectionResponse_UnmarshalJSON_StringAMR(t *testing.T) { + var got IntrospectionResponse + err := json.Unmarshal([]byte(`{"active":true,"sub":"hello@me.com","amr":"pwd"}`), &got) + assert.NoError(t, err) + assert.Equal(t, AuthenticationMethodsReferences{"pwd"}, got.AuthenticationMethodsReferences) +} + func TestNewLogoutTokenClaims(t *testing.T) { want := &LogoutTokenClaims{ Issuer: "zitadel", diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 8b6010fc..4c3df88b 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -35,6 +35,31 @@ func (a *Audience) UnmarshalJSON(text []byte) error { return nil } +type AuthenticationMethodsReferences []string + +func (a *AuthenticationMethodsReferences) UnmarshalJSON(data []byte) error { + var dst any + if err := json.Unmarshal(data, &dst); err != nil { + return fmt.Errorf("oidc amr: %w", err) + } + + switch v := dst.(type) { + case nil: + *a = nil + case string: + *a = AuthenticationMethodsReferences{v} + case []any: + refs, err := gu.AssertInterfaces[string](v) + if err != nil { + return fmt.Errorf("oidc amr: %w", err) + } + *a = AuthenticationMethodsReferences(refs) + default: + return fmt.Errorf("oidc amr: unsupported type: %T", v) + } + return nil +} + type Display string func (d *Display) UnmarshalText(text []byte) error { diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 1219f6ec..4b983cca 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -70,6 +70,50 @@ func TestAudience_UnmarshalText(t *testing.T) { } } +func TestAuthenticationMethodsReferences_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want AuthenticationMethodsReferences + wantErr bool + }{ + { + name: "single auth method", + input: `{"amr":"pwd"}`, + want: AuthenticationMethodsReferences{"pwd"}, + }, + { + name: "multiple auth methods", + input: `{"amr":["pwd","mfa"]}`, + want: AuthenticationMethodsReferences{"pwd", "mfa"}, + }, + { + name: "null", + input: `{"amr":null}`, + }, + { + name: "invalid type", + input: `{"amr":1}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got struct { + AMR AuthenticationMethodsReferences `json:"amr,omitempty"` + } + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got.AMR) + }) + } +} + func TestDisplay_UnmarshalText(t *testing.T) { type args struct { text []byte