diff --git a/integration/cypher_test.go b/integration/cypher_test.go index bd1f11a..ae18fc3 100644 --- a/integration/cypher_test.go +++ b/integration/cypher_test.go @@ -92,7 +92,10 @@ func TestCypher(t *testing.T) { db, ctx := SetupDB(t, datasetNames...) - driver := *driverFlag + driver, err := driverFromConnStr(os.Getenv("CONNECTION_STRING")) + if err != nil { + t.Fatalf("Failed to detect driver: %v", err) + } for _, g := range groups { ClearGraph(t, db, ctx) diff --git a/integration/harness.go b/integration/harness.go index 2edc508..5cc5c70 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -20,6 +20,7 @@ import ( "context" "flag" "fmt" + "net/url" "os" "sort" "strings" @@ -31,16 +32,30 @@ import ( "github.com/specterops/dawgs/opengraph" "github.com/specterops/dawgs/util/size" - // Register drivers - _ "github.com/specterops/dawgs/drivers/neo4j" + "github.com/specterops/dawgs/drivers/neo4j" ) var ( - driverFlag = flag.String("driver", "pg", "database driver to test against (pg, neo4j)") - connStrFlag = flag.String("connection", "", "database connection string (overrides PG_CONNECTION_STRING env var)") localDatasetFlag = flag.String("local-dataset", "", "name of a local dataset to test (e.g. local/phantom)") ) +// driverFromConnStr returns the dawgs driver name based on the connection string scheme. +func driverFromConnStr(connStr string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("failed to parse connection string: %w", err) + } + + switch u.Scheme { + case "postgresql", "postgres": + return pg.DriverName, nil + case neo4j.DriverName, "neo4j+s", "neo4j+ssc": + return neo4j.DriverName, nil + default: + return "", fmt.Errorf("unknown connection string scheme %q", u.Scheme) + } +} + // SetupDB opens a database connection for the selected driver, asserts a schema // derived from the given datasets, and registers cleanup. Returns the database // and a background context. @@ -49,12 +64,14 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) ctx := context.Background() - connStr := *connStrFlag + connStr := os.Getenv("CONNECTION_STRING") if connStr == "" { - connStr = os.Getenv("PG_CONNECTION_STRING") + t.Fatal("CONNECTION_STRING env var is not set") } - if connStr == "" { - t.Fatal("no connection string: set -connection flag or PG_CONNECTION_STRING env var") + + driver, err := driverFromConnStr(connStr) + if err != nil { + t.Fatalf("Failed to detect driver: %v", err) } cfg := dawgs.Config{ @@ -62,8 +79,7 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) ConnectionString: connStr, } - // PG needs a pool with composite type registration - if *driverFlag == pg.DriverName { + if driver == pg.DriverName { pool, err := pg.NewPool(connStr) if err != nil { t.Fatalf("Failed to create PG pool: %v", err) @@ -71,7 +87,7 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) cfg.Pool = pool } - db, err := dawgs.Open(ctx, *driverFlag, cfg) + db, err := dawgs.Open(ctx, driver, cfg) if err != nil { t.Fatalf("Failed to open database: %v", err) }