diff --git a/cmd/api/src/api/dbpool/dbpool.go b/cmd/api/src/api/dbpool/dbpool.go new file mode 100644 index 00000000000..47a943842bf --- /dev/null +++ b/cmd/api/src/api/dbpool/dbpool.go @@ -0,0 +1,80 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package dbpool + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/bloodhound/cmd/api/src/config" + "github.com/specterops/dawgs/drivers/pg" +) + +const ( + poolInitConnectionTimeout = time.Second * 10 +) + +func newPoolCfg(cfg config.DatabaseConfiguration) (*pgxpool.Config, error) { + poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()) + if err != nil { + return nil, err + } + + // TODO: Min and Max connections for the pool should be configurable + poolCfg.MinConns = 5 + poolCfg.MaxConns = 50 + + if cfg.EnableRDSIAMAuth { + // Only enable the BeforeConnect handler if RDS IAM Auth is enabled + poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error { + if newPoolCfg, err := pgxpool.ParseConfig(cfg.RDSIAMAuthConnectionString()); err != nil { + return err + } else { + connCfg.Host = newPoolCfg.ConnConfig.Host + connCfg.Port = newPoolCfg.ConnConfig.Port + + connCfg.User = newPoolCfg.ConnConfig.User + connCfg.Password = newPoolCfg.ConnConfig.Password + connCfg.Database = newPoolCfg.ConnConfig.Database + } + + return nil + } + } + + return poolCfg, nil +} + +func NewDawgsPool(cfg config.DatabaseConfiguration) (*pgxpool.Pool, error) { + if poolCfg, err := newPoolCfg(cfg); err != nil { + return nil, err + } else { + return pg.NewPool(poolCfg) + } +} + +func NewAppPool(cfg config.DatabaseConfiguration) (*pgxpool.Pool, error) { + poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout) + defer done() + + if poolCfg, err := newPoolCfg(cfg); err != nil { + return nil, err + } else { + return pgxpool.NewWithConfig(poolCtx, poolCfg) + } +} diff --git a/cmd/api/src/api/tools/pg.go b/cmd/api/src/api/tools/pg.go index bc5b4b488fa..7b21b7782b6 100644 --- a/cmd/api/src/api/tools/pg.go +++ b/cmd/api/src/api/tools/pg.go @@ -25,6 +25,7 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" "github.com/specterops/bloodhound/cmd/api/src/api" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/packages/go/bhlog/attr" "github.com/specterops/bloodhound/packages/go/bhlog/measure" @@ -255,7 +256,7 @@ type PGMigrator struct { migrationCancelFunc func() State MigratorState lock *sync.Mutex - Cfg config.Configuration + cfg config.Configuration } func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSchema graph.Schema, graphDBSwitch *graph.DatabaseSwitch) *PGMigrator { @@ -265,7 +266,7 @@ func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSch ServerCtx: serverCtx, State: StateIdle, lock: &sync.Mutex{}, - Cfg: cfg, + cfg: cfg, } } @@ -297,7 +298,7 @@ func (s *PGMigrator) SwitchPostgreSQL(response http.ResponseWriter, request *htt }, http.StatusInternalServerError, response) } else if err := pgDB.AssertSchema(request.Context(), s.graphSchema); err != nil { slog.ErrorContext(request.Context(), "Unable to assert graph schema in PostgreSQL", attr.Error(err)) - } else if err := SetGraphDriver(request.Context(), s.Cfg, pg.DriverName); err != nil { + } else if err := SetGraphDriver(request.Context(), s.cfg, pg.DriverName); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), }, http.StatusInternalServerError, response) @@ -314,7 +315,7 @@ func (s *PGMigrator) SwitchNeo4j(response http.ResponseWriter, request *http.Req api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed connecting to Neo4j: %w", err), }, http.StatusInternalServerError, response) - } else if err := SetGraphDriver(request.Context(), s.Cfg, neo4j.DriverName); err != nil { + } else if err := SetGraphDriver(request.Context(), s.cfg, neo4j.DriverName); err != nil { api.WriteJSONResponse(request.Context(), map[string]any{ "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), }, http.StatusInternalServerError, response) @@ -449,12 +450,12 @@ func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http } func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) { - if pool, err := pg.NewPool(s.Cfg.Database); err != nil { + if pool, err := dbpool.NewDawgsPool(s.cfg.Database); err != nil { return nil, err } else { return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: size.Gibibyte, - ConnectionString: s.Cfg.Database.PostgreSQLConnectionString(), + ConnectionString: s.cfg.Database.PostgreSQLConnectionString(), Pool: pool, }) } @@ -463,6 +464,6 @@ func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) { func (s *PGMigrator) OpenNeo4jGraphConnection() (graph.Database, error) { return dawgs.Open(s.ServerCtx, neo4j.DriverName, dawgs.Config{ GraphQueryMemoryLimit: size.Gibibyte, - ConnectionString: s.Cfg.Neo4J.Neo4jConnectionString(), + ConnectionString: s.cfg.Neo4J.Neo4JConnectionString(), }) } diff --git a/cmd/api/src/bootstrap/util.go b/cmd/api/src/bootstrap/util.go index 8be699c1ec0..c19d74a0b13 100644 --- a/cmd/api/src/bootstrap/util.go +++ b/cmd/api/src/bootstrap/util.go @@ -23,6 +23,7 @@ import ( "os" "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/api/tools" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/dawgs" @@ -92,13 +93,13 @@ func ConnectGraph(ctx context.Context, cfg config.Configuration) (*graph.Databas switch driverName { case neo4j.DriverName: slog.InfoContext(ctx, "Connecting to graph using Neo4j") - connectionString = cfg.Neo4J.Neo4jConnectionString() + connectionString = cfg.Neo4J.Neo4JConnectionString() case pg.DriverName: slog.InfoContext(ctx, "Connecting to graph using PostgreSQL") connectionString = cfg.Database.PostgreSQLConnectionString() - pool, err = pg.NewPool(cfg.Database) + pool, err = dbpool.NewDawgsPool(cfg.Database) if err != nil { return nil, err } diff --git a/cmd/api/src/cmd/dawgs-harness/main.go b/cmd/api/src/cmd/dawgs-harness/main.go index 3db9c7bb60a..e640b83025b 100644 --- a/cmd/api/src/cmd/dawgs-harness/main.go +++ b/cmd/api/src/cmd/dawgs-harness/main.go @@ -29,12 +29,12 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jedib0t/go-pretty/v6/table" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/cmd/dawgs-harness/tests" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/packages/go/bhlog" schema "github.com/specterops/bloodhound/packages/go/graphschema" "github.com/specterops/dawgs" - "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/drivers/neo4j" "github.com/specterops/dawgs/drivers/pg" "github.com/specterops/dawgs/graph" @@ -46,14 +46,14 @@ func fatalf(format string, args ...any) { os.Exit(1) } -func RunTestSuite(ctx context.Context, connectionStr, driverName string, cfg drivers.DatabaseConfiguration) tests.TestSuite { +func RunTestSuite(ctx context.Context, connectionStr, driverName string, cfg config.DatabaseConfiguration) tests.TestSuite { var ( pool *pgxpool.Pool err error ) if driverName == pg.DriverName { - pool, err = pg.NewPool(cfg) + pool, err = dbpool.NewDawgsPool(cfg) if err != nil { fatalf("Failed creating a new pgxpool: %s", err) } diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index 731dc4a63b2..34503ba6e75 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -17,21 +17,26 @@ package config import ( + "context" "encoding/base64" "encoding/json" "errors" "fmt" "log/slog" + "net" + "net/url" "os" "path/filepath" "regexp" "strconv" "strings" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/specterops/bloodhound/cmd/api/src/serde" "github.com/specterops/bloodhound/packages/go/bhlog/attr" "github.com/specterops/bloodhound/packages/go/crypto" - dawgs "github.com/specterops/dawgs/drivers" ) const ( @@ -51,6 +56,71 @@ func (s TLSConfiguration) Enabled() bool { return s.CertFile != "" && s.KeyFile != "" } +type DatabaseConfiguration struct { + Connection string `json:"connection"` + Address string `json:"addr"` + Database string `json:"database"` + Username string `json:"username"` + Secret string `json:"secret"` + MaxConcurrentSessions int `json:"max_concurrent_sessions"` + EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"` +} + +func (s DatabaseConfiguration) PostgreSQLConnectionString() string { + if s.Connection != "" { + return s.Connection + } + + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database) +} + +func (s DatabaseConfiguration) Neo4JConnectionString() string { + if s.Connection != "" { + return s.Connection + } + + return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database) +} + +func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string { + if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil { + slog.Error("AWS Config Loading Error", slog.String("err", err.Error())) + } else { + // Must use instance endpoint with IAM auth + endpoint := s.LookupEndpoint() + + slog.Info("Requesting RDS IAM Auth Token") + + if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil { + slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error())) + } else { + slog.Info("RDS IAM Auth Token Created") + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database) + } + } + + slog.Warn("Failed to create IAM auth token. Falling back to default Postgres connection string") + return s.PostgreSQLConnectionString() +} + +func (s DatabaseConfiguration) LookupEndpoint() string { + host, port, err := net.SplitHostPort(s.Address) + if err != nil { + slog.Warn("Missing port in address. Using default port 5432.", slog.String("err", err.Error())) + host = s.Address + port = "5432" + } + + if hostCName, err := net.DefaultResolver.LookupCNAME(context.TODO(), host); err != nil { + slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error())) + } else { + host = hostCName + } + + // Instance endpoint always returns with a trailing '.' + return net.JoinHostPort(strings.TrimSuffix(host, "."), port) +} + type CollectorManifest struct { Latest string `json:"latest"` Versions []CollectorVersion `json:"versions"` @@ -111,39 +181,39 @@ type DefaultAdminConfiguration struct { } type Configuration struct { - Version int `json:"version"` - BindAddress string `json:"bind_addr"` - SlowQueryThreshold int64 `json:"slow_query_threshold"` - MaxGraphQueryCacheSize int `json:"max_graphdb_cache_size"` - MaxAPICacheSize int `json:"max_api_cache_size"` - MetricsPort string `json:"metrics_port"` - RootURL serde.URL `json:"root_url"` - WorkDir string `json:"work_dir"` - LogLevel string `json:"log_level"` - LogPath string `json:"log_path"` - TLS TLSConfiguration `json:"tls"` - GraphDriver string `json:"graph_driver"` - Database dawgs.DatabaseConfiguration `json:"database"` - Neo4J dawgs.DatabaseConfiguration `json:"neo4j"` - Crypto CryptoConfiguration `json:"crypto"` - SAML SAMLConfiguration `json:"saml"` - DefaultAdmin DefaultAdminConfiguration `json:"default_admin"` - CollectorsBucketURL serde.URL `json:"collectors_bucket_url"` - CollectorsBasePath string `json:"collectors_base_path"` - DatapipeInterval int `json:"datapipe_interval"` - EnableStartupWaitPeriod bool `json:"enable_startup_wait_period"` - EnableAPILogging bool `json:"enable_api_logging"` - EnableCypherMutations bool `json:"enable_cypher_mutations"` - DisableAnalysis bool `json:"disable_analysis"` - DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` - DisableIngest bool `json:"disable_ingest"` - DisableMigrations bool `json:"disable_migrations"` - GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"` - EnableTextLogger bool `json:"enable_text_logger"` - RecreateDefaultAdmin bool `json:"recreate_default_admin"` - EnableUserAnalytics bool `json:"enable_user_analytics"` - ForceDownloadEmbeddedCollectors bool `json:"force_download_embedded_collectors"` - EnableAuditLogStdout bool `json:"enable_audit_log_stdout"` + Version int `json:"version"` + BindAddress string `json:"bind_addr"` + SlowQueryThreshold int64 `json:"slow_query_threshold"` + MaxGraphQueryCacheSize int `json:"max_graphdb_cache_size"` + MaxAPICacheSize int `json:"max_api_cache_size"` + MetricsPort string `json:"metrics_port"` + RootURL serde.URL `json:"root_url"` + WorkDir string `json:"work_dir"` + LogLevel string `json:"log_level"` + LogPath string `json:"log_path"` + TLS TLSConfiguration `json:"tls"` + GraphDriver string `json:"graph_driver"` + Database DatabaseConfiguration `json:"database"` + Neo4J DatabaseConfiguration `json:"neo4j"` + Crypto CryptoConfiguration `json:"crypto"` + SAML SAMLConfiguration `json:"saml"` + DefaultAdmin DefaultAdminConfiguration `json:"default_admin"` + CollectorsBucketURL serde.URL `json:"collectors_bucket_url"` + CollectorsBasePath string `json:"collectors_base_path"` + DatapipeInterval int `json:"datapipe_interval"` + EnableStartupWaitPeriod bool `json:"enable_startup_wait_period"` + EnableAPILogging bool `json:"enable_api_logging"` + EnableCypherMutations bool `json:"enable_cypher_mutations"` + DisableAnalysis bool `json:"disable_analysis"` + DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` + DisableIngest bool `json:"disable_ingest"` + DisableMigrations bool `json:"disable_migrations"` + GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"` + EnableTextLogger bool `json:"enable_text_logger"` + RecreateDefaultAdmin bool `json:"recreate_default_admin"` + EnableUserAnalytics bool `json:"enable_user_analytics"` + ForceDownloadEmbeddedCollectors bool `json:"force_download_embedded_collectors"` + EnableAuditLogStdout bool `json:"enable_audit_log_stdout"` } func (s Configuration) TempDirectory() string { diff --git a/cmd/api/src/config/config_test.go b/cmd/api/src/config/config_test.go index 3d0362e5b6c..28ef453f446 100644 --- a/cmd/api/src/config/config_test.go +++ b/cmd/api/src/config/config_test.go @@ -70,7 +70,7 @@ func TestSetValuesFromEnv(t *testing.T) { "bhe_database_secret=supersecretpassword", })) - assert.Equal(t, "neo4j://neo4j:neo4jj@localhost:7070/neo4j", cfg.Neo4J.Neo4jConnectionString()) + assert.Equal(t, "neo4j://neo4j:neo4jj@localhost:7070/neo4j", cfg.Neo4J.Neo4JConnectionString()) assert.Equal(t, "postgresql://bhe:supersecretpassword@localhost:5432/bhe", cfg.Database.PostgreSQLConnectionString()) }) diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index a7d5a8c779f..471d3c7ab38 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -19,7 +19,6 @@ package config import ( "fmt" - dawgs "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/cmd/api/src/serde" @@ -86,10 +85,10 @@ func NewDefaultConfiguration() (Configuration, error) { TLS: TLSConfiguration{}, SAML: SAMLConfiguration{}, GraphDriver: neo4j.DriverName, // Default to PG as the graph driver - Database: dawgs.DatabaseConfiguration{ + Database: DatabaseConfiguration{ MaxConcurrentSessions: 10, }, - Neo4J: dawgs.DatabaseConfiguration{ + Neo4J: DatabaseConfiguration{ MaxConcurrentSessions: 10, }, Crypto: CryptoConfiguration{ diff --git a/cmd/api/src/daemons/changelog/changelog_e2e/main.go b/cmd/api/src/daemons/changelog/changelog_e2e/main.go index caafac33970..cb4c284654a 100644 --- a/cmd/api/src/daemons/changelog/changelog_e2e/main.go +++ b/cmd/api/src/daemons/changelog/changelog_e2e/main.go @@ -26,6 +26,7 @@ import ( "syscall" "time" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/daemons/changelog" @@ -84,8 +85,7 @@ func newHarness() *Harness { os.Exit(1) } - pool, err := pg.NewPool(cfg.Database) - + pool, err := dbpool.NewDawgsPool(cfg.Database) if err != nil { slog.Error("Failed to connect", attr.Error(err)) os.Exit(1) diff --git a/cmd/api/src/daemons/changelog/ingestion_integration_test.go b/cmd/api/src/daemons/changelog/ingestion_integration_test.go index 8fe76277215..6d8b6c8cf3e 100644 --- a/cmd/api/src/daemons/changelog/ingestion_integration_test.go +++ b/cmd/api/src/daemons/changelog/ingestion_integration_test.go @@ -28,6 +28,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/database" @@ -100,7 +101,7 @@ func setupIntegrationTest(t *testing.T) IntegrationTestSuite { require.NoError(t, err) // Create connection pool - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) // Open graph database diff --git a/cmd/api/src/daemons/datapipe/datapipe_integration_test.go b/cmd/api/src/daemons/datapipe/datapipe_integration_test.go index 38440046b75..89397503b7c 100644 --- a/cmd/api/src/daemons/datapipe/datapipe_integration_test.go +++ b/cmd/api/src/daemons/datapipe/datapipe_integration_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/daemons/changelog" @@ -72,7 +73,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes require.NoError(t, err) //#region Setup for dbs - pool, err := pg.NewPool(cfg.Database) + graphPool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) gormDB, dbPool, err := database.OpenDatabase(cfg.Database) @@ -83,7 +84,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: 1024 * 1024 * 1024 * 2, ConnectionString: connConf.URL(), - Pool: pool, + Pool: graphPool, }) require.NoError(t, err) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index d8082575a7f..fdb01a298fa 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -29,6 +29,7 @@ import ( "github.com/gofrs/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/database/migration" @@ -36,8 +37,6 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/bhlog/attr" - "github.com/specterops/dawgs/drivers" - "github.com/specterops/dawgs/drivers/pg" "gorm.io/driver/postgres" "gorm.io/gorm" ) @@ -248,23 +247,28 @@ func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB }, opts...) } -func OpenDatabase(cfg drivers.DatabaseConfiguration) (*gorm.DB, *pgxpool.Pool, error) { - gormConfig := &gorm.Config{ - Logger: &GormLogAdapter{ - SlowQueryErrorThreshold: time.Second * 30, - SlowQueryWarnThreshold: time.Second * 10, - }, - } - pool, err := pg.NewPool(cfg) +func OpenDatabase(cfg config.DatabaseConfiguration) (*gorm.DB, *pgxpool.Pool, error) { + var ( + gormConfig = &gorm.Config{ + Logger: &GormLogAdapter{ + SlowQueryErrorThreshold: time.Second * 30, + SlowQueryWarnThreshold: time.Second * 10, + }, + } + ) + + // NewAppPool creates a relational database pool without graph composite type hooks, + // keeping the relational and graph connection concerns separate. + pool, err := dbpool.NewAppPool(cfg) if err != nil { return nil, nil, err } - dbPool := stdlib.OpenDBFromPool(pool) + dbConn := stdlib.OpenDBFromPool(pool) - db, err := gorm.Open(postgres.New(postgres.Config{Conn: dbPool}), gormConfig) + db, err := gorm.Open(postgres.New(postgres.Config{Conn: dbConn}), gormConfig) if err != nil { - _ = dbPool.Close() + _ = dbConn.Close() pool.Close() return nil, nil, err } diff --git a/cmd/api/src/queries/queries_integration_test.go b/cmd/api/src/queries/queries_integration_test.go index 82a5edef736..d4d2d6b8bb4 100644 --- a/cmd/api/src/queries/queries_integration_test.go +++ b/cmd/api/src/queries/queries_integration_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/database" @@ -77,7 +78,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes cfg.Database.Connection = connConf.URL() //#region Setup for dbs - pool, err := pg.NewPool(cfg.Database) + graphPool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) gormDB, dbPool, err := database.OpenDatabase(cfg.Database) @@ -90,7 +91,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: 1024 * 1024 * 1024 * 2, ConnectionString: connConf.URL(), - Pool: pool, + Pool: graphPool, }) require.NoError(t, err) diff --git a/cmd/api/src/services/graphify/graphify_integration_test.go b/cmd/api/src/services/graphify/graphify_integration_test.go index f016fc4de89..9ae9fbe322c 100644 --- a/cmd/api/src/services/graphify/graphify_integration_test.go +++ b/cmd/api/src/services/graphify/graphify_integration_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/daemons/changelog" @@ -67,7 +68,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes require.NoError(t, err) //#region Setup for dbs - pool, err := pg.NewPool(cfg.Database) + graphPool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) gormDB, dbPool, err := database.OpenDatabase(cfg.Database) @@ -78,7 +79,7 @@ func setupIntegrationTestSuite(t *testing.T, fixturesPath string) IntegrationTes graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: 1024 * 1024 * 1024 * 2, ConnectionString: connConf.URL(), - Pool: pool, + Pool: graphPool, }) require.NoError(t, err) diff --git a/cmd/api/src/services/graphify/test/test.go b/cmd/api/src/services/graphify/test/test.go index 63f035f352c..84d6276d226 100644 --- a/cmd/api/src/services/graphify/test/test.go +++ b/cmd/api/src/services/graphify/test/test.go @@ -26,6 +26,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/test/integration/utils" "github.com/specterops/bloodhound/packages/go/graphschema" @@ -107,7 +108,7 @@ func SetupIntegrationTestSuite(t *testing.T) IntegrationTestSuite { require.NoError(t, err) //#region Setup for dbs - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ diff --git a/cmd/api/src/services/upload/streamdecoder.go b/cmd/api/src/services/upload/streamdecoder.go index 113fb07c367..b5b700e8a6d 100644 --- a/cmd/api/src/services/upload/streamdecoder.go +++ b/cmd/api/src/services/upload/streamdecoder.go @@ -295,13 +295,13 @@ func (s ValidationReport) BuildAPIError() []string { func (s ValidationReport) Error() string { var sb strings.Builder if len(s.CriticalErrors) > 0 { - sb.WriteString(fmt.Sprintf("(%d) critical error(s): [%s]", len(s.CriticalErrors), formatAggregateErrors(s.CriticalErrors))) + fmt.Fprintf(&sb, "(%d) critical error(s): [%s]", len(s.CriticalErrors), formatAggregateErrors(s.CriticalErrors)) if len(s.ValidationErrors) > 0 { sb.WriteString(", ") } } if len(s.ValidationErrors) > 0 { - sb.WriteString(fmt.Sprintf("(%d) validation error(s): [%s]", len(s.ValidationErrors), formatAggregateErrors(s.ValidationErrors))) + fmt.Fprintf(&sb, "(%d) validation error(s): [%s]", len(s.ValidationErrors), formatAggregateErrors(s.ValidationErrors)) } return sb.String() } @@ -310,7 +310,7 @@ func formatSchemaValidationError(arrayName string, index int, err error) string var sb strings.Builder if ve, ok := err.(*jsonschema.ValidationError); ok { numberOfViolations := len(ve.Causes) - sb.WriteString(fmt.Sprintf("%s[%d] schema validation failed with %d error(s): ", arrayName, index, numberOfViolations)) + fmt.Fprintf(&sb, "%s[%d] schema validation failed with %d error(s): ", arrayName, index, numberOfViolations) sb.WriteString("[") @@ -328,17 +328,13 @@ func formatSchemaValidationError(arrayName string, index int, err error) string switch { // Case: property value is an object (not allowed) case isPropertyError && isTypeError(cause, "object"): - sb.WriteString(fmt.Sprintf( - "Invalid property '%s': objects are not allowed in the property bag. Use only strings, numbers, booleans, nulls, or arrays of these types.", - propertyName, - )) + fmt.Fprintf(&sb, "Invalid property '%s': objects are not allowed in the property bag. Use only strings, numbers, booleans, nulls, or arrays of these types.", + propertyName) // Case: array contains a nested object (also not allowed) case isPropertyError && isNotError(cause): - sb.WriteString(fmt.Sprintf( - "Invalid property '%s': array contains an object. Arrays must contain only primitive values (string, number, boolean, or null).", - propertyName, - )) + fmt.Fprintf(&sb, "Invalid property '%s': array contains an object. Arrays must contain only primitive values (string, number, boolean, or null).", + propertyName) default: sb.WriteString(cause.Error()) diff --git a/cmd/api/src/test/integration/dawgs.go b/cmd/api/src/test/integration/dawgs.go index 968762a440d..2ccbcd61275 100644 --- a/cmd/api/src/test/integration/dawgs.go +++ b/cmd/api/src/test/integration/dawgs.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/test" "github.com/specterops/bloodhound/cmd/api/src/test/integration/utils" @@ -58,7 +59,7 @@ func OpenGraphDB(t *testing.T, schema graph.Schema) graph.Database { connConf := pgtestdb.Custom(t, GetPostgresConfig(cfg), pgtestdb.NoopMigrator{}) cfg.Database.Connection = connConf.URL() - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) test.RequireNilErrf(t, err, "Failed to create new pgx pool: %v", err) graphDatabase, err = dawgs.Open(context.Background(), cfg.GraphDriver, dawgs.Config{ ConnectionString: cfg.Database.PostgreSQLConnectionString(), diff --git a/go.mod b/go.mod index c6b506d53b6..a069d2e279c 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,8 @@ require ( cuelang.org/go v0.16.0 github.com/Masterminds/semver/v3 v3.4.0 github.com/RoaringBitmap/roaring/v2 v2.16.0 + github.com/aws/aws-sdk-go-v2/config v1.32.13 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.21 github.com/bloodhoundad/azurehound/v2 v2.12.1 github.com/cespare/xxhash/v2 v2.3.0 github.com/channelmeter/iso8601duration v0.0.0-20150204201828-8da3af7a2a61 @@ -47,7 +49,7 @@ require ( github.com/prometheus/client_golang v1.22.0 github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 github.com/shirou/gopsutil/v3 v3.24.5 - github.com/specterops/dawgs v0.5.3 + github.com/specterops/dawgs v0.5.4 github.com/stretchr/testify v1.11.1 github.com/teambition/rrule-go v1.8.2 github.com/ulule/limiter/v3 v3.11.2 @@ -101,10 +103,8 @@ require ( github.com/ashanbrown/forbidigo/v2 v2.3.0 // indirect github.com/ashanbrown/makezero/v2 v2.1.0 // indirect github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect - github.com/aws/aws-sdk-go-v2/config v1.32.13 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect - github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.21 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect diff --git a/go.sum b/go.sum index 93ef5efb4b0..89574c64a7e 100644 --- a/go.sum +++ b/go.sum @@ -1742,8 +1742,8 @@ github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJ github.com/sourcegraph/go-diff v0.7.0 h1:9uLlrd5T46OXs5qpp8L/MTltk0zikUGi0sNNyCpA8G0= github.com/sourcegraph/go-diff v0.7.0/go.mod h1:iBszgVvyxdc8SFZ7gm69go2KDdt3ag071iBaWPF6cjs= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/specterops/dawgs v0.5.3 h1:I4vorE2lN1zjUQ+1Ebarj3z/OKSe8hcqLFBY4Yf6kVE= -github.com/specterops/dawgs v0.5.3/go.mod h1:lduDY0VNWpdzNt9Cym+/Owky/FICB3+gr+TLSaY2SpE= +github.com/specterops/dawgs v0.5.4 h1:zQQ9x0YgoYTUyYiZiud8y8YhTHOsty7UtSgaj1UGUGQ= +github.com/specterops/dawgs v0.5.4/go.mod h1:6TsRAbrHBNR+JnQJZS1iv4oMOonD13M5rAAKEvYsCUs= github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/spf13/afero v1.9.2/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcDf8Y= diff --git a/packages/go/analysis/ad/ad_integration_suite_test.go b/packages/go/analysis/ad/ad_integration_suite_test.go index 90a0f5d577f..62e27db1a6c 100644 --- a/packages/go/analysis/ad/ad_integration_suite_test.go +++ b/packages/go/analysis/ad/ad_integration_suite_test.go @@ -28,6 +28,7 @@ import ( "github.com/gofrs/uuid" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/test/integration/utils" @@ -58,7 +59,7 @@ func setupIntegrationTestSuite(t *testing.T) IntegrationTestSuite { cfg, err := config.NewDefaultConnectionConfiguration(connConf.URL()) require.NoError(t, err) - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ diff --git a/packages/go/analysis/analysis_integration_test.go b/packages/go/analysis/analysis_integration_test.go index 6e51d614578..2738e6fc9f5 100644 --- a/packages/go/analysis/analysis_integration_test.go +++ b/packages/go/analysis/analysis_integration_test.go @@ -26,6 +26,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/database" @@ -61,7 +62,7 @@ func setupIntegrationTestSuite(t *testing.T) IntegrationTestSuite { cfg.Database.Connection = connConf.URL() - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) gormDB, dbPool, err := database.OpenDatabase(cfg.Database) diff --git a/packages/go/analysis/azure/azure_integration_suite_test.go b/packages/go/analysis/azure/azure_integration_suite_test.go index 007fa63679c..9e026d41dae 100644 --- a/packages/go/analysis/azure/azure_integration_suite_test.go +++ b/packages/go/analysis/azure/azure_integration_suite_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/peterldowns/pgtestdb" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/test/integration/utils" @@ -55,7 +56,7 @@ func setupIntegrationTestSuite(t *testing.T) IntegrationTestSuite { cfg, err := config.NewDefaultConnectionConfiguration(connConf.URL()) require.NoError(t, err) - pool, err := pg.NewPool(cfg.Database) + pool, err := dbpool.NewDawgsPool(cfg.Database) require.NoError(t, err) graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ diff --git a/packages/go/graphify/graph/graph.go b/packages/go/graphify/graph/graph.go index 078b0257e7d..8600ea2c5a0 100644 --- a/packages/go/graphify/graph/graph.go +++ b/packages/go/graphify/graph/graph.go @@ -28,6 +28,7 @@ import ( "strings" "time" + "github.com/specterops/bloodhound/cmd/api/src/api/dbpool" "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/config" "github.com/specterops/bloodhound/cmd/api/src/database" @@ -416,7 +417,7 @@ func getNodesAndEdges(ctx context.Context, database graph.Database) ([]*graph.No } func initializeGraphDatabase(ctx context.Context, cfg config.Configuration) (graph.Database, error) { - if pool, err := pg.NewPool(cfg.Database); err != nil { + if pool, err := dbpool.NewDawgsPool(cfg.Database); err != nil { return nil, fmt.Errorf("error creating postgres connection: %w", err) } else if database, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: size.Gibibyte, diff --git a/packages/go/schemagen/generator/sql.go b/packages/go/schemagen/generator/sql.go index a2e944e408f..c225ad47bb2 100644 --- a/packages/go/schemagen/generator/sql.go +++ b/packages/go/schemagen/generator/sql.go @@ -178,7 +178,7 @@ func GenerateExtensionSQLAzure(dir string, azSchema model.Azure) error { func GenerateExtensionSQL(name string, displayName string, version string, namespace, dir string, fileName string, nodeKinds []model.StringEnum, relationshipKinds []model.StringEnum, pathfindingRelationshipKinds []model.StringEnum) error { var sb strings.Builder - sb.WriteString(fmt.Sprintf("-- Code generated by Cuelang code gen. DO NOT EDIT!\n-- Cuelang source: %s/", SchemaSourceName)) + fmt.Fprintf(&sb, "-- Code generated by Cuelang code gen. DO NOT EDIT!\n-- Cuelang source: %s/", SchemaSourceName) sb.WriteString(` CREATE OR REPLACE FUNCTION genscript_upsert_kind(node_kind_name TEXT) RETURNS SMALLINT AS $$ @@ -300,20 +300,20 @@ $$ LANGUAGE plpgsql; sb.WriteString("\nDO $$\nDECLARE\n\textension_id INT;\n\tenvironment_id INT;\nBEGIN\n\tLOCK schema_extensions, schema_node_kinds, schema_relationship_kinds, kind;\n\n") - sb.WriteString(fmt.Sprintf("\tIF NOT EXISTS (SELECT id FROM schema_extensions WHERE name = '%s') THEN\n", name)) - sb.WriteString(fmt.Sprintf("\t\tINSERT INTO schema_extensions (name, display_name, version, is_builtin, namespace, created_at, updated_at) VALUES ('%s', '%s', '%s', true, '%s', NOW(), NOW()) RETURNING id INTO extension_id;\n", name, displayName, version, namespace)) + fmt.Fprintf(&sb, "\tIF NOT EXISTS (SELECT id FROM schema_extensions WHERE name = '%s') THEN\n", name) + fmt.Fprintf(&sb, "\t\tINSERT INTO schema_extensions (name, display_name, version, is_builtin, namespace, created_at, updated_at) VALUES ('%s', '%s', '%s', true, '%s', NOW(), NOW()) RETURNING id INTO extension_id;\n", name, displayName, version, namespace) sb.WriteString("\tELSE\n") - sb.WriteString(fmt.Sprintf("\t\tUPDATE schema_extensions SET display_name = '%s', version = '%s', namespace = '%s', updated_at = NOW() WHERE name = '%s' RETURNING id INTO extension_id;\n", displayName, version, namespace, name)) + fmt.Fprintf(&sb, "\t\tUPDATE schema_extensions SET display_name = '%s', version = '%s', namespace = '%s', updated_at = NOW() WHERE name = '%s' RETURNING id INTO extension_id;\n", displayName, version, namespace, name) sb.WriteString("\tEND IF;\n\n") sb.WriteString("\t-- Insert Node Kinds\n") for _, kind := range nodeKinds { - sb.WriteString(fmt.Sprintf("\tPERFORM genscript_upsert_kind('%s');\n", kind.GetRepresentation())) + fmt.Fprintf(&sb, "\tPERFORM genscript_upsert_kind('%s');\n", kind.GetRepresentation()) } sb.WriteString("\n\t-- Insert Relationship Kinds\n") for _, kind := range relationshipKinds { - sb.WriteString(fmt.Sprintf("\tPERFORM genscript_upsert_kind('%s');\n", kind.GetRepresentation())) + fmt.Fprintf(&sb, "\tPERFORM genscript_upsert_kind('%s');\n", kind.GetRepresentation()) } sb.WriteString("\n") @@ -321,7 +321,7 @@ $$ LANGUAGE plpgsql; for _, kind := range nodeKinds { iconInfo, found := nodeIcons[kind.GetRepresentation()] - sb.WriteString(fmt.Sprintf("\tPERFORM genscript_upsert_schema_node_kind(extension_id, '%s', '%s', '', %t, '%s', '%s');\n", kind.GetRepresentation(), kind.GetRepresentation(), found, iconInfo.Icon, iconInfo.Color)) + fmt.Fprintf(&sb, "\tPERFORM genscript_upsert_schema_node_kind(extension_id, '%s', '%s', '', %t, '%s', '%s');\n", kind.GetRepresentation(), kind.GetRepresentation(), found, iconInfo.Icon, iconInfo.Color) } traversableMap := make(map[string]struct{}) @@ -335,7 +335,7 @@ $$ LANGUAGE plpgsql; for _, kind := range relationshipKinds { _, traversable := traversableMap[kind.GetRepresentation()] - sb.WriteString(fmt.Sprintf("\tPERFORM genscript_upsert_schema_relationship_kind(extension_id, '%s', '', %t);\n", kind.GetRepresentation(), traversable)) + fmt.Fprintf(&sb, "\tPERFORM genscript_upsert_schema_relationship_kind(extension_id, '%s', '', %t);\n", kind.GetRepresentation(), traversable) } sb.WriteString("\n")