Skip to content

Commit c346113

Browse files
feat: support provider-specific OAuth whitelists (#882)
Co-authored-by: Puneet Dixit <236133619+puneetdixit200@users.noreply.github.com>
1 parent 3f584ca commit c346113

7 files changed

Lines changed: 77 additions & 20 deletions

File tree

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
101101
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
102102
# Path to the file containing the OAuth client secret.
103103
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
104+
# Comma-separated list of allowed OAuth domains for this provider.
105+
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
106+
# Path to the OAuth whitelist file for this provider.
107+
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
104108
# OAuth scopes.
105109
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
106110
# OAuth redirect URL.

internal/bootstrap/app_bootstrap.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ func (app *BootstrapApp) Setup() error {
117117
app.runtime.OAuthProviders = app.config.OAuth.Providers
118118

119119
for id, provider := range app.runtime.OAuthProviders {
120+
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
121+
if err != nil {
122+
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
123+
}
124+
125+
provider.Whitelist = providerWhitelist
126+
120127
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
121128
provider.ClientSecret = secret
122129
provider.ClientSecretFile = ""

internal/controller/oauth_controller.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
183183
return
184184
}
185185

186-
if !controller.auth.IsEmailWhitelisted(user.Email) {
186+
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
187+
188+
if err != nil {
189+
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
190+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
191+
return
192+
}
193+
194+
if svc.ID() != req.Provider {
195+
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
196+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
197+
return
198+
}
199+
200+
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
187201
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
188-
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
202+
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
189203

190204
queries, err := query.Values(UnauthorizedQuery{
191205
Username: user.Email,
@@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
226240
username = strings.Replace(user.Email, "@", "_", 1)
227241
}
228242

229-
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
230-
231-
if err != nil {
232-
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
233-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
234-
return
235-
}
236-
237-
if svc.ID() != req.Provider {
238-
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
239-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
240-
return
241-
}
242-
243243
sessionCookie := repository.Session{
244244
Username: username,
245245
Name: name,

internal/middleware/context_middleware.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
205205
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
206206
}
207207

208-
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
208+
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
209209
m.auth.DeleteSession(ctx, uuid)
210210
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
211211
}

internal/model/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ type OAuthServiceConfig struct {
226226
ClientID string `description:"OAuth client ID." yaml:"clientId"`
227227
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
228228
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
229+
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
230+
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
229231
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
230232
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
231233
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`

internal/service/auth_service.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,15 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
285285
}
286286
}
287287

288-
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
289-
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
288+
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
289+
whitelist := auth.runtime.OAuthWhitelist
290+
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
291+
whitelist = providerConfig.Whitelist
292+
}
293+
294+
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
290295
if err != nil {
291-
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
296+
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
292297
return false
293298
}
294299
return match
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package service
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/tinyauthapp/tinyauth/internal/model"
8+
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
9+
)
10+
11+
func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
12+
log := logger.NewLogger().WithTestConfig()
13+
log.Init()
14+
15+
auth := &AuthService{
16+
log: log,
17+
runtime: model.RuntimeConfig{
18+
OAuthWhitelist: []string{"global@example.com"},
19+
OAuthProviders: map[string]model.OAuthServiceConfig{
20+
"github": {
21+
Whitelist: []string{"github@example.com"},
22+
},
23+
"pocketid": {
24+
Whitelist: []string{"pocket@example.com"},
25+
},
26+
"gitlab": {
27+
Whitelist: []string{},
28+
},
29+
},
30+
},
31+
}
32+
33+
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
34+
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
35+
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
36+
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
37+
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
38+
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
39+
}

0 commit comments

Comments
 (0)