Skip to content

Commit d7c5b7d

Browse files
mbilskiclaude
andauthored
Preserve unrecognized token response fields through marshal roundtrip (#150)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9db96f9 commit d7c5b7d

2 files changed

Lines changed: 118 additions & 13 deletions

File tree

internal/oauth2/oauth2.go

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"log"
11+
"maps"
1112
"net/http"
1213
"net/url"
1314
"strconv"
@@ -341,20 +342,56 @@ func WaitForCallback(clientConfig ClientConfig, serverConfig ServerConfig, hc *h
341342
}
342343

343344
type TokenResponse struct {
344-
AccessToken string `json:"access_token,omitempty"`
345-
ExpiresIn FlexibleInt64 `json:"expires_in,omitempty"`
346-
IDToken string `json:"id_token,omitempty"`
347-
IssuedTokenType string `json:"issued_token_type,omitempty"`
348-
RefreshToken string `json:"refresh_token,omitempty"`
349-
Scope string `json:"scope,omitempty"`
350-
TokenType string `json:"token_type,omitempty"`
351-
AuthorizationDetails []map[string]interface{} `json:"authorization_details,omitempty"`
345+
AccessToken string `json:"access_token,omitempty"`
346+
ExpiresIn FlexibleInt64 `json:"expires_in,omitempty"`
347+
IDToken string `json:"id_token,omitempty"`
348+
IssuedTokenType string `json:"issued_token_type,omitempty"`
349+
RefreshToken string `json:"refresh_token,omitempty"`
350+
Scope string `json:"scope,omitempty"`
351+
TokenType string `json:"token_type,omitempty"`
352+
AuthorizationDetails []map[string]any `json:"authorization_details,omitempty"`
353+
354+
raw map[string]json.RawMessage
355+
}
356+
357+
type tokenResponseAlias TokenResponse
358+
359+
func (t *TokenResponse) UnmarshalJSON(data []byte) error {
360+
var typed tokenResponseAlias
361+
362+
if err := json.Unmarshal(data, &typed); err != nil {
363+
return err
364+
}
365+
366+
*t = TokenResponse(typed)
367+
368+
return json.Unmarshal(data, &t.raw)
369+
}
370+
371+
func (t TokenResponse) MarshalJSON() ([]byte, error) {
372+
var typedMap map[string]json.RawMessage
373+
374+
typed, err := json.Marshal(tokenResponseAlias(t))
375+
if err != nil {
376+
return nil, err
377+
}
378+
379+
if len(t.raw) == 0 {
380+
return typed, nil
381+
}
382+
383+
if err := json.Unmarshal(typed, &typedMap); err != nil {
384+
return nil, err
385+
}
386+
387+
out := make(map[string]json.RawMessage, len(t.raw))
388+
389+
maps.Copy(out, t.raw)
390+
maps.Copy(out, typedMap)
391+
392+
return json.Marshal(out)
352393
}
353394

354-
// FlexibleInt64 is a type that can be unmarshaled from a JSON number or
355-
// string. This was added to support the `expires_in` field in the token
356-
// response. Typically it is expressed as a JSON number, but at least
357-
// login.microsoft.com returns the number as a string.
358395
type FlexibleInt64 int64
359396

360397
func (f *FlexibleInt64) UnmarshalJSON(b []byte) error {
@@ -365,6 +402,7 @@ func (f *FlexibleInt64) UnmarshalJSON(b []byte) error {
365402
// check if we have a number in a string, and parse it if so
366403
if b[0] == '"' {
367404
var s string
405+
368406
if err := json.Unmarshal(b, &s); err != nil {
369407
return err
370408
}
@@ -378,13 +416,14 @@ func (f *FlexibleInt64) UnmarshalJSON(b []byte) error {
378416
return nil
379417
}
380418

381-
// finally we assume that we have a number that's not wrapped in a string
382419
var i int64
420+
383421
if err := json.Unmarshal(b, &i); err != nil {
384422
return err
385423
}
386424

387425
*f = FlexibleInt64(i)
426+
388427
return nil
389428
}
390429

internal/oauth2/oauth2_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,72 @@ import (
1010
"github.com/cloudentity/oauth2c/internal/oauth2"
1111
)
1212

13+
func TestTokenResponseExtraFields(t *testing.T) {
14+
body := []byte(`{
15+
"access_token": "token",
16+
"expires_in": 3600,
17+
"token_type": "Bearer",
18+
"email": "me@email.com",
19+
"environment_id": "envid-token-here",
20+
"legal_entity_name": "oauth environment provider"
21+
}`)
22+
23+
var resp oauth2.TokenResponse
24+
require.NoError(t, json.Unmarshal(body, &resp))
25+
26+
require.Equal(t, "token", resp.AccessToken)
27+
require.Equal(t, oauth2.FlexibleInt64(3600), resp.ExpiresIn)
28+
require.Equal(t, "Bearer", resp.TokenType)
29+
30+
out, err := json.Marshal(resp)
31+
require.NoError(t, err)
32+
33+
var roundTrip map[string]any
34+
require.NoError(t, json.Unmarshal(out, &roundTrip))
35+
36+
require.Equal(t, "token", roundTrip["access_token"])
37+
require.Equal(t, "Bearer", roundTrip["token_type"])
38+
require.Equal(t, "me@email.com", roundTrip["email"])
39+
require.Equal(t, "envid-token-here", roundTrip["environment_id"])
40+
require.Equal(t, "oauth environment provider", roundTrip["legal_entity_name"])
41+
require.NotContains(t, roundTrip, "raw")
42+
}
43+
44+
func TestTokenResponseNoExtraFields(t *testing.T) {
45+
body := []byte(`{"access_token": "token", "expires_in": 3600, "token_type": "Bearer"}`)
46+
47+
var resp oauth2.TokenResponse
48+
require.NoError(t, json.Unmarshal(body, &resp))
49+
50+
out, err := json.Marshal(resp)
51+
require.NoError(t, err)
52+
53+
var roundTrip map[string]any
54+
require.NoError(t, json.Unmarshal(out, &roundTrip))
55+
56+
require.Equal(t, "token", roundTrip["access_token"])
57+
require.Equal(t, "Bearer", roundTrip["token_type"])
58+
require.Len(t, roundTrip, 3)
59+
}
60+
61+
func TestTokenResponseTypedFieldsWinOnConflict(t *testing.T) {
62+
body := []byte(`{"access_token": "from-server", "expires_in": "3600"}`)
63+
64+
var resp oauth2.TokenResponse
65+
require.NoError(t, json.Unmarshal(body, &resp))
66+
67+
resp.AccessToken = "overridden"
68+
69+
out, err := json.Marshal(resp)
70+
require.NoError(t, err)
71+
72+
var roundTrip map[string]any
73+
require.NoError(t, json.Unmarshal(out, &roundTrip))
74+
75+
require.Equal(t, "overridden", roundTrip["access_token"])
76+
require.EqualValues(t, 3600, roundTrip["expires_in"])
77+
}
78+
1379
func TestUnmarshalExpires(t *testing.T) {
1480
tests := map[string]struct {
1581
bytes []byte

0 commit comments

Comments
 (0)