Skip to content

Commit d500907

Browse files
committed
fix: coderabbit comments
1 parent 548d97f commit d500907

17 files changed

Lines changed: 107 additions & 163 deletions

internal/bootstrap/app_bootstrap.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (app *BootstrapApp) Setup() error {
102102

103103
app.runtime.OAuthWhitelist = oauthWhitelist
104104

105-
// Setup oauth providers
105+
// setup oauth providers
106106
app.runtime.OAuthProviders = app.config.OAuth.Providers
107107

108108
for id, provider := range app.runtime.OAuthProviders {
@@ -168,6 +168,14 @@ func (app *BootstrapApp) Setup() error {
168168
return fmt.Errorf("failed to setup database: %w", err)
169169
}
170170

171+
// after this point, we start initializing dependencies so it's a good time to setup a defer
172+
// to ensure that resources are cleaned up properly in case of an error during initialization
173+
defer func() {
174+
app.cancel()
175+
app.wg.Wait()
176+
app.db.Close()
177+
}()
178+
171179
// queries
172180
queries := repository.New(app.db)
173181
app.queries = queries
@@ -279,9 +287,6 @@ func (app *BootstrapApp) Setup() error {
279287
for {
280288
select {
281289
case <-app.ctx.Done():
282-
app.wg.Wait()
283-
app.log.App.Debug().Msg("Closing database")
284-
app.db.Close()
285290
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
286291
return nil
287292
case err := <-errChan:
@@ -305,7 +310,7 @@ func (app *BootstrapApp) serveHTTP() error {
305310
go func() {
306311
<-app.ctx.Done()
307312
app.log.App.Debug().Msg("Shutting down http listener")
308-
server.Close()
313+
server.Shutdown(app.ctx)
309314
}()
310315

311316
err := server.ListenAndServe()
@@ -345,21 +350,23 @@ func (app *BootstrapApp) serveUnix() error {
345350
Handler: app.router.Handler(),
346351
}
347352

348-
defer server.Close()
349-
defer listener.Close()
350-
defer os.Remove(app.config.Server.SocketPath)
353+
shutdown := func() {
354+
server.Shutdown(app.ctx)
355+
listener.Close()
356+
os.Remove(app.config.Server.SocketPath)
357+
}
358+
359+
defer shutdown()
351360

352361
go func() {
353362
<-app.ctx.Done()
354363
app.log.App.Debug().Msg("Shutting down unix socket listener")
355-
server.Close()
356-
listener.Close()
357-
os.Remove(app.config.Server.SocketPath)
364+
shutdown()
358365
}()
359366

360367
err = server.Serve(listener)
361368

362-
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
369+
if err != nil && !errors.Is(err, http.ErrServerClosed) {
363370
return fmt.Errorf("failed to start unix socket listener: %w", err)
364371
}
365372

internal/bootstrap/db_bootstrap.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ func (app *BootstrapApp) SetupDatabase() error {
2727
return fmt.Errorf("failed to open database: %w", err)
2828
}
2929

30+
// Close the database if there is an error during migration
31+
defer func() {
32+
if err != nil {
33+
db.Close()
34+
}
35+
}()
36+
3037
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
3138
// if the sqlite connection starts being a bottleneck
3239
db.SetMaxOpenConns(1)

internal/bootstrap/router_bootstrap.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func (app *BootstrapApp) setupRouter() error {
4343

4444
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
4545
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
46-
controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
46+
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
4747
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
4848
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
4949
controller.NewResourcesController(app.config, &engine.RouterGroup)

internal/controller/context_controller_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/stretchr/testify/require"
1212
"github.com/tinyauthapp/tinyauth/internal/controller"
1313
"github.com/tinyauthapp/tinyauth/internal/model"
14+
"github.com/tinyauthapp/tinyauth/internal/test"
1415
"github.com/tinyauthapp/tinyauth/internal/utils"
1516
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
1617
)
@@ -19,7 +20,7 @@ func TestContextController(t *testing.T) {
1920
log := logger.NewLogger().WithTestConfig()
2021
log.Init()
2122

22-
cfg, runtime := createTestConfigs(t)
23+
cfg, runtime := test.CreateTestConfigs(t)
2324

2425
tests := []struct {
2526
description string

internal/controller/oauth_controller.go

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
131131

132132
if err != nil {
133133
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
134-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
134+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
135135
return
136136
}
137137

@@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
141141

142142
if err != nil {
143143
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
144-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
144+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
145145
return
146146
}
147147

@@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
150150
state := c.Query("state")
151151
if state != oauthPendingSession.State {
152152
controller.log.App.Warn().Msg("OAuth state mismatch")
153-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
153+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
154154
return
155155
}
156156

@@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
159159

160160
if err != nil {
161161
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
162-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
162+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
163163
return
164164
}
165165

166166
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
167167

168+
if err != nil {
169+
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
170+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
171+
return
172+
}
173+
174+
if user == nil {
175+
controller.log.App.Warn().Msg("OAuth provider did not return user info")
176+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
177+
return
178+
}
179+
168180
if user.Email == "" {
169181
controller.log.App.Warn().Msg("OAuth provider did not return an email")
170-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
182+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
171183
return
172184
}
173185

@@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
181193

182194
if err != nil {
183195
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
184-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
196+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
185197
return
186198
}
187199

188-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
200+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
189201
return
190202
}
191203

@@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
213225

214226
if err != nil {
215227
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
216-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
228+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
217229
return
218230
}
219231

220232
if svc.ID() != req.Provider {
221233
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
222-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
234+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
223235
return
224236
}
225237

@@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
239251

240252
if err != nil {
241253
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
242-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
254+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
243255
return
244256
}
245257

@@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
252264
queries, err := query.Values(oauthPendingSession.CallbackParams)
253265
if err != nil {
254266
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
255-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
267+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
256268
return
257269
}
258-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
270+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
259271
return
260272
}
261273

@@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
266278

267279
if err != nil {
268280
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
269-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
281+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
270282
return
271283
}
272284

273-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
285+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
274286
return
275287
}
276288

277-
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
289+
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
278290
}
279291

280292
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {

internal/controller/oidc_controller.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ import (
1717
)
1818

1919
type OIDCController struct {
20-
log *logger.Logger
21-
oidc *service.OIDCService
20+
log *logger.Logger
21+
oidc *service.OIDCService
22+
runtime model.RuntimeConfig
2223
}
2324

2425
type AuthorizeCallback struct {
@@ -58,10 +59,12 @@ type ClientCredentials struct {
5859
func NewOIDCController(
5960
log *logger.Logger,
6061
oidcService *service.OIDCService,
62+
runtimeConfig model.RuntimeConfig,
6163
router *gin.RouterGroup) *OIDCController {
6264
controller := &OIDCController{
63-
log: log,
64-
oidc: oidcService,
65+
log: log,
66+
oidc: oidcService,
67+
runtime: runtimeConfig,
6568
}
6669

6770
oidcGroup := router.Group("/oidc")
@@ -75,6 +78,15 @@ func NewOIDCController(
7578
}
7679

7780
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
81+
if controller.oidc == nil {
82+
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
83+
c.JSON(500, gin.H{
84+
"status": 500,
85+
"message": "OIDC not configured",
86+
})
87+
return
88+
}
89+
7890
var req ClientRequest
7991

8092
err := c.BindUri(&req)
@@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
198210
func (controller *OIDCController) Token(c *gin.Context) {
199211
if controller.oidc == nil {
200212
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
201-
c.JSON(404, gin.H{
202-
"error": "not_found",
213+
c.JSON(500, gin.H{
214+
"error": "server_error",
203215
})
204216
return
205217
}
@@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
374386
func (controller *OIDCController) Userinfo(c *gin.Context) {
375387
if controller.oidc == nil {
376388
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
377-
c.JSON(404, gin.H{
378-
"error": "not_found",
389+
c.JSON(500, gin.H{
390+
"error": "server_error",
379391
})
380392
return
381393
}
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
507519
return
508520
}
509521

522+
redirectUrl := ""
523+
524+
if controller.oidc != nil {
525+
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
526+
} else {
527+
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
528+
}
529+
510530
c.JSON(200, gin.H{
511531
"status": 200,
512-
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
532+
"redirect_uri": redirectUrl,
513533
})
514534
}

internal/controller/oidc_controller_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ import (
2020
"github.com/tinyauthapp/tinyauth/internal/model"
2121
"github.com/tinyauthapp/tinyauth/internal/repository"
2222
"github.com/tinyauthapp/tinyauth/internal/service"
23+
"github.com/tinyauthapp/tinyauth/internal/test"
2324
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
2425
)
2526

2627
func TestOIDCController(t *testing.T) {
2728
log := logger.NewLogger().WithTestConfig()
2829
log.Init()
2930

30-
cfg, runtime := createTestConfigs(t)
31+
cfg, runtime := test.CreateTestConfigs(t)
3132

3233
simpleCtx := func(c *gin.Context) {
3334
c.Set("context", &model.UserContext{
@@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) {
861862
group := router.Group("/api")
862863
gin.SetMode(gin.TestMode)
863864

864-
controller.NewOIDCController(log, oidcService, group)
865+
controller.NewOIDCController(log, oidcService, runtime, group)
865866

866867
recorder := httptest.NewRecorder()
867868

internal/controller/proxy_controller_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ import (
1414
"github.com/tinyauthapp/tinyauth/internal/model"
1515
"github.com/tinyauthapp/tinyauth/internal/repository"
1616
"github.com/tinyauthapp/tinyauth/internal/service"
17+
"github.com/tinyauthapp/tinyauth/internal/test"
1718
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
1819
)
1920

2021
func TestProxyController(t *testing.T) {
2122
log := logger.NewLogger().WithTestConfig()
2223
log.Init()
2324

24-
cfg, runtime := createTestConfigs(t)
25+
cfg, runtime := test.CreateTestConfigs(t)
2526

2627
acls := map[string]model.App{
2728
"app_path_allow": {

internal/controller/resources_controller_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ import (
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212
"github.com/tinyauthapp/tinyauth/internal/controller"
13+
"github.com/tinyauthapp/tinyauth/internal/test"
1314
)
1415

1516
func TestResourcesController(t *testing.T) {
16-
cfg, _ := createTestConfigs(t)
17+
cfg, _ := test.CreateTestConfigs(t)
1718

1819
err := os.MkdirAll(cfg.Resources.Path, 0777)
1920
require.NoError(t, err)

internal/controller/user_controller_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ import (
1919
"github.com/tinyauthapp/tinyauth/internal/model"
2020
"github.com/tinyauthapp/tinyauth/internal/repository"
2121
"github.com/tinyauthapp/tinyauth/internal/service"
22+
"github.com/tinyauthapp/tinyauth/internal/test"
2223
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
2324
)
2425

2526
func TestUserController(t *testing.T) {
2627
log := logger.NewLogger().WithTestConfig()
2728
log.Init()
2829

29-
cfg, runtime := createTestConfigs(t)
30+
cfg, runtime := test.CreateTestConfigs(t)
3031

3132
totpCtx := func(c *gin.Context) {
3233
c.Set("context", &model.UserContext{

0 commit comments

Comments
 (0)