Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 48 additions & 192 deletions internal/bootstrap/app_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,26 @@ import (
"time"
"tinyauth/internal/config"
"tinyauth/internal/controller"
"tinyauth/internal/middleware"
"tinyauth/internal/model"
"tinyauth/internal/service"
"tinyauth/internal/utils"

"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)

type Controller interface {
SetupRoutes()
}

type Middleware interface {
Middleware() gin.HandlerFunc
Init() error
}

type Service interface {
Init() error
}

type BootstrapApp struct {
config config.Config
uuid string
config config.Config
context struct {
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
}
services Services
}

func NewBootstrapApp(config config.Config) *BootstrapApp {
Expand All @@ -55,111 +49,51 @@ func (app *BootstrapApp) Setup() error {
return err
}

app.context.users = users

// Get OAuth configs
oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.config.AppURL)

if err != nil {
return err
}

app.context.oauthProviders = oauthProviders

// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)

if err != nil {
return err
}

app.context.cookieDomain = cookieDomain

// Cookie names
appUrl, _ := url.Parse(app.config.AppURL) // Already validated
uuid := utils.GenerateUUID(appUrl.Hostname())
app.uuid = uuid
cookieId := strings.Split(uuid, "-")[0]
sessionCookieName := fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
csrfCookieName := fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
redirectCookieName := fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)

// Dumps
log.Trace().Interface("config", app.config).Msg("Config dump")
log.Trace().Interface("users", users).Msg("Users dump")
log.Trace().Interface("oauthProviders", oauthProviders).Msg("OAuth providers dump")
log.Trace().Str("cookieDomain", cookieDomain).Msg("Cookie domain")
log.Trace().Str("sessionCookieName", sessionCookieName).Msg("Session cookie name")
log.Trace().Str("csrfCookieName", csrfCookieName).Msg("CSRF cookie name")
log.Trace().Str("redirectCookieName", redirectCookieName).Msg("Redirect cookie name")

// Create configs
authConfig := service.AuthServiceConfig{
Users: users,
OauthWhitelist: app.config.OAuthWhitelist,
SessionExpiry: app.config.SessionExpiry,
SecureCookie: app.config.SecureCookie,
CookieDomain: cookieDomain,
LoginTimeout: app.config.LoginTimeout,
LoginMaxRetries: app.config.LoginMaxRetries,
SessionCookieName: sessionCookieName,
}

// Setup services
var ldapService *service.LdapService

if app.config.LdapAddress != "" {
ldapConfig := service.LdapServiceConfig{
Address: app.config.LdapAddress,
BindDN: app.config.LdapBindDN,
BindPassword: app.config.LdapBindPassword,
BaseDN: app.config.LdapBaseDN,
Insecure: app.config.LdapInsecure,
SearchFilter: app.config.LdapSearchFilter,
}

ldapService = service.NewLdapService(ldapConfig)
log.Trace().Interface("users", app.context.users).Msg("Users dump")
log.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
log.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
log.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")

err := ldapService.Init()

if err != nil {
log.Warn().Err(err).Msg("Failed to initialize LDAP service, continuing without LDAP")
ldapService = nil
}
}

// Bootstrap database
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
DatabasePath: app.config.DatabasePath,
})

log.Debug().Str("service", fmt.Sprintf("%T", databaseService)).Msg("Initializing service")

err = databaseService.Init()
// Services
services, err := app.initServices()

if err != nil {
return fmt.Errorf("failed to initialize database service: %w", err)
return fmt.Errorf("failed to initialize services: %w", err)
}

database := databaseService.GetDatabase()

// Create services
dockerService := service.NewDockerService()
aclsService := service.NewAccessControlsService(dockerService)
authService := service.NewAuthService(authConfig, dockerService, ldapService, database)
oauthBrokerService := service.NewOAuthBrokerService(oauthProviders)

// Initialize services (order matters)
services := []Service{
dockerService,
aclsService,
authService,
oauthBrokerService,
}

for _, svc := range services {
if svc != nil {
log.Debug().Str("service", fmt.Sprintf("%T", svc)).Msg("Initializing service")
err := svc.Init()
if err != nil {
return err
}
}
}
app.services = services

// Configured providers
configuredProviders := make([]controller.Provider, 0)
Expand All @@ -176,7 +110,7 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})

if authService.UserAuthConfigured() || ldapService != nil {
if services.authService.UserAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
Expand All @@ -190,106 +124,27 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("no authentication providers configured")
}

// Create engine
engine := gin.New()
engine.Use(gin.Recovery())

if len(app.config.TrustedProxies) > 0 {
err := engine.SetTrustedProxies(strings.Split(app.config.TrustedProxies, ","))

if err != nil {
return fmt.Errorf("failed to set trusted proxies: %w", err)
}
}

// Create middlewares
var middlewares []Middleware

contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: cookieDomain,
}, authService, oauthBrokerService)

uiMiddleware := middleware.NewUIMiddleware()
zerologMiddleware := middleware.NewZerologMiddleware()
app.context.configuredProviders = configuredProviders

middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware)
// Setup router
router, err := app.setupRouter()

for _, middleware := range middlewares {
log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware")
err := middleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err)
}
engine.Use(middleware.Middleware())
}

// Create routers
mainRouter := engine.Group("")
apiRouter := engine.Group("/api")

// Create controllers
contextController := controller.NewContextController(controller.ContextControllerConfig{
Providers: configuredProviders,
Title: app.config.Title,
AppURL: app.config.AppURL,
CookieDomain: cookieDomain,
ForgotPasswordMessage: app.config.ForgotPasswordMessage,
BackgroundImage: app.config.BackgroundImage,
OAuthAutoRedirect: app.config.OAuthAutoRedirect,
DisableUIWarnings: app.config.DisableUIWarnings,
}, apiRouter)

oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL,
SecureCookie: app.config.SecureCookie,
CSRFCookieName: csrfCookieName,
RedirectCookieName: redirectCookieName,
CookieDomain: cookieDomain,
}, apiRouter, authService, oauthBrokerService)

proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, aclsService, authService)

userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: cookieDomain,
}, apiRouter, authService)

resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
ResourcesDir: app.config.ResourcesDir,
ResourcesDisabled: app.config.DisableResources,
}, mainRouter)

healthController := controller.NewHealthController(apiRouter)

// Setup routes
controller := []Controller{
contextController,
oauthController,
proxyController,
userController,
healthController,
resourcesController,
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)
}

for _, ctrl := range controller {
log.Debug().Msgf("Setting up %T controller", ctrl)
ctrl.SetupRoutes()
}
// Start DB cleanup routine
log.Debug().Msg("Starting database cleanup routine")
go app.dbCleanup(services.databaseService.GetDatabase())

// If analytics are not disabled, start heartbeat
if !app.config.DisableAnalytics {
log.Debug().Msg("Starting heartbeat routine")
go app.heartbeat()
}

// Start DB cleanup routine
log.Debug().Msg("Starting database cleanup routine")
go app.dbCleanup(database)

// If we have an socket path, bind to it
if app.config.SocketPath != "" {
// Remove existing socket file
if _, err := os.Stat(app.config.SocketPath); err == nil {
log.Info().Msgf("Removing existing socket file %s", app.config.SocketPath)
err := os.Remove(app.config.SocketPath)
Expand All @@ -298,9 +153,8 @@ func (app *BootstrapApp) Setup() error {
}
}

// Start server with unix socket
log.Info().Msgf("Starting server on unix socket %s", app.config.SocketPath)
if err := engine.RunUnix(app.config.SocketPath); err != nil {
if err := router.RunUnix(app.config.SocketPath); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}

Expand All @@ -310,7 +164,7 @@ func (app *BootstrapApp) Setup() error {
// Start server
address := fmt.Sprintf("%s:%d", app.config.Address, app.config.Port)
log.Info().Msgf("Starting server on %s", address)
if err := engine.Run(address); err != nil {
if err := router.Run(address); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}

Expand All @@ -328,7 +182,7 @@ func (app *BootstrapApp) heartbeat() {

var body heartbeat

body.UUID = app.uuid
body.UUID = app.context.uuid
body.Version = config.Version

bodyJson, err := json.Marshal(body)
Expand All @@ -338,7 +192,9 @@ func (app *BootstrapApp) heartbeat() {
return
}

client := &http.Client{}
client := &http.Client{
Timeout: time.Duration(10) * time.Second, // The server should never take more than 10 seconds to respond
}

heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"

Expand Down
Loading