Skip to content

Commit 165197e

Browse files
authored
feat: add pkce support to oidc server (#766)
* feat: add pkce support to oidc server * tests: add test cases for pkce * fix: review comments * chore: remove debug line * chore: remove simple logger from testing * tests: add test for invalid challenge method * chore: fix typo
1 parent 431cd33 commit 165197e

18 files changed

Lines changed: 350 additions & 39 deletions

frontend/src/lib/hooks/oidc.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ export type OIDCValues = {
55
redirect_uri: string;
66
state: string;
77
nonce: string;
8+
code_challenge: string;
9+
code_challenge_method: string;
810
};
911

1012
interface IuseOIDCParams {
@@ -14,7 +16,12 @@ interface IuseOIDCParams {
1416
missingParams: string[];
1517
}
1618

17-
const optionalParams: string[] = ["state", "nonce"];
19+
const optionalParams: string[] = [
20+
"state",
21+
"nonce",
22+
"code_challenge",
23+
"code_challenge_method",
24+
];
1825

1926
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
2027
let compiled: string = "";
@@ -28,6 +35,8 @@ export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
2835
redirect_uri: params.get("redirect_uri") ?? "",
2936
state: params.get("state") ?? "",
3037
nonce: params.get("nonce") ?? "",
38+
code_challenge: params.get("code_challenge") ?? "",
39+
code_challenge_method: params.get("code_challenge_method") ?? "",
3140
};
3241

3342
for (const key of Object.keys(values)) {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";

internal/controller/context_controller_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import (
1010
"github.com/steveiliop56/tinyauth/internal/config"
1111
"github.com/steveiliop56/tinyauth/internal/controller"
1212
"github.com/steveiliop56/tinyauth/internal/utils"
13+
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
1314
"github.com/stretchr/testify/assert"
1415
)
1516

1617
func TestContextController(t *testing.T) {
18+
tlog.NewTestLogger().Init()
1719
controllerConfig := controller.ContextControllerConfig{
1820
Providers: []controller.Provider{
1921
{

internal/controller/health_controller_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ import (
88

99
"github.com/gin-gonic/gin"
1010
"github.com/steveiliop56/tinyauth/internal/controller"
11+
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
1112
"github.com/stretchr/testify/assert"
1213
)
1314

1415
func TestHealthController(t *testing.T) {
16+
tlog.NewTestLogger().Init()
1517
tests := []struct {
1618
description string
1719
path string

internal/controller/oidc_controller.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type TokenRequest struct {
3434
RefreshToken string `form:"refresh_token" url:"refresh_token"`
3535
ClientSecret string `form:"client_secret" url:"client_secret"`
3636
ClientID string `form:"client_id" url:"client_id"`
37+
CodeVerifier string `form:"code_verifier" url:"code_verifier"`
3738
}
3839

3940
type CallbackError struct {
@@ -308,6 +309,16 @@ func (controller *OIDCController) Token(c *gin.Context) {
308309
return
309310
}
310311

312+
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
313+
314+
if !ok {
315+
tlog.App.Warn().Msg("PKCE validation failed")
316+
c.JSON(400, gin.H{
317+
"error": "invalid_grant",
318+
})
319+
return
320+
}
321+
311322
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
312323

313324
if err != nil {

internal/controller/oidc_controller_test.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package controller_test
22

33
import (
4+
"crypto/sha256"
5+
"encoding/base64"
46
"encoding/json"
57
"net/http/httptest"
68
"net/url"
@@ -15,11 +17,13 @@ import (
1517
"github.com/steveiliop56/tinyauth/internal/controller"
1618
"github.com/steveiliop56/tinyauth/internal/repository"
1719
"github.com/steveiliop56/tinyauth/internal/service"
20+
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
1821
"github.com/stretchr/testify/assert"
1922
"github.com/stretchr/testify/require"
2023
)
2124

2225
func TestOIDCController(t *testing.T) {
26+
tlog.NewTestLogger().Init()
2327
tempDir := t.TempDir()
2428

2529
oidcServiceCfg := service.OIDCServiceConfig{
@@ -431,6 +435,227 @@ func TestOIDCController(t *testing.T) {
431435
assert.False(t, ok, "Did not expect email claim in userinfo response")
432436
},
433437
},
438+
{
439+
description: "Ensure plain PKCE succeeds",
440+
middlewares: []gin.HandlerFunc{
441+
simpleCtx,
442+
},
443+
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
444+
reqBody := service.AuthorizeRequest{
445+
Scope: "openid",
446+
ResponseType: "code",
447+
ClientID: "some-client-id",
448+
RedirectURI: "https://test.example.com/callback",
449+
State: "some-state",
450+
Nonce: "some-nonce",
451+
CodeChallenge: "some-challenge",
452+
// Not setting a code challenge method should default to "plain"
453+
CodeChallengeMethod: "",
454+
}
455+
reqBodyBytes, err := json.Marshal(reqBody)
456+
assert.NoError(t, err)
457+
458+
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
459+
req.Header.Set("Content-Type", "application/json")
460+
router.ServeHTTP(recorder, req)
461+
assert.Equal(t, 200, recorder.Code)
462+
463+
var res map[string]any
464+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
465+
assert.NoError(t, err)
466+
467+
redirectURI := res["redirect_uri"].(string)
468+
url, err := url.Parse(redirectURI)
469+
assert.NoError(t, err)
470+
471+
queryParams := url.Query()
472+
assert.Equal(t, queryParams.Get("state"), "some-state")
473+
474+
code := queryParams.Get("code")
475+
assert.NotEmpty(t, code)
476+
477+
// Now exchange the code for a token
478+
recorder = httptest.NewRecorder()
479+
tokenReqBody := controller.TokenRequest{
480+
GrantType: "authorization_code",
481+
Code: code,
482+
RedirectURI: "https://test.example.com/callback",
483+
CodeVerifier: "some-challenge",
484+
}
485+
reqBodyEncoded, err := query.Values(tokenReqBody)
486+
assert.NoError(t, err)
487+
488+
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
489+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
490+
req.SetBasicAuth("some-client-id", "some-client-secret")
491+
router.ServeHTTP(recorder, req)
492+
493+
assert.Equal(t, 200, recorder.Code)
494+
},
495+
},
496+
{
497+
description: "Ensure S256 PKCE succeeds",
498+
middlewares: []gin.HandlerFunc{
499+
simpleCtx,
500+
},
501+
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
502+
hasher := sha256.New()
503+
hasher.Write([]byte("some-challenge"))
504+
codeChallenge := hasher.Sum(nil)
505+
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
506+
reqBody := service.AuthorizeRequest{
507+
Scope: "openid",
508+
ResponseType: "code",
509+
ClientID: "some-client-id",
510+
RedirectURI: "https://test.example.com/callback",
511+
State: "some-state",
512+
Nonce: "some-nonce",
513+
CodeChallenge: codeChallengeEncoded,
514+
CodeChallengeMethod: "S256",
515+
}
516+
reqBodyBytes, err := json.Marshal(reqBody)
517+
assert.NoError(t, err)
518+
519+
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
520+
req.Header.Set("Content-Type", "application/json")
521+
router.ServeHTTP(recorder, req)
522+
assert.Equal(t, 200, recorder.Code)
523+
524+
var res map[string]any
525+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
526+
assert.NoError(t, err)
527+
528+
redirectURI := res["redirect_uri"].(string)
529+
url, err := url.Parse(redirectURI)
530+
assert.NoError(t, err)
531+
532+
queryParams := url.Query()
533+
assert.Equal(t, queryParams.Get("state"), "some-state")
534+
535+
code := queryParams.Get("code")
536+
assert.NotEmpty(t, code)
537+
538+
// Now exchange the code for a token
539+
recorder = httptest.NewRecorder()
540+
tokenReqBody := controller.TokenRequest{
541+
GrantType: "authorization_code",
542+
Code: code,
543+
RedirectURI: "https://test.example.com/callback",
544+
CodeVerifier: "some-challenge",
545+
}
546+
reqBodyEncoded, err := query.Values(tokenReqBody)
547+
assert.NoError(t, err)
548+
549+
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
550+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
551+
req.SetBasicAuth("some-client-id", "some-client-secret")
552+
router.ServeHTTP(recorder, req)
553+
554+
assert.Equal(t, 200, recorder.Code)
555+
},
556+
},
557+
{
558+
description: "Ensure request with invalid PKCE fails",
559+
middlewares: []gin.HandlerFunc{
560+
simpleCtx,
561+
},
562+
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
563+
hasher := sha256.New()
564+
hasher.Write([]byte("some-challenge"))
565+
codeChallenge := hasher.Sum(nil)
566+
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
567+
reqBody := service.AuthorizeRequest{
568+
Scope: "openid",
569+
ResponseType: "code",
570+
ClientID: "some-client-id",
571+
RedirectURI: "https://test.example.com/callback",
572+
State: "some-state",
573+
Nonce: "some-nonce",
574+
CodeChallenge: codeChallengeEncoded,
575+
CodeChallengeMethod: "S256",
576+
}
577+
reqBodyBytes, err := json.Marshal(reqBody)
578+
assert.NoError(t, err)
579+
580+
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
581+
req.Header.Set("Content-Type", "application/json")
582+
router.ServeHTTP(recorder, req)
583+
assert.Equal(t, 200, recorder.Code)
584+
585+
var res map[string]any
586+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
587+
assert.NoError(t, err)
588+
589+
redirectURI := res["redirect_uri"].(string)
590+
url, err := url.Parse(redirectURI)
591+
assert.NoError(t, err)
592+
593+
queryParams := url.Query()
594+
assert.Equal(t, queryParams.Get("state"), "some-state")
595+
596+
code := queryParams.Get("code")
597+
assert.NotEmpty(t, code)
598+
599+
// Now exchange the code for a token
600+
recorder = httptest.NewRecorder()
601+
tokenReqBody := controller.TokenRequest{
602+
GrantType: "authorization_code",
603+
Code: code,
604+
RedirectURI: "https://test.example.com/callback",
605+
CodeVerifier: "some-challenge-1",
606+
}
607+
reqBodyEncoded, err := query.Values(tokenReqBody)
608+
assert.NoError(t, err)
609+
610+
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
611+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
612+
req.SetBasicAuth("some-client-id", "some-client-secret")
613+
router.ServeHTTP(recorder, req)
614+
615+
assert.Equal(t, 400, recorder.Code)
616+
},
617+
},
618+
{
619+
description: "Ensure request with invalid challenge method fails",
620+
middlewares: []gin.HandlerFunc{
621+
simpleCtx,
622+
},
623+
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
624+
hasher := sha256.New()
625+
hasher.Write([]byte("some-challenge"))
626+
codeChallenge := hasher.Sum(nil)
627+
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
628+
reqBody := service.AuthorizeRequest{
629+
Scope: "openid",
630+
ResponseType: "code",
631+
ClientID: "some-client-id",
632+
RedirectURI: "https://test.example.com/callback",
633+
State: "some-state",
634+
Nonce: "some-nonce",
635+
CodeChallenge: codeChallengeEncoded,
636+
CodeChallengeMethod: "foo",
637+
}
638+
reqBodyBytes, err := json.Marshal(reqBody)
639+
assert.NoError(t, err)
640+
641+
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
642+
req.Header.Set("Content-Type", "application/json")
643+
router.ServeHTTP(recorder, req)
644+
assert.Equal(t, 200, recorder.Code)
645+
646+
var res map[string]any
647+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
648+
assert.NoError(t, err)
649+
650+
redirectURI := res["redirect_uri"].(string)
651+
url, err := url.Parse(redirectURI)
652+
assert.NoError(t, err)
653+
654+
queryParams := url.Query()
655+
error := queryParams.Get("error")
656+
assert.NotEmpty(t, error)
657+
},
658+
},
434659
}
435660

436661
app := bootstrap.NewBootstrapApp(config.Config{})

internal/controller/proxy_controller_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
)
1818

1919
func TestProxyController(t *testing.T) {
20+
tlog.NewTestLogger().Init()
2021
tempDir := t.TempDir()
2122

2223
authServiceCfg := service.AuthServiceConfig{
@@ -390,8 +391,6 @@ func TestProxyController(t *testing.T) {
390391
},
391392
}
392393

393-
tlog.NewSimpleLogger().Init()
394-
395394
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
396395

397396
app := bootstrap.NewBootstrapApp(config.Config{})

internal/controller/resources_controller_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ import (
88

99
"github.com/gin-gonic/gin"
1010
"github.com/steveiliop56/tinyauth/internal/controller"
11+
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
1112
"github.com/stretchr/testify/assert"
1213
"github.com/stretchr/testify/require"
1314
)
1415

1516
func TestResourcesController(t *testing.T) {
17+
tlog.NewTestLogger().Init()
1618
tempDir := t.TempDir()
1719

1820
resourcesControllerCfg := controller.ResourcesControllerConfig{

internal/controller/user_controller_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
)
2323

2424
func TestUserController(t *testing.T) {
25+
tlog.NewTestLogger().Init()
2526
tempDir := t.TempDir()
2627

2728
authServiceCfg := service.AuthServiceConfig{
@@ -274,8 +275,6 @@ func TestUserController(t *testing.T) {
274275
},
275276
}
276277

277-
tlog.NewSimpleLogger().Init()
278-
279278
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
280279

281280
app := bootstrap.NewBootstrapApp(config.Config{})

0 commit comments

Comments
 (0)