diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index 6f72302..898ab4b 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -26,6 +26,7 @@ import ( "time" "github.com/specterops/dawgs" + "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/drivers/pg" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/opengraph" @@ -39,10 +40,11 @@ func main() { driver = flag.String("driver", "pg", "database driver (pg, neo4j)") connStr = flag.String("connection", "", "database connection string (or PG_CONNECTION_STRING)") iterations = flag.Int("iterations", 10, "timed iterations per scenario") - output = flag.String("output", "", "markdown output file (default: stdout)") - datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") + output = flag.String("output", "", "markdown output file (default: stdout)") + datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") localDataset = flag.String("local-dataset", "", "additional local dataset (e.g. local/phantom)") onlyDataset = flag.String("dataset", "", "run only this dataset (e.g. diamond, local/phantom)") + dbcfg = drivers.DatabaseConfiguration{} ) flag.Parse() @@ -55,6 +57,8 @@ func main() { fatal("no connection string: set -connection flag or PG_CONNECTION_STRING env var") } + dbcfg.Connection = conn + ctx := context.Background() cfg := dawgs.Config{ @@ -63,7 +67,7 @@ func main() { } if *driver == pg.DriverName { - pool, err := pg.NewPool(conn) + pool, err := pg.NewPool(dbcfg) if err != nil { fatal("failed to create pool: %v", err) } diff --git a/cmd/export/main.go b/cmd/export/main.go index 43ac05c..8dc23ad 100644 --- a/cmd/export/main.go +++ b/cmd/export/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" + "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/drivers/pg" "github.com/specterops/dawgs/opengraph" "github.com/specterops/dawgs/util/size" @@ -16,7 +17,10 @@ func main() { connStr = "postgresql://bloodhound:bloodhoundcommunityedition@localhost:5432/bloodhound" } - pool, err := pg.NewPool(connStr) + dbcfg := drivers.DatabaseConfiguration{} + dbcfg.Connection = connStr + + pool, err := pg.NewPool(dbcfg) if err != nil { fmt.Fprintf(os.Stderr, "failed to connect: %v\n", err) os.Exit(1) diff --git a/drivers/config.go b/drivers/config.go new file mode 100644 index 0000000..f8bbdce --- /dev/null +++ b/drivers/config.go @@ -0,0 +1,76 @@ +package drivers + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/url" + "strings" + + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" +) + +type DatabaseConfiguration struct { + Connection string `json:"connection"` + Address string `json:"addr"` + Database string `json:"database"` + Username string `json:"username"` + Secret string `json:"secret"` + MaxConcurrentSessions int `json:"max_concurrent_sessions"` + EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"` +} + +func (s DatabaseConfiguration) defaultPostgreSQLConnectionString() string { + if s.Connection != "" { + return s.Connection + } + + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(s.Secret), s.Address, s.Database) +} + +func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string { + slog.Info("Loading RDS Configuration With IAM Auth") + + if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil { + slog.Error("AWS Config Loading Error", slog.String("err", err.Error())) + } else { + host := s.Address + + if hostCName, err := net.LookupCNAME(s.Address); err != nil { + slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error())) + } else { + host = hostCName + } + + endpoint := strings.TrimSuffix(host, ".") + ":5432" + + slog.Info("Requesting RDS IAM Auth Token") + + if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil { + slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error())) + } else { + slog.Info("RDS IAM Auth Token Created") + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database) + } + } + + return s.defaultPostgreSQLConnectionString() +} + +func (s DatabaseConfiguration) PostgreSQLConnectionString() string { + if s.EnableRDSIAMAuth { + return s.RDSIAMAuthConnectionString() + } + + return s.defaultPostgreSQLConnectionString() +} + +func (s DatabaseConfiguration) Neo4jConnectionString() string { + if s.Connection == "" { + return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database) + } + + return s.Connection +} diff --git a/drivers/pg/pg.go b/drivers/pg/pg.go index 1e2d0a2..047c1de 100644 --- a/drivers/pg/pg.go +++ b/drivers/pg/pg.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs" "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/graph" ) @@ -50,15 +51,12 @@ func afterPooledConnectionRelease(conn *pgx.Conn) bool { return true } -func NewPool(connectionString string) (*pgxpool.Pool, error) { - if connectionString == "" { - return nil, fmt.Errorf("graph connection requires a connection url to be set") - } +func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) { poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout) defer done() - poolCfg, err := pgxpool.ParseConfig(connectionString) + poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()) if err != nil { return nil, err } @@ -73,6 +71,21 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) { poolCfg.AfterConnect = afterPooledConnectionEstablished poolCfg.AfterRelease = afterPooledConnectionRelease + if cfg.EnableRDSIAMAuth { + // Only enable the BeforeConnect handler if RDS IAM Auth is enabled + poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error { + slog.Debug("New Connection RDS IAM Auth") + + if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil { + return err + } else { + connCfg.Password = newPoolCfg.ConnConfig.Password + } + + return nil + } + } + pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg) if err != nil { return nil, err diff --git a/go.mod b/go.mod index b6f080b..5777b1d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( cuelang.org/go v0.16.0 github.com/RoaringBitmap/roaring/v2 v2.16.0 github.com/antlr4-go/antlr/v4 v4.13.1 + github.com/aws/aws-sdk-go-v2/config v1.31.13 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 github.com/axiomhq/hyperloglog v0.2.6 github.com/bits-and-blooms/bitset v1.24.4 github.com/cespare/xxhash/v2 v2.3.0 @@ -17,6 +19,18 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/cockroachdb/apd/v3 v3.2.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect diff --git a/go.sum b/go.sum index b4fd42a..ed5b1b2 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,34 @@ github.com/RoaringBitmap/roaring/v2 v2.16.0 h1:Kys1UNf49d5W8Tq3bpuAhIr/Z8/yPB+59 github.com/RoaringBitmap/roaring/v2 v2.16.0/go.mod h1:eq4wdNXxtJIS/oikeCzdX1rBzek7ANzbth041hrU8Q4= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM= +github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 h1:xfgjONWMae6+y//dlhVukwt9N+I++FPuiwcQt7DI7Qg= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10/go.mod h1:FO6aarJTHA2N3S8F2A4wKfnX9Jr6MPerJFaqoLgTctU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/axiomhq/hyperloglog v0.2.6 h1:sRhvvF3RIXWQgAXaTphLp4yJiX4S0IN3MWTaAgZoRJw= github.com/axiomhq/hyperloglog v0.2.6/go.mod h1:YjX/dQqCR/7QYX0g8mu8UZAjpIenz1FKM71UEsjFoTo= github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE= diff --git a/integration/harness.go b/integration/harness.go index 5cc5c70..a7c2e86 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/specterops/dawgs" + "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/drivers/pg" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/opengraph" @@ -79,8 +80,11 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) ConnectionString: connStr, } + dbcfg := drivers.DatabaseConfiguration{} + dbcfg.Connection = connStr + if driver == pg.DriverName { - pool, err := pg.NewPool(connStr) + pool, err := pg.NewPool(dbcfg) if err != nil { t.Fatalf("Failed to create PG pool: %v", err) }