Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion integration/cypher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ func TestCypher(t *testing.T) {

db, ctx := SetupDB(t, datasetNames...)

driver := *driverFlag
driver, err := driverFromConnStr(os.Getenv("CONNECTION_STRING"))
if err != nil {
t.Fatalf("Failed to detect driver: %v", err)
}

for _, g := range groups {
ClearGraph(t, db, ctx)
Expand Down
38 changes: 27 additions & 11 deletions integration/harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"flag"
"fmt"
"net/url"
"os"
"sort"
"strings"
Expand All @@ -31,16 +32,30 @@ import (
"github.com/specterops/dawgs/opengraph"
"github.com/specterops/dawgs/util/size"

// Register drivers
_ "github.com/specterops/dawgs/drivers/neo4j"
"github.com/specterops/dawgs/drivers/neo4j"
)

var (
driverFlag = flag.String("driver", "pg", "database driver to test against (pg, neo4j)")
connStrFlag = flag.String("connection", "", "database connection string (overrides PG_CONNECTION_STRING env var)")
localDatasetFlag = flag.String("local-dataset", "", "name of a local dataset to test (e.g. local/phantom)")
)

// driverFromConnStr returns the dawgs driver name based on the connection string scheme.
func driverFromConnStr(connStr string) (string, error) {
u, err := url.Parse(connStr)
if err != nil {
return "", fmt.Errorf("failed to parse connection string: %w", err)
}

switch u.Scheme {
case "postgresql", "postgres":
return pg.DriverName, nil
case neo4j.DriverName, "neo4j+s", "neo4j+ssc":
return neo4j.DriverName, nil
default:
return "", fmt.Errorf("unknown connection string scheme %q", u.Scheme)
}
}

// SetupDB opens a database connection for the selected driver, asserts a schema
// derived from the given datasets, and registers cleanup. Returns the database
// and a background context.
Expand All @@ -49,29 +64,30 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context)

ctx := context.Background()

connStr := *connStrFlag
connStr := os.Getenv("CONNECTION_STRING")
if connStr == "" {
connStr = os.Getenv("PG_CONNECTION_STRING")
t.Fatal("CONNECTION_STRING env var is not set")
}
if connStr == "" {
t.Fatal("no connection string: set -connection flag or PG_CONNECTION_STRING env var")

driver, err := driverFromConnStr(connStr)
if err != nil {
t.Fatalf("Failed to detect driver: %v", err)
}

cfg := dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
ConnectionString: connStr,
}

// PG needs a pool with composite type registration
if *driverFlag == pg.DriverName {
if driver == pg.DriverName {
pool, err := pg.NewPool(connStr)
if err != nil {
t.Fatalf("Failed to create PG pool: %v", err)
}
cfg.Pool = pool
}

db, err := dawgs.Open(ctx, *driverFlag, cfg)
db, err := dawgs.Open(ctx, driver, cfg)
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
Expand Down
Loading