Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions cmd/api/src/api/dbpool/dbpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2026 Specter Ops, Inc.
//
// Licensed under the Apache License, Version 2.0
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbpool

import (
"context"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/bloodhound/cmd/api/src/config"
"github.com/specterops/dawgs/drivers/pg"
)

const (
poolInitConnectionTimeout = time.Second * 10
)

func newPoolCfg(cfg config.DatabaseConfiguration) (*pgxpool.Config, error) {
poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString())
if err != nil {
return nil, err
}

// TODO: Min and Max connections for the pool should be configurable
poolCfg.MinConns = 5
poolCfg.MaxConns = 50

if cfg.EnableRDSIAMAuth {
// Only enable the BeforeConnect handler if RDS IAM Auth is enabled
poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error {
if newPoolCfg, err := pgxpool.ParseConfig(cfg.RDSIAMAuthConnectionString()); err != nil {
return err
} else {
connCfg.Host = newPoolCfg.ConnConfig.Host
connCfg.Port = newPoolCfg.ConnConfig.Port

connCfg.User = newPoolCfg.ConnConfig.User
connCfg.Password = newPoolCfg.ConnConfig.Password
connCfg.Database = newPoolCfg.ConnConfig.Database
}

return nil
}
}

return poolCfg, nil
}

func NewDawgsPool(cfg config.DatabaseConfiguration) (*pgxpool.Pool, error) {
if poolCfg, err := newPoolCfg(cfg); err != nil {
return nil, err
} else {
return pg.NewPool(poolCfg)
}
}

func NewAppPool(cfg config.DatabaseConfiguration) (*pgxpool.Pool, error) {
poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout)
defer done()

if poolCfg, err := newPoolCfg(cfg); err != nil {
return nil, err
} else {
return pgxpool.NewWithConfig(poolCtx, poolCfg)
}
}
15 changes: 8 additions & 7 deletions cmd/api/src/api/tools/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
"github.com/specterops/bloodhound/cmd/api/src/api"
"github.com/specterops/bloodhound/cmd/api/src/api/dbpool"
"github.com/specterops/bloodhound/cmd/api/src/config"
"github.com/specterops/bloodhound/packages/go/bhlog/attr"
"github.com/specterops/bloodhound/packages/go/bhlog/measure"
Expand Down Expand Up @@ -255,7 +256,7 @@ type PGMigrator struct {
migrationCancelFunc func()
State MigratorState
lock *sync.Mutex
Cfg config.Configuration
cfg config.Configuration
}

func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSchema graph.Schema, graphDBSwitch *graph.DatabaseSwitch) *PGMigrator {
Expand All @@ -265,7 +266,7 @@ func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSch
ServerCtx: serverCtx,
State: StateIdle,
lock: &sync.Mutex{},
Cfg: cfg,
cfg: cfg,
}
}

Expand Down Expand Up @@ -297,7 +298,7 @@ func (s *PGMigrator) SwitchPostgreSQL(response http.ResponseWriter, request *htt
}, http.StatusInternalServerError, response)
} else if err := pgDB.AssertSchema(request.Context(), s.graphSchema); err != nil {
slog.ErrorContext(request.Context(), "Unable to assert graph schema in PostgreSQL", attr.Error(err))
} else if err := SetGraphDriver(request.Context(), s.Cfg, pg.DriverName); err != nil {
} else if err := SetGraphDriver(request.Context(), s.cfg, pg.DriverName); err != nil {
api.WriteJSONResponse(request.Context(), map[string]any{
"error": fmt.Errorf("failed updating graph database driver preferences: %w", err),
}, http.StatusInternalServerError, response)
Expand All @@ -314,7 +315,7 @@ func (s *PGMigrator) SwitchNeo4j(response http.ResponseWriter, request *http.Req
api.WriteJSONResponse(request.Context(), map[string]any{
"error": fmt.Errorf("failed connecting to Neo4j: %w", err),
}, http.StatusInternalServerError, response)
} else if err := SetGraphDriver(request.Context(), s.Cfg, neo4j.DriverName); err != nil {
} else if err := SetGraphDriver(request.Context(), s.cfg, neo4j.DriverName); err != nil {
api.WriteJSONResponse(request.Context(), map[string]any{
"error": fmt.Errorf("failed updating graph database driver preferences: %w", err),
}, http.StatusInternalServerError, response)
Expand Down Expand Up @@ -449,12 +450,12 @@ func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http
}

func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) {
if pool, err := pg.NewPool(s.Cfg.Database); err != nil {
if pool, err := dbpool.NewDawgsPool(s.cfg.Database); err != nil {
return nil, err
} else {
return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
ConnectionString: s.Cfg.Database.PostgreSQLConnectionString(),
ConnectionString: s.cfg.Database.PostgreSQLConnectionString(),
Pool: pool,
})
}
Expand All @@ -463,6 +464,6 @@ func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) {
func (s *PGMigrator) OpenNeo4jGraphConnection() (graph.Database, error) {
return dawgs.Open(s.ServerCtx, neo4j.DriverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
ConnectionString: s.Cfg.Neo4J.Neo4jConnectionString(),
ConnectionString: s.cfg.Neo4J.Neo4JConnectionString(),
})
}
5 changes: 3 additions & 2 deletions cmd/api/src/bootstrap/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/bloodhound/cmd/api/src/api/dbpool"
"github.com/specterops/bloodhound/cmd/api/src/api/tools"
"github.com/specterops/bloodhound/cmd/api/src/config"
"github.com/specterops/dawgs"
Expand Down Expand Up @@ -92,13 +93,13 @@ func ConnectGraph(ctx context.Context, cfg config.Configuration) (*graph.Databas
switch driverName {
case neo4j.DriverName:
slog.InfoContext(ctx, "Connecting to graph using Neo4j")
connectionString = cfg.Neo4J.Neo4jConnectionString()
connectionString = cfg.Neo4J.Neo4JConnectionString()

case pg.DriverName:
slog.InfoContext(ctx, "Connecting to graph using PostgreSQL")
connectionString = cfg.Database.PostgreSQLConnectionString()

pool, err = pg.NewPool(cfg.Database)
pool, err = dbpool.NewDawgsPool(cfg.Database)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/api/src/cmd/dawgs-harness/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ import (

"github.com/jackc/pgx/v5/pgxpool"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/specterops/bloodhound/cmd/api/src/api/dbpool"
"github.com/specterops/bloodhound/cmd/api/src/cmd/dawgs-harness/tests"
"github.com/specterops/bloodhound/cmd/api/src/config"
"github.com/specterops/bloodhound/packages/go/bhlog"
schema "github.com/specterops/bloodhound/packages/go/graphschema"
"github.com/specterops/dawgs"
"github.com/specterops/dawgs/drivers"
"github.com/specterops/dawgs/drivers/neo4j"
"github.com/specterops/dawgs/drivers/pg"
"github.com/specterops/dawgs/graph"
Expand All @@ -46,14 +46,14 @@ func fatalf(format string, args ...any) {
os.Exit(1)
}

func RunTestSuite(ctx context.Context, connectionStr, driverName string, cfg drivers.DatabaseConfiguration) tests.TestSuite {
func RunTestSuite(ctx context.Context, connectionStr, driverName string, cfg config.DatabaseConfiguration) tests.TestSuite {
var (
pool *pgxpool.Pool
err error
)

if driverName == pg.DriverName {
pool, err = pg.NewPool(cfg)
pool, err = dbpool.NewDawgsPool(cfg)
if err != nil {
fatalf("Failed creating a new pgxpool: %s", err)
}
Expand Down
138 changes: 104 additions & 34 deletions cmd/api/src/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@
package config

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/url"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"

awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"

"github.com/specterops/bloodhound/cmd/api/src/serde"
"github.com/specterops/bloodhound/packages/go/bhlog/attr"
"github.com/specterops/bloodhound/packages/go/crypto"
dawgs "github.com/specterops/dawgs/drivers"
)

const (
Expand All @@ -51,6 +56,71 @@ func (s TLSConfiguration) Enabled() bool {
return s.CertFile != "" && s.KeyFile != ""
}

type DatabaseConfiguration struct {
Connection string `json:"connection"`
Address string `json:"addr"`
Database string `json:"database"`
Username string `json:"username"`
Secret string `json:"secret"`
MaxConcurrentSessions int `json:"max_concurrent_sessions"`
EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"`
}

func (s DatabaseConfiguration) PostgreSQLConnectionString() string {
if s.Connection != "" {
return s.Connection
}

return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database)
}

func (s DatabaseConfiguration) Neo4JConnectionString() string {
if s.Connection != "" {
return s.Connection
}

return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database)
}

func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string {
if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil {
slog.Error("AWS Config Loading Error", slog.String("err", err.Error()))
} else {
// Must use instance endpoint with IAM auth
endpoint := s.LookupEndpoint()

slog.Info("Requesting RDS IAM Auth Token")

if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil {
slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error()))
} else {
slog.Info("RDS IAM Auth Token Created")
return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database)
}
}

slog.Warn("Failed to create IAM auth token. Falling back to default Postgres connection string")
return s.PostgreSQLConnectionString()
}

func (s DatabaseConfiguration) LookupEndpoint() string {
host, port, err := net.SplitHostPort(s.Address)
if err != nil {
slog.Warn("Missing port in address. Using default port 5432.", slog.String("err", err.Error()))
host = s.Address
port = "5432"
}

if hostCName, err := net.DefaultResolver.LookupCNAME(context.TODO(), host); err != nil {
slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error()))
} else {
host = hostCName
}

// Instance endpoint always returns with a trailing '.'
return net.JoinHostPort(strings.TrimSuffix(host, "."), port)
}

type CollectorManifest struct {
Latest string `json:"latest"`
Versions []CollectorVersion `json:"versions"`
Expand Down Expand Up @@ -111,39 +181,39 @@ type DefaultAdminConfiguration struct {
}

type Configuration struct {
Version int `json:"version"`
BindAddress string `json:"bind_addr"`
SlowQueryThreshold int64 `json:"slow_query_threshold"`
MaxGraphQueryCacheSize int `json:"max_graphdb_cache_size"`
MaxAPICacheSize int `json:"max_api_cache_size"`
MetricsPort string `json:"metrics_port"`
RootURL serde.URL `json:"root_url"`
WorkDir string `json:"work_dir"`
LogLevel string `json:"log_level"`
LogPath string `json:"log_path"`
TLS TLSConfiguration `json:"tls"`
GraphDriver string `json:"graph_driver"`
Database dawgs.DatabaseConfiguration `json:"database"`
Neo4J dawgs.DatabaseConfiguration `json:"neo4j"`
Crypto CryptoConfiguration `json:"crypto"`
SAML SAMLConfiguration `json:"saml"`
DefaultAdmin DefaultAdminConfiguration `json:"default_admin"`
CollectorsBucketURL serde.URL `json:"collectors_bucket_url"`
CollectorsBasePath string `json:"collectors_base_path"`
DatapipeInterval int `json:"datapipe_interval"`
EnableStartupWaitPeriod bool `json:"enable_startup_wait_period"`
EnableAPILogging bool `json:"enable_api_logging"`
EnableCypherMutations bool `json:"enable_cypher_mutations"`
DisableAnalysis bool `json:"disable_analysis"`
DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"`
DisableIngest bool `json:"disable_ingest"`
DisableMigrations bool `json:"disable_migrations"`
GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"`
EnableTextLogger bool `json:"enable_text_logger"`
RecreateDefaultAdmin bool `json:"recreate_default_admin"`
EnableUserAnalytics bool `json:"enable_user_analytics"`
ForceDownloadEmbeddedCollectors bool `json:"force_download_embedded_collectors"`
EnableAuditLogStdout bool `json:"enable_audit_log_stdout"`
Version int `json:"version"`
BindAddress string `json:"bind_addr"`
SlowQueryThreshold int64 `json:"slow_query_threshold"`
MaxGraphQueryCacheSize int `json:"max_graphdb_cache_size"`
MaxAPICacheSize int `json:"max_api_cache_size"`
MetricsPort string `json:"metrics_port"`
RootURL serde.URL `json:"root_url"`
WorkDir string `json:"work_dir"`
LogLevel string `json:"log_level"`
LogPath string `json:"log_path"`
TLS TLSConfiguration `json:"tls"`
GraphDriver string `json:"graph_driver"`
Database DatabaseConfiguration `json:"database"`
Neo4J DatabaseConfiguration `json:"neo4j"`
Crypto CryptoConfiguration `json:"crypto"`
SAML SAMLConfiguration `json:"saml"`
DefaultAdmin DefaultAdminConfiguration `json:"default_admin"`
CollectorsBucketURL serde.URL `json:"collectors_bucket_url"`
CollectorsBasePath string `json:"collectors_base_path"`
DatapipeInterval int `json:"datapipe_interval"`
EnableStartupWaitPeriod bool `json:"enable_startup_wait_period"`
EnableAPILogging bool `json:"enable_api_logging"`
EnableCypherMutations bool `json:"enable_cypher_mutations"`
DisableAnalysis bool `json:"disable_analysis"`
DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"`
DisableIngest bool `json:"disable_ingest"`
DisableMigrations bool `json:"disable_migrations"`
GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"`
EnableTextLogger bool `json:"enable_text_logger"`
RecreateDefaultAdmin bool `json:"recreate_default_admin"`
EnableUserAnalytics bool `json:"enable_user_analytics"`
ForceDownloadEmbeddedCollectors bool `json:"force_download_embedded_collectors"`
EnableAuditLogStdout bool `json:"enable_audit_log_stdout"`
}

func (s Configuration) TempDirectory() string {
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestSetValuesFromEnv(t *testing.T) {
"bhe_database_secret=supersecretpassword",
}))

assert.Equal(t, "neo4j://neo4j:neo4jj@localhost:7070/neo4j", cfg.Neo4J.Neo4jConnectionString())
assert.Equal(t, "neo4j://neo4j:neo4jj@localhost:7070/neo4j", cfg.Neo4J.Neo4JConnectionString())
assert.Equal(t, "postgresql://bhe:supersecretpassword@localhost:5432/bhe", cfg.Database.PostgreSQLConnectionString())
})

Expand Down
Loading
Loading