From cbec511b68b17e7dfa6cf4ebbbb3dc8beeffdef2 Mon Sep 17 00:00:00 2001 From: Wagner Abrantes <51835962+wagnerdevocelot@users.noreply.github.com> Date: Sun, 13 Jul 2025 21:27:22 -0300 Subject: [PATCH] Fix benchmark context --- benchmark_test.go | 39 +++++++ db_adapter_test.go | 141 ++++++++++++++++++++++++ go.mod | 1 + go.sum | 1 + integration_legacy_db_test.go | 100 +++++++++++++++++ logging_adapter_test.go | 200 ++++++++++++++++++++++++++++++++++ 6 files changed, 482 insertions(+) create mode 100644 benchmark_test.go create mode 100644 db_adapter_test.go create mode 100644 integration_legacy_db_test.go create mode 100644 logging_adapter_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..0029743 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "context" + "database/sql" + "strconv" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/ory/fosite" +) + +func BenchmarkLoggingAdapterCreateToken(b *testing.B) { + adapter := NewLoggingAdapter(NewInMemoryStore()) + req := &fosite.Request{} + for i := 0; i < b.N; i++ { + adapter.CreateToken(context.Background(), "access_token", "sig"+strconv.Itoa(i), "c1", req) + } +} + +func BenchmarkLegacyDBCreateToken(b *testing.B) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + b.Fatal(err) + } + stmts := []string{ + `CREATE TABLE tokens (signature TEXT PRIMARY KEY, client_id TEXT, token_type TEXT, data BLOB, revoked_at TIMESTAMP)`, + } + for _, s := range stmts { + if _, err := db.Exec(s); err != nil { + b.Fatal(err) + } + } + adapter := NewLegacyDBAdapter(db) + req := []byte("data") + for i := 0; i < b.N; i++ { + adapter.CreateToken(context.Background(), "access_token", "sig"+strconv.Itoa(i), "c1", req) + } +} diff --git a/db_adapter_test.go b/db_adapter_test.go new file mode 100644 index 0000000..4a38cbb --- /dev/null +++ b/db_adapter_test.go @@ -0,0 +1,141 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/ory/fosite" +) + +func setupTestDB(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + // create tables + statements := []string{ + `CREATE TABLE clients (id TEXT PRIMARY KEY, secret TEXT, redirect_uris TEXT, scopes TEXT, is_public BOOLEAN)`, + `CREATE TABLE tokens (signature TEXT PRIMARY KEY, client_id TEXT, token_type TEXT, data BLOB, revoked_at TIMESTAMP)`, + `CREATE TABLE sessions (id TEXT PRIMARY KEY, session_type TEXT, data BLOB)`, + `CREATE TABLE used_jtis (jti TEXT PRIMARY KEY, expires_at TIMESTAMP)`, + } + for _, stmt := range statements { + if _, err := db.Exec(stmt); err != nil { + t.Fatalf("failed to create table: %v", err) + } + } + return db +} + +func TestLegacyDBAdapterClientMethods(t *testing.T) { + ctx := context.Background() + db := setupTestDB(t) + adapter := NewLegacyDBAdapter(db) + + // Create client + client := &fosite.DefaultClient{ID: "c1", Secret: []byte("secret"), RedirectURIs: []string{"http://localhost"}} + if err := adapter.CreateClient(ctx, client); err != nil { + t.Fatalf("CreateClient failed: %v", err) + } + + // Get client + got, err := adapter.GetClient(ctx, "c1") + if err != nil { + t.Fatalf("GetClient failed: %v", err) + } + if got.GetID() != "c1" { + t.Errorf("expected id c1 got %s", got.GetID()) + } + + // Update client + client.RedirectURIs = []string{"http://127.0.0.1"} + if err := adapter.UpdateClient(ctx, client); err != nil { + t.Fatalf("UpdateClient failed: %v", err) + } + + // Delete client + if err := adapter.DeleteClient(ctx, "c1"); err != nil { + t.Fatalf("DeleteClient failed: %v", err) + } + if _, err := adapter.GetClient(ctx, "c1"); !errors.Is(err, fosite.ErrNotFound) { + t.Errorf("expected not found after delete got %v", err) + } +} + +func TestLegacyDBAdapterTokenMethods(t *testing.T) { + ctx := context.Background() + db := setupTestDB(t) + adapter := NewLegacyDBAdapter(db) + + // create token + payload := []byte("data") + if err := adapter.CreateToken(ctx, "access_token", "sig1", "client", payload); err != nil { + t.Fatalf("CreateToken failed: %v", err) + } + // get token + got, err := adapter.GetToken(ctx, "access_token", "sig1") + if err != nil { + t.Fatalf("GetToken failed: %v", err) + } + if string(got.([]byte)) != "data" { + t.Errorf("unexpected token data: %v", got) + } + // revoke + if err := adapter.RevokeToken(ctx, "access_token", "sig1"); err != nil { + t.Fatalf("RevokeToken failed: %v", err) + } + // delete + if err := adapter.DeleteToken(ctx, "access_token", "sig1"); err != nil { + t.Fatalf("DeleteToken failed: %v", err) + } + if _, err := adapter.GetToken(ctx, "access_token", "sig1"); !errors.Is(err, fosite.ErrNotFound) { + t.Errorf("expected not found after delete got %v", err) + } +} + +func TestLegacyDBAdapterSessionMethods(t *testing.T) { + ctx := context.Background() + db := setupTestDB(t) + adapter := NewLegacyDBAdapter(db) + + data := []byte("session") + if err := adapter.CreateSession(ctx, "openid", "s1", data); err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + got, err := adapter.GetSession(ctx, "openid", "s1") + if err != nil { + t.Fatalf("GetSession failed: %v", err) + } + if string(got.([]byte)) != "session" { + t.Errorf("unexpected session data: %v", got) + } + if err := adapter.DeleteSession(ctx, "openid", "s1"); err != nil { + t.Fatalf("DeleteSession failed: %v", err) + } + if _, err := adapter.GetSession(ctx, "openid", "s1"); !errors.Is(err, fosite.ErrNotFound) { + t.Errorf("expected not found after delete got %v", err) + } +} + +func TestLegacyDBAdapterJWTMethods(t *testing.T) { + ctx := context.Background() + db := setupTestDB(t) + adapter := NewLegacyDBAdapter(db) + + // validate new jti + if err := adapter.ValidateJWT(ctx, "j1"); err != nil { + t.Fatalf("ValidateJWT unexpected: %v", err) + } + // mark used + if err := adapter.MarkJWTAsUsed(ctx, "j1", time.Now().Add(time.Hour)); err != nil { + t.Fatalf("MarkJWTAsUsed failed: %v", err) + } + // now validate should fail + if err := adapter.ValidateJWT(ctx, "j1"); !errors.Is(err, fosite.ErrJTIKnown) { + t.Errorf("expected JTIKnown after mark used got %v", err) + } +} diff --git a/go.mod b/go.mod index 12a1052..5ec3b13 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module identity-go go 1.22 require ( + github.com/mattn/go-sqlite3 v1.14.16 github.com/ory/fosite v0.49.0 golang.org/x/crypto v0.31.0 ) diff --git a/go.sum b/go.sum index ca36e34..4963384 100644 --- a/go.sum +++ b/go.sum @@ -296,6 +296,7 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/goveralls v0.0.12 h1:PEEeF0k1SsTjOBQ8FOmrOAoCu4ytuMaWCnWe94zxbCg= github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ= diff --git a/integration_legacy_db_test.go b/integration_legacy_db_test.go new file mode 100644 index 0000000..2c33fe0 --- /dev/null +++ b/integration_legacy_db_test.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/token/jwt" +) + +// setupLegacyStore initializes the global store with a LoggingAdapter over LegacyDBAdapter. +func setupLegacyStore(t *testing.T) func() { + db := setupTestDB(t) + orig := store + adapter := NewLegacyDBAdapter(db) + // seed client similar to in-memory store + client := &fosite.DefaultClient{ + ID: "my-test-client", + Secret: []byte("foobar"), + RedirectURIs: []string{"http://localhost:3000/callback", "http://127.0.0.1:3000/callback"}, + GrantTypes: fosite.Arguments{"authorization_code", "refresh_token", "client_credentials"}, + ResponseTypes: fosite.Arguments{"code", "token", "id_token", "code id_token", "code token", "id_token token", "code id_token token"}, + Scopes: fosite.Arguments{"openid", "profile", "email", "offline"}, + Audience: fosite.Arguments{"https://my-api.com"}, + } + if err := adapter.CreateClient(context.Background(), client); err != nil { + t.Fatalf("seed client: %v", err) + } + store = adapter + return func() { store = orig } +} + +// TestFullFlowLegacyDB simulates login, consent, token issuance and validation using the legacy DB. +func TestFullFlowLegacyDB(t *testing.T) { + teardown := setupLegacyStore(t) + defer teardown() + + // mimic authorization code flow using fosite directly + router := setupRouter() + srv := httptest.NewServer(router) + defer srv.Close() + + // create authorize request manually + arReq, _ := http.NewRequest("GET", srv.URL+"/oauth2/auth?response_type=code&client_id=my-test-client&redirect_uri=http://localhost:3000/callback&scope=openid+profile+offline&state=12345678", nil) + ar, err := oauth2Provider.NewAuthorizeRequest(arReq.Context(), arReq) + if err != nil { + t.Fatalf("authorize request: %v", err) + } + ar.GrantScope("openid") + ar.GrantScope("profile") + ar.GrantScope("offline") + sess := &openid.DefaultSession{Claims: &jwt.IDTokenClaims{Subject: "user"}, Headers: &jwt.Headers{}, Subject: "user"} + resp, err := oauth2Provider.NewAuthorizeResponse(arReq.Context(), ar, sess) + if err != nil { + t.Fatalf("authorize response: %v", err) + } + recorder := httptest.NewRecorder() + oauth2Provider.WriteAuthorizeResponse(arReq.Context(), recorder, ar, resp) + location, _ := recorder.Result().Location() + code := location.Query().Get("code") + + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", "http://localhost:3000/callback") + tokenReq, _ := http.NewRequest("POST", srv.URL+"/oauth2/token", strings.NewReader(data.Encode())) + tokenReq.SetBasicAuth("my-test-client", "foobar") + tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resToken, err := http.DefaultClient.Do(tokenReq) + if err != nil { + t.Fatalf("token request err: %v", err) + } + var tokenResp map[string]interface{} + json.NewDecoder(resToken.Body).Decode(&tokenResp) + resToken.Body.Close() + + access := tokenResp["access_token"].(string) + + introspect := url.Values{"token": {access}} + introReq, _ := http.NewRequest("POST", srv.URL+"/oauth2/introspect", strings.NewReader(introspect.Encode())) + introReq.SetBasicAuth("my-test-client", "foobar") + introReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resIntro, err := http.DefaultClient.Do(introReq) + if err != nil { + t.Fatalf("introspect err: %v", err) + } + var introResp map[string]interface{} + json.NewDecoder(resIntro.Body).Decode(&introResp) + resIntro.Body.Close() + if active, ok := introResp["active"].(bool); !ok || !active { + t.Errorf("token not active: %v", introResp) + } +} diff --git a/logging_adapter_test.go b/logging_adapter_test.go new file mode 100644 index 0000000..fdba289 --- /dev/null +++ b/logging_adapter_test.go @@ -0,0 +1,200 @@ +package main + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" +) + +// failingStore wraps InMemoryStore and allows forcing errors for selected methods. +type failingStore struct { + *InMemoryStore + fail map[string]bool +} + +func newFailingStore(fail map[string]bool) *failingStore { + return &failingStore{InMemoryStore: NewInMemoryStore(), fail: fail} +} + +func (s *failingStore) shouldFail(method string) bool { return s.fail[method] } + +func (s *failingStore) GetClient(ctx context.Context, id string) (fosite.Client, error) { + if s.shouldFail("GetClient") { + return nil, errors.New("fail") + } + return s.InMemoryStore.GetClient(ctx, id) +} +func (s *failingStore) CreateClient(ctx context.Context, c fosite.Client) error { + if s.shouldFail("CreateClient") { + return errors.New("fail") + } + return s.InMemoryStore.CreateClient(ctx, c) +} +func (s *failingStore) UpdateClient(ctx context.Context, c fosite.Client) error { + if s.shouldFail("UpdateClient") { + return errors.New("fail") + } + return s.InMemoryStore.UpdateClient(ctx, c) +} +func (s *failingStore) DeleteClient(ctx context.Context, id string) error { + if s.shouldFail("DeleteClient") { + return errors.New("fail") + } + return s.InMemoryStore.DeleteClient(ctx, id) +} +func (s *failingStore) CreateToken(ctx context.Context, tokenType, signature, clientID string, data interface{}) error { + if s.shouldFail("CreateToken") { + return errors.New("fail") + } + return s.InMemoryStore.CreateToken(ctx, tokenType, signature, clientID, data) +} +func (s *failingStore) GetToken(ctx context.Context, tokenType, signature string) (interface{}, error) { + if s.shouldFail("GetToken") { + return nil, errors.New("fail") + } + return s.InMemoryStore.GetToken(ctx, tokenType, signature) +} +func (s *failingStore) DeleteToken(ctx context.Context, tokenType, signature string) error { + if s.shouldFail("DeleteToken") { + return errors.New("fail") + } + return s.InMemoryStore.DeleteToken(ctx, tokenType, signature) +} +func (s *failingStore) RevokeToken(ctx context.Context, tokenType, signature string) error { + if s.shouldFail("RevokeToken") { + return errors.New("fail") + } + return s.InMemoryStore.RevokeToken(ctx, tokenType, signature) +} +func (s *failingStore) CreateSession(ctx context.Context, sessionType, id string, data interface{}) error { + if s.shouldFail("CreateSession") { + return errors.New("fail") + } + return s.InMemoryStore.CreateSession(ctx, sessionType, id, data) +} +func (s *failingStore) GetSession(ctx context.Context, sessionType, id string) (interface{}, error) { + if s.shouldFail("GetSession") { + return nil, errors.New("fail") + } + return s.InMemoryStore.GetSession(ctx, sessionType, id) +} +func (s *failingStore) DeleteSession(ctx context.Context, sessionType, id string) error { + if s.shouldFail("DeleteSession") { + return errors.New("fail") + } + return s.InMemoryStore.DeleteSession(ctx, sessionType, id) +} +func (s *failingStore) ValidateJWT(ctx context.Context, jti string) error { + if s.shouldFail("ValidateJWT") { + return errors.New("fail") + } + return s.InMemoryStore.ValidateJWT(ctx, jti) +} +func (s *failingStore) MarkJWTAsUsed(ctx context.Context, jti string, exp time.Time) error { + if s.shouldFail("MarkJWTAsUsed") { + return errors.New("fail") + } + return s.InMemoryStore.MarkJWTAsUsed(ctx, jti, exp) +} +func (s *failingStore) GetPKCERequestSession(ctx context.Context, signature string, sess fosite.Session) (fosite.Requester, error) { + if s.shouldFail("GetPKCERequestSession") { + return nil, errors.New("fail") + } + return s.InMemoryStore.GetPKCERequestSession(ctx, signature, sess) +} +func (s *failingStore) CreatePKCERequestSession(ctx context.Context, signature string, r fosite.Requester) error { + if s.shouldFail("CreatePKCERequestSession") { + return errors.New("fail") + } + return s.InMemoryStore.CreatePKCERequestSession(ctx, signature, r) +} +func (s *failingStore) DeletePKCERequestSession(ctx context.Context, signature string) error { + if s.shouldFail("DeletePKCERequestSession") { + return errors.New("fail") + } + return s.InMemoryStore.DeletePKCERequestSession(ctx, signature) +} + +func TestLoggingAdapterMetrics(t *testing.T) { + ctx := context.Background() + adapter := NewLoggingAdapter(NewInMemoryStore()) + + c := &fosite.DefaultClient{ID: "c1"} + if err := adapter.CreateClient(ctx, c); err != nil { + t.Fatalf("CreateClient err: %v", err) + } + if _, err := adapter.GetClient(ctx, "c1"); err != nil { + t.Fatalf("GetClient err: %v", err) + } + if err := adapter.UpdateClient(ctx, c); err != nil { + t.Fatalf("UpdateClient err: %v", err) + } + if err := adapter.DeleteClient(ctx, "c1"); err != nil { + t.Fatalf("DeleteClient err: %v", err) + } + + req := &fosite.Request{} + if err := adapter.CreateToken(ctx, "access_token", "sig", "c1", req); err != nil { + t.Fatalf("CreateToken err: %v", err) + } + if _, err := adapter.GetToken(ctx, "access_token", "sig"); err != nil { + t.Fatalf("GetToken err: %v", err) + } + if err := adapter.RevokeToken(ctx, "access_token", "sig"); err != nil { + t.Fatalf("RevokeToken err: %v", err) + } + if err := adapter.DeleteToken(ctx, "access_token", "sig"); err != nil { + t.Fatalf("DeleteToken err: %v", err) + } + + if err := adapter.CreateSession(ctx, "openid", "s1", req); err != nil { + t.Fatalf("CreateSession err: %v", err) + } + if _, err := adapter.GetSession(ctx, "openid", "s1"); err != nil { + t.Fatalf("GetSession err: %v", err) + } + if err := adapter.DeleteSession(ctx, "openid", "s1"); err != nil { + t.Fatalf("DeleteSession err: %v", err) + } + + if err := adapter.ValidateJWT(ctx, "j1"); err != nil { + t.Fatalf("ValidateJWT err: %v", err) + } + if err := adapter.MarkJWTAsUsed(ctx, "j1", time.Now()); err != nil { + t.Fatalf("MarkJWTAsUsed err: %v", err) + } + + if err := adapter.CreatePKCERequestSession(ctx, "p1", req); err != nil { + t.Fatalf("CreatePKCERequestSession err: %v", err) + } + if _, err := adapter.GetPKCERequestSession(ctx, "p1", &openid.DefaultSession{}); err != nil { + t.Fatalf("GetPKCERequestSession err: %v", err) + } + if err := adapter.DeletePKCERequestSession(ctx, "p1"); err != nil { + t.Fatalf("DeletePKCERequestSession err: %v", err) + } + + metrics := adapter.Metrics() + expected := []string{"CreateClient", "GetClient", "UpdateClient", "DeleteClient", "CreateToken", "GetToken", "RevokeToken", "DeleteToken", "CreateSession", "GetSession", "DeleteSession", "ValidateJWT", "MarkJWTAsUsed", "CreatePKCESession", "GetPKCESession", "DeletePKCESession"} + for _, m := range expected { + if metrics[m] == 0 { + t.Errorf("metric %s not incremented", m) + } + } +} + +func TestLoggingAdapterErrorMetrics(t *testing.T) { + ctx := context.Background() + fs := newFailingStore(map[string]bool{"GetClient": true}) + adapter := NewLoggingAdapter(fs) + if _, err := adapter.GetClient(ctx, "x"); err == nil { + t.Fatal("expected error") + } + if adapter.Metrics()["GetClientError"] != 1 { + t.Errorf("error metric not incremented") + } +}