From 18b889abf24611f6d4fa239e5b9afc8600a99c00 Mon Sep 17 00:00:00 2001 From: David Murphy Date: Tue, 31 Dec 2024 12:55:12 -0600 Subject: [PATCH 1/3] feat(stalk-go): initial port to go --- go.mod | 7 + go.sum | 8 + src/go/pt-stalk/collect.go | 300 +++++++++++++++++++++ src/go/pt-stalk/collect_test.go | 292 ++++++++++++++++++++ src/go/pt-stalk/examples/plugins/sample.sh | 36 +++ src/go/pt-stalk/logger.go | 73 +++++ src/go/pt-stalk/main.go | 215 +++++++++++++++ src/go/pt-stalk/main_test.go | 212 +++++++++++++++ src/go/pt-stalk/plugin.go | 165 ++++++++++++ src/go/pt-stalk/stalk.go | 291 ++++++++++++++++++++ src/go/pt-stalk/stalk_test.go | 197 ++++++++++++++ src/go/pt-stalk/utils.go | 169 ++++++++++++ 12 files changed, 1965 insertions(+) create mode 100644 src/go/pt-stalk/collect.go create mode 100644 src/go/pt-stalk/collect_test.go create mode 100644 src/go/pt-stalk/examples/plugins/sample.sh create mode 100644 src/go/pt-stalk/logger.go create mode 100644 src/go/pt-stalk/main.go create mode 100644 src/go/pt-stalk/main_test.go create mode 100644 src/go/pt-stalk/plugin.go create mode 100644 src/go/pt-stalk/stalk.go create mode 100644 src/go/pt-stalk/stalk_test.go create mode 100644 src/go/pt-stalk/utils.go diff --git a/go.mod b/go.mod index 67c214789..0014b5a33 100644 --- a/go.mod +++ b/go.mod @@ -35,12 +35,18 @@ require ( k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 ) +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect +) + require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/go-sql-driver/mysql v1.8.1 github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/gofuzz v1.2.0 // indirect @@ -52,6 +58,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sevlyar/go-daemon v0.1.6 github.com/tklauser/go-sysconf v0.3.11 // indirect github.com/tklauser/numcpus v0.6.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index 8c3980ced..48f92aeab 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/AlekSi/pointer v1.2.0 h1:glcy/gc4h8HnG2Z3ZECSzZ1IX1x2JxRVuDzaJwQE0+w= github.com/AlekSi/pointer v1.2.0/go.mod h1:gZGfd3dpW4vEc/UlyfKKi1roIqcCgwOIvb0tSNSBle0= github.com/Ladicle/tabwriter v1.0.0 h1:DZQqPvMumBDwVNElso13afjYLNp0Z7pHqHnu0r4t9Dg= @@ -30,6 +32,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -53,6 +57,8 @@ github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef h1:A9HsByNhogrvm9cWb github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY= @@ -91,6 +97,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99 github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/sevlyar/go-daemon v0.1.6 h1:EUh1MDjEM4BI109Jign0EaknA2izkOyi0LV3ro3QQGs= +github.com/sevlyar/go-daemon v0.1.6/go.mod h1:6dJpPatBT9eUwM5VCw9Bt6CdX9Tk6UWvhW3MebLDRKE= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/src/go/pt-stalk/collect.go b/src/go/pt-stalk/collect.go new file mode 100644 index 000000000..c9e46c6f9 --- /dev/null +++ b/src/go/pt-stalk/collect.go @@ -0,0 +1,300 @@ +package main + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + "time" +) + +type Collector struct { + stalker *Stalker + db *sql.DB + outDir string + prefix string + wg sync.WaitGroup +} + +func (s *Stalker) collect(ctx context.Context, db *sql.DB, prefix string) error { + outDir := filepath.Join(s.config.Dest, prefix) + if err := os.MkdirAll(outDir, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %v", err) + } + + collector := &Collector{ + stalker: s, + db: db, + outDir: outDir, + prefix: prefix, + } + + // Start collection goroutines + collector.wg.Add(1) + go func() { + defer collector.wg.Done() + if err := collector.collectSystemMetrics(); err != nil { + s.logger.Error("System metrics collection failed: %v", err) + } + }() + + if !s.config.SystemOnly { + collector.wg.Add(1) + go func() { + defer collector.wg.Done() + if err := collector.collectMySQLMetrics(); err != nil { + s.logger.Error("MySQL metrics collection failed: %v", err) + } + }() + } + + // Wait for collections with timeout + done := make(chan struct{}) + go func() { + collector.wg.Wait() + close(done) + }() + + select { + case <-done: + s.logger.Info("Collection completed successfully") + case <-time.After(time.Duration(s.config.RunTime) * time.Second): + s.logger.Warn("Collection timed out after %d seconds", s.config.RunTime) + } + + return nil +} + +func (c *Collector) collectSystemMetrics() error { + if c.stalker.config.MySQLOnly { + return nil + } + + metrics := []struct { + name string + command string + args []string + }{ + {"uptime", "uptime", nil}, + {"uname", "uname", []string{"-a"}}, + {"vmstat", "vmstat", []string{"1"}}, + {"iostat", "iostat", []string{"-dx", "1"}}, + {"mpstat", "mpstat", []string{"1"}}, + {"free", "free", []string{"-m"}}, + {"df", "df", []string{"-h"}}, + {"dmesg", "dmesg", nil}, + {"netstat", "netstat", []string{"-antp"}}, + {"top", "top", []string{"-b", "-n", "1"}}, + } + + for _, metric := range metrics { + c.wg.Add(1) + go func(m struct { + name string + command string + args []string + }) { + defer c.wg.Done() + outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-%s.txt", c.prefix, m.name)) + if err := c.runCommand(m.command, m.args, outFile); err != nil { + c.stalker.logger.Error("Failed to collect %s: %v", m.name, err) + } + }(metric) + } + + // Collect special metrics that need custom handling + if c.stalker.config.CollectTcpdump { + c.wg.Add(1) + go func() { + defer c.wg.Done() + if err := c.collectTcpdump(); err != nil { + c.stalker.logger.Error("Failed to collect tcpdump: %v", err) + } + }() + } + + return nil +} + +func (c *Collector) collectMySQLMetrics() error { + queries := []struct { + name string + query string + }{ + {"variables", "SHOW GLOBAL VARIABLES"}, + {"status", "SHOW GLOBAL STATUS"}, + {"processlist", "SHOW FULL PROCESSLIST"}, + {"slave_status", "SHOW SLAVE STATUS"}, + {"innodb_status", "SHOW ENGINE INNODB STATUS"}, + {"mutex_status", "SHOW ENGINE INNODB MUTEX"}, + } + + for _, q := range queries { + c.wg.Add(1) + go func(query struct { + name string + query string + }) { + defer c.wg.Done() + outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-mysql-%s.txt", c.prefix, query.name)) + if err := c.collectMySQLQuery(query.query, outFile); err != nil { + c.stalker.logger.Error("Failed to collect MySQL %s: %v", query.name, err) + } + }(q) + } + + // Collect special MySQL metrics that need custom handling + if c.stalker.config.CollectGDB { + c.wg.Add(1) + go func() { + defer c.wg.Done() + if err := c.collectGDBStacktrace(); err != nil { + c.stalker.logger.Error("Failed to collect GDB stacktrace: %v", err) + } + }() + } + + return nil +} + +func (c *Collector) runCommand(command string, args []string, outFile string) error { + cmd := exec.Command(command, args...) + + out, err := os.Create(outFile) + if err != nil { + return fmt.Errorf("failed to create output file: %v", err) + } + defer out.Close() + + cmd.Stdout = out + cmd.Stderr = out + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("command failed: %v", err) + } + case <-time.After(time.Duration(c.stalker.config.RunTime) * time.Second): + if err := cmd.Process.Kill(); err != nil { + c.stalker.logger.Error("Failed to kill process: %v", err) + } + return fmt.Errorf("command timed out") + } + + return nil +} + +func (c *Collector) collectMySQLQuery(query string, outFile string) error { + rows, err := c.db.Query(query) + if err != nil { + return fmt.Errorf("query failed: %v", err) + } + defer rows.Close() + + out, err := os.Create(outFile) + if err != nil { + return fmt.Errorf("failed to create output file: %v", err) + } + defer out.Close() + + w := bufio.NewWriter(out) + + // Get column names + cols, err := rows.Columns() + if err != nil { + return fmt.Errorf("failed to get columns: %v", err) + } + + // Write header + fmt.Fprintf(w, "# %s\n", strings.Join(cols, "\t")) + + // Prepare values holders + vals := make([]interface{}, len(cols)) + for i := range vals { + vals[i] = new(sql.RawBytes) + } + + // Write data + for rows.Next() { + if err := rows.Scan(vals...); err != nil { + return fmt.Errorf("failed to scan row: %v", err) + } + + for i, val := range vals { + if i > 0 { + w.WriteString("\t") + } + if rb, ok := val.(*sql.RawBytes); ok { + w.Write(*rb) + } + } + w.WriteString("\n") + } + + return w.Flush() +} + +func (c *Collector) collectGDBStacktrace() error { + // Find MySQL process ID + var pid int + err := c.db.QueryRow("SELECT @@pid").Scan(&pid) + if err != nil { + return fmt.Errorf("failed to get MySQL PID: %v", err) + } + + outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-gdb.txt", c.prefix)) + + gdbCommands := fmt.Sprintf("attach %d\nthread apply all bt\ndetach\nquit", pid) + cmd := exec.Command("gdb", "-batch", "-nx", "-ex", gdbCommands) + + out, err := os.Create(outFile) + if err != nil { + return fmt.Errorf("failed to create output file: %v", err) + } + defer out.Close() + + cmd.Stdout = out + cmd.Stderr = out + + return cmd.Run() +} + +func (c *Collector) collectTcpdump() error { + // Get MySQL port + var port int + err := c.db.QueryRow("SELECT @@port").Scan(&port) + if err != nil { + return fmt.Errorf("failed to get MySQL port: %v", err) + } + + outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-tcpdump.cap", c.prefix)) + + cmd := exec.Command("tcpdump", "-i", "any", fmt.Sprintf("port %d", port), "-w", outFile) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start tcpdump: %v", err) + } + + time.Sleep(time.Duration(c.stalker.config.RunTime) * time.Second) + + if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { + return fmt.Errorf("failed to stop tcpdump: %v", err) + } + + return cmd.Wait() +} diff --git a/src/go/pt-stalk/collect_test.go b/src/go/pt-stalk/collect_test.go new file mode 100644 index 000000000..a4356780b --- /dev/null +++ b/src/go/pt-stalk/collect_test.go @@ -0,0 +1,292 @@ +package main + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestCollectionFunctionality(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-collect-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + RunTime: 2, + CollectGDB: true, + CollectStrace: true, + CollectTcpdump: true, + MySQLOnly: false, + SystemOnly: false, + } + + logger, _ := NewLogger("", 3) + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + db, err := setupTestDB(t) + if err != nil { + t.Skip("MySQL not available:", err) + } + defer db.Close() + + prefix := time.Now().Format("2006_01_02_15_04_05") + ctx := context.Background() + if err := stalker.collect(ctx, db, prefix); err != nil { + t.Fatal(err) + } + + // Verify collection files + expectedFiles := []string{ + "mysql-variables.txt", + "mysql-status.txt", + "mysql-processlist.txt", + "vmstat.txt", + "iostat.txt", + "mpstat.txt", + "tcpdump.cap", + } + + outDir := filepath.Join(tmpDir, prefix) + for _, file := range expectedFiles { + path := filepath.Join(outDir, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file not found: %s", file) + } + } +} + +func TestMySQLOnlyCollection(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-mysql-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + RunTime: 2, + MySQLOnly: true, + } + + logger, _ := NewLogger("", 3) + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + db, err := setupTestDB(t) + if err != nil { + t.Skip("MySQL not available:", err) + } + defer db.Close() + + prefix := time.Now().Format("2006_01_02_15_04_05") + ctx := context.Background() + if err := stalker.collect(ctx, db, prefix); err != nil { + t.Fatal(err) + } + + // Verify only MySQL files exist + files, err := os.ReadDir(filepath.Join(tmpDir, prefix)) + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + if !strings.HasPrefix(file.Name(), "mysql-") { + t.Errorf("Found non-MySQL file: %s", file.Name()) + } + } +} + +func TestSystemOnlyCollection(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-system-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + RunTime: 2, + SystemOnly: true, + } + + logger, _ := NewLogger("", 3) + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + prefix := time.Now().Format("2006_01_02_15_04_05") + ctx := context.Background() + if err := stalker.collect(ctx, nil, prefix); err != nil { + t.Fatal(err) + } + + // Verify only system files exist + files, err := os.ReadDir(filepath.Join(tmpDir, prefix)) + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + if strings.HasPrefix(file.Name(), "mysql-") { + t.Errorf("Found MySQL file in system-only mode: %s", file.Name()) + } + } +} + +func TestDaemonization(t *testing.T) { + if os.Getenv("TEST_DAEMONIZATION") != "1" { + t.Skip("Skipping daemonization test") + } + + tmpDir, err := os.MkdirTemp("", "pt-stalk-daemon-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + pidFile := filepath.Join(tmpDir, "pt-stalk.pid") + logFile := filepath.Join(tmpDir, "pt-stalk.log") + + cfg := &Config{ + Dest: tmpDir, + Pid: pidFile, + Log: logFile, + Daemonize: true, + } + + logger, _ := NewLogger(logFile, 3) + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + if err := stalker.daemonize(); err != nil { + t.Fatal(err) + } + + // Verify PID file + if _, err := os.Stat(pidFile); os.IsNotExist(err) { + t.Error("PID file not created") + } + + // Verify log file + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Error("Log file not created") + } +} + +func TestCollectionErrors(t *testing.T) { + tests := []struct { + name string + config *Config + setupDB bool + expectError bool + }{ + { + name: "invalid_dest", + config: &Config{ + Dest: "/nonexistent/directory", + }, + setupDB: true, + expectError: true, + }, + { + name: "no_db_mysql_only", + config: &Config{ + MySQLOnly: true, + }, + setupDB: false, + expectError: true, + }, + { + name: "invalid_command", + config: &Config{ + CollectGDB: true, + }, + setupDB: true, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var db *sql.DB + if tc.setupDB { + var err error + db, err = setupTestDB(t) + if err != nil { + t.Skip("MySQL not available:", err) + } + defer db.Close() + } + + logger, _ := NewLogger("", 3) + stalker := &Stalker{ + config: tc.config, + logger: logger, + } + + ctx := context.Background() + err := stalker.collect(ctx, db, "test") + if tc.expectError && err == nil { + t.Error("Expected error but got none") + } else if !tc.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestCollectionTimeout(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-timeout-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + RunTime: 1, + } + + logger, _ := NewLogger("", 3) + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + db, err := setupTestDB(t) + if err != nil { + t.Skip("MySQL not available:", err) + } + defer db.Close() + + // Create a long-running collection + done := make(chan error) + go func() { + done <- stalker.collect(context.Background(), db, "test") + }() + + select { + case err := <-done: + if err == nil { + t.Error("Expected timeout error") + } + case <-time.After(time.Duration(cfg.RunTime+1) * time.Second): + t.Error("Collection did not timeout as expected") + } +} diff --git a/src/go/pt-stalk/examples/plugins/sample.sh b/src/go/pt-stalk/examples/plugins/sample.sh new file mode 100644 index 000000000..c33665c1f --- /dev/null +++ b/src/go/pt-stalk/examples/plugins/sample.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# This is a sample plugin for pt-stalk that demonstrates the available hooks + +before_stalk() { + echo "Starting stalker with:" + echo " Function: $PT_FUNCTION" + echo " Variable: $PT_VARIABLE" + echo " Threshold: $PT_THRESHOLD" +} + +before_collect() { + local prefix="$1" + echo "About to collect metrics with prefix: $prefix" + echo "Output directory: $PT_DEST/$prefix" +} + +after_collect() { + local prefix="$1" + echo "Finished collecting metrics with prefix: $prefix" + + # Example: Calculate total size of collected data + du -sh "$PT_DEST/$prefix" +} + +after_collect_sleep() { + echo "Finished sleeping after collection" +} + +after_interval_sleep() { + echo "Finished interval sleep" +} + +after_stalk() { + echo "Stalker finished" +} \ No newline at end of file diff --git a/src/go/pt-stalk/logger.go b/src/go/pt-stalk/logger.go new file mode 100644 index 000000000..8d35425bd --- /dev/null +++ b/src/go/pt-stalk/logger.go @@ -0,0 +1,73 @@ +package main + +import ( + "fmt" + "log" + "os" + "time" +) + +type Logger struct { + logger *log.Logger + verbose int + filename string +} + +func NewLogger(filename string, verbose int) (*Logger, error) { + var output *os.File + var err error + + if filename == "" || filename == "-" { + output = os.Stdout + } else { + output, err = os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open log file: %v", err) + } + } + + return &Logger{ + logger: log.New(output, "", 0), + verbose: verbose, + filename: filename, + }, nil +} + +func (l *Logger) formatMessage(level string, format string, args ...interface{}) string { + timestamp := time.Now().Format("2006-01-02 15:04:05") + message := fmt.Sprintf(format, args...) + return fmt.Sprintf("%s [%s] %s", timestamp, level, message) +} + +func (l *Logger) Error(format string, args ...interface{}) { + if l.verbose >= 0 { + l.logger.Println(l.formatMessage("ERROR", format, args...)) + } +} + +func (l *Logger) Warn(format string, args ...interface{}) { + if l.verbose >= 1 { + l.logger.Println(l.formatMessage("WARN", format, args...)) + } +} + +func (l *Logger) Info(format string, args ...interface{}) { + if l.verbose >= 2 { + l.logger.Println(l.formatMessage("INFO", format, args...)) + } +} + +func (l *Logger) Debug(format string, args ...interface{}) { + if l.verbose >= 3 { + l.logger.Println(l.formatMessage("DEBUG", format, args...)) + } +} + +func (l *Logger) Close() error { + if l.filename != "" && l.filename != "-" { + if logger, ok := l.logger.Writer().(*os.File); ok { + return logger.Close() + } + } + return nil +} diff --git a/src/go/pt-stalk/main.go b/src/go/pt-stalk/main.go new file mode 100644 index 000000000..a1c35328f --- /dev/null +++ b/src/go/pt-stalk/main.go @@ -0,0 +1,215 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/sevlyar/go-daemon" +) + +type Config struct { + Function string + Variable string + Match string + Threshold float64 + Cycles int + Interval int + RunTime int + Sleep int + SleepCollect int + Dest string + Prefix string + CollectGDB bool + CollectOProfile bool + CollectStrace bool + CollectTcpdump bool + Socket string + Host string + Port int + User string + Password string + DefaultsFile string + Log string + Pid string + Plugin string + Daemonize bool + SystemOnly bool + MySQLOnly bool + RetentionTime int + RetentionCount int + RetentionSize int + DiskBytesFree int64 + DiskPctFree int + NotifyByEmail string + Verbose int +} + +type Stalker struct { + config *Config + ctx context.Context + cancel context.CancelFunc + logger *Logger + plugin *Plugin +} + +func main() { + cfg := &Config{} + + // Parse command line flags + flag.StringVar(&cfg.Function, "function", "status", "Trigger function (status|processlist)") + flag.StringVar(&cfg.Variable, "variable", "Threads_running", "Variable to monitor") + flag.StringVar(&cfg.Match, "match", "", "Pattern to match (for processlist)") + flag.Float64Var(&cfg.Threshold, "threshold", 25, "Threshold value") + flag.IntVar(&cfg.Cycles, "cycles", 5, "Number of cycles before collecting") + flag.IntVar(&cfg.Interval, "interval", 1, "Check interval in seconds") + flag.IntVar(&cfg.RunTime, "run-time", 30, "How long to collect data in seconds") + flag.IntVar(&cfg.Sleep, "sleep", 300, "How long to sleep after collection") + flag.IntVar(&cfg.SleepCollect, "sleep-collect", 1, "How long to sleep between collection cycles") + flag.StringVar(&cfg.Dest, "dest", "/var/lib/pt-stalk", "Output destination directory") + flag.StringVar(&cfg.Prefix, "prefix", "", "Filename prefix for samples") + flag.BoolVar(&cfg.CollectGDB, "collect-gdb", false, "Collect GDB stacktraces") + flag.BoolVar(&cfg.CollectOProfile, "collect-oprofile", false, "Collect OProfile data") + flag.BoolVar(&cfg.CollectStrace, "collect-strace", false, "Collect strace data") + flag.BoolVar(&cfg.CollectTcpdump, "collect-tcpdump", false, "Collect tcpdump data") + flag.StringVar(&cfg.Socket, "socket", "", "MySQL socket file") + flag.StringVar(&cfg.Host, "host", "", "MySQL host") + flag.IntVar(&cfg.Port, "port", 3306, "MySQL port") + flag.StringVar(&cfg.User, "user", "", "MySQL user") + flag.StringVar(&cfg.Password, "password", "", "MySQL password") + flag.StringVar(&cfg.DefaultsFile, "defaults-file", "", "MySQL defaults file") + flag.StringVar(&cfg.Log, "log", "/var/log/pt-stalk.log", "Log file when daemonized") + flag.StringVar(&cfg.Pid, "pid", "/var/run/pt-stalk.pid", "PID file") + flag.BoolVar(&cfg.Daemonize, "daemonize", false, "Run as daemon") + flag.BoolVar(&cfg.SystemOnly, "system-only", false, "Collect only system metrics") + flag.BoolVar(&cfg.MySQLOnly, "mysql-only", false, "Collect only MySQL metrics") + flag.IntVar(&cfg.RetentionTime, "retention-time", 30, "Days to retain samples") + flag.IntVar(&cfg.RetentionCount, "retention-count", 0, "Number of samples to retain") + flag.IntVar(&cfg.RetentionSize, "retention-size", 0, "Maximum size in MB to retain") + flag.Int64Var(&cfg.DiskBytesFree, "disk-bytes-free", 100*1024*1024, "Minimum bytes free") + flag.IntVar(&cfg.DiskPctFree, "disk-pct-free", 5, "Minimum percent free") + flag.StringVar(&cfg.NotifyByEmail, "notify-by-email", "", "Email address for notifications") + flag.IntVar(&cfg.Verbose, "verbose", 2, "Verbosity level (0-3)") + flag.Parse() + + // Setup signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + // Create context with cancellation + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Initialize logger + logger, err := NewLogger(cfg.Log, cfg.Verbose) + if err != nil { + fmt.Fprintf(os.Stderr, "Error initializing logger: %v\n", err) + os.Exit(1) + } + + stalker := &Stalker{ + config: cfg, + ctx: ctx, + cancel: cancel, + logger: logger, + } + + // Initialize plugin if specified + if err := stalker.initPlugin(); err != nil { + fmt.Fprintf(os.Stderr, "Error initializing plugin: %v\n", err) + os.Exit(1) + } + + // Create PID file if daemonizing + if cfg.Daemonize { + if err := createPIDFile(cfg.Pid); err != nil { + stalker.logger.Error("Failed to create PID file: %v", err) + os.Exit(1) + } + defer os.Remove(cfg.Pid) + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(cfg.Dest, 0755); err != nil { + stalker.logger.Error("Failed to create destination directory: %v", err) + os.Exit(1) + } + + // Start stalking in a goroutine + errChan := make(chan error, 1) + go func() { + errChan <- stalker.Stalk() + }() + + // Wait for signal or error + select { + case sig := <-sigChan: + stalker.logger.Info("Received signal: %v", sig) + cancel() + case err := <-errChan: + if err != nil { + stalker.logger.Error("Stalking error: %v", err) + os.Exit(1) + } + } + + // Wait for cleanup + cleanup := make(chan struct{}) + go func() { + stalker.cleanup() + close(cleanup) + }() + + select { + case <-cleanup: + case <-time.After(time.Duration(cfg.RunTime*3) * time.Second): + stalker.logger.Warn("Cleanup timed out") + } +} + +func createPIDFile(pidFile string) error { + if _, err := os.Stat(pidFile); err == nil { + // PID file exists, check if process is running + pidBytes, err := os.ReadFile(pidFile) + if err != nil { + return fmt.Errorf("failed to read PID file: %v", err) + } + + pid := string(pidBytes) + if _, err := os.Stat(filepath.Join("/proc", pid)); err == nil { + return fmt.Errorf("process %s is already running", pid) + } + } + + return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0644) +} + +func (s *Stalker) daemonize() error { + if !s.config.Daemonize { + return nil + } + + cntxt := &daemon.Context{ + PidFileName: s.config.Pid, + PidFilePerm: 0644, + LogFileName: s.config.Log, + LogFilePerm: 0640, + WorkDir: "./", + Umask: 027, + } + + d, err := cntxt.Reborn() + if err != nil { + return fmt.Errorf("failed to daemonize: %v", err) + } + if d != nil { + os.Exit(0) + } + + return nil +} diff --git a/src/go/pt-stalk/main_test.go b/src/go/pt-stalk/main_test.go new file mode 100644 index 000000000..2301cf08f --- /dev/null +++ b/src/go/pt-stalk/main_test.go @@ -0,0 +1,212 @@ +package main + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +func TestStalker(t *testing.T) { + // Create temporary directory for test outputs + tmpDir, err := os.MkdirTemp("", "pt-stalk-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Test configuration + cfg := &Config{ + Function: "status", + Variable: "Threads_running", + Threshold: 5, + Cycles: 2, + Interval: 1, + RunTime: 5, + Sleep: 2, + Dest: tmpDir, + Host: "localhost", + Port: 3306, + User: os.Getenv("MYSQL_TEST_USER"), + Password: os.Getenv("MYSQL_TEST_PASS"), + Verbose: 3, + DiskBytesFree: 1024 * 1024, // 1MB + DiskPctFree: 1, + } + + // Initialize logger + logger, err := NewLogger("", cfg.Verbose) + if err != nil { + t.Fatalf("Failed to create logger: %v", err) + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize stalker + stalker := &Stalker{ + config: cfg, + ctx: ctx, + cancel: cancel, + logger: logger, + } + + // Test MySQL connection + db, err := sql.Open("mysql", stalker.buildDSN()) + if err != nil { + t.Skipf("Skipping test, could not connect to MySQL: %v", err) + } + defer db.Close() + + if err := db.Ping(); err != nil { + t.Skipf("Skipping test, MySQL not responding: %v", err) + } + + // Run stalker in goroutine + errChan := make(chan error, 1) + go func() { + errChan <- stalker.Stalk() + }() + + // Create some test load + go func() { + for i := 0; i < 10; i++ { + db.Exec("SELECT SLEEP(1)") + time.Sleep(time.Second) + } + }() + + // Wait for stalker to finish or timeout + select { + case err := <-errChan: + if err != nil { + t.Errorf("Stalker failed: %v", err) + } + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("Unexpected context error: %v", ctx.Err()) + } + } + + // Verify outputs + files, err := os.ReadDir(tmpDir) + if err != nil { + t.Fatalf("Failed to read output directory: %v", err) + } + + if len(files) == 0 { + t.Error("No output files were created") + } + + // Check specific files + expectedFiles := []string{ + "mysql-variables.txt", + "mysql-status.txt", + "mysql-processlist.txt", + "uptime.txt", + "vmstat.txt", + "iostat.txt", + } + + for _, dir := range files { + if !dir.IsDir() { + continue + } + + for _, expected := range expectedFiles { + path := filepath.Join(tmpDir, dir.Name(), expected) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file %s not found", path) + } + } + } +} + +func TestPluginExecution(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test plugin + pluginContent := `#!/bin/bash +before_stalk() { + echo "before_stalk called" + return 0 +} +before_collect() { + echo "before_collect called with $1" + return 0 +} +after_collect() { + echo "after_collect called with $1" + return 0 +}` + + pluginPath := filepath.Join(tmpDir, "test-plugin.sh") + if err := os.WriteFile(pluginPath, []byte(pluginContent), 0755); err != nil { + t.Fatalf("Failed to write test plugin: %v", err) + } + + logger, err := NewLogger("", 3) + if err != nil { + t.Fatalf("Failed to create logger: %v", err) + } + + plugin, err := NewPlugin(pluginPath, logger) + if err != nil { + t.Fatalf("Failed to create plugin: %v", err) + } + + // Test each hook + hooks := []struct { + hook PluginHook + args []string + }{ + {BeforeStalk, nil}, + {BeforeCollect, []string{"test_prefix"}}, + {AfterCollect, []string{"test_prefix"}}, + } + + for _, tc := range hooks { + err := plugin.Execute(tc.hook, tc.args...) + if err != nil { + t.Errorf("Plugin execution failed for %s: %v", tc.hook, err) + } + } +} + +func TestSizeParser(t *testing.T) { + tests := []struct { + input string + expected int64 + hasError bool + }{ + {"1K", 1024, false}, + {"1M", 1024 * 1024, false}, + {"1G", 1024 * 1024 * 1024, false}, + {"1T", 1024 * 1024 * 1024 * 1024, false}, + {"1.5G", 1610612736, false}, + {"1024", 1024, false}, + {"invalid", 0, true}, + } + + for _, tc := range tests { + result, err := ParseSize(tc.input) + if tc.hasError && err == nil { + t.Errorf("Expected error for input %s, got none", tc.input) + } + if !tc.hasError && err != nil { + t.Errorf("Unexpected error for input %s: %v", tc.input, err) + } + if !tc.hasError && result != tc.expected { + t.Errorf("For input %s, expected %d, got %d", tc.input, tc.expected, result) + } + } +} diff --git a/src/go/pt-stalk/plugin.go b/src/go/pt-stalk/plugin.go new file mode 100644 index 000000000..d873a61d0 --- /dev/null +++ b/src/go/pt-stalk/plugin.go @@ -0,0 +1,165 @@ +package main + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +type Plugin struct { + path string + env map[string]string + logger *Logger +} + +type PluginHook string + +const ( + BeforeStalk PluginHook = "before_stalk" + BeforeCollect PluginHook = "before_collect" + AfterCollect PluginHook = "after_collect" + AfterCollectSleep PluginHook = "after_collect_sleep" + AfterIntervalSleep PluginHook = "after_interval_sleep" + AfterStalk PluginHook = "after_stalk" +) + +func NewPlugin(path string, logger *Logger) (*Plugin, error) { + if path == "" { + return nil, nil + } + + absPath, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("failed to resolve plugin path: %v", err) + } + + if _, err := os.Stat(absPath); err != nil { + return nil, fmt.Errorf("plugin file not found: %v", err) + } + + return &Plugin{ + path: absPath, + env: make(map[string]string), + logger: logger, + }, nil +} + +func (p *Plugin) SetEnv(key, value string) { + if p != nil { + p.env[key] = value + } +} + +func (p *Plugin) Execute(hook PluginHook, args ...string) error { + if p == nil { + return nil + } + + p.logger.Debug("Executing plugin hook: %s", hook) + + // Prepare environment variables + env := os.Environ() + for k, v := range p.env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + + // Add hook name to environment + env = append(env, fmt.Sprintf("PT_HOOK=%s", hook)) + + // Create temporary script to execute the plugin + tmpScript, err := os.CreateTemp("", "pt-stalk-plugin-*.sh") + if err != nil { + return fmt.Errorf("failed to create temporary script: %v", err) + } + defer os.Remove(tmpScript.Name()) + + // Write plugin execution script + script := fmt.Sprintf(`#!/bin/bash +source "%s" +if type %s >/dev/null 2>&1; then + %s "$@" + exit $? +else + exit 0 +fi +`, p.path, hook, hook) + + if _, err := tmpScript.WriteString(script); err != nil { + return fmt.Errorf("failed to write plugin script: %v", err) + } + + if err := tmpScript.Close(); err != nil { + return fmt.Errorf("failed to close plugin script: %v", err) + } + + if err := os.Chmod(tmpScript.Name(), 0755); err != nil { + return fmt.Errorf("failed to make plugin script executable: %v", err) + } + + // Execute the plugin + cmd := exec.Command(tmpScript.Name(), args...) + cmd.Env = env + cmd.Dir = filepath.Dir(p.path) + + // Capture output + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("plugin hook %s failed: %v\nOutput: %s", hook, err, output) + } + + if len(output) > 0 { + p.logger.Debug("Plugin output (%s):\n%s", hook, strings.TrimSpace(string(output))) + } + + return nil +} + +// Helper methods for the Stalker struct to handle plugins +func (s *Stalker) initPlugin() error { + if s.config.Plugin != "" { + plugin, err := NewPlugin(s.config.Plugin, s.logger) + if err != nil { + return fmt.Errorf("failed to initialize plugin: %v", err) + } + s.plugin = plugin + + // Set up environment variables for the plugin + s.plugin.SetEnv("PT_DEST", s.config.Dest) + s.plugin.SetEnv("PT_MYSQL_USER", s.config.User) + s.plugin.SetEnv("PT_MYSQL_HOST", s.config.Host) + s.plugin.SetEnv("PT_MYSQL_PORT", fmt.Sprintf("%d", s.config.Port)) + s.plugin.SetEnv("PT_INTERVAL", fmt.Sprintf("%d", s.config.Interval)) + s.plugin.SetEnv("PT_SLEEP", fmt.Sprintf("%d", s.config.Sleep)) + s.plugin.SetEnv("PT_FUNCTION", s.config.Function) + s.plugin.SetEnv("PT_VARIABLE", s.config.Variable) + s.plugin.SetEnv("PT_THRESHOLD", fmt.Sprintf("%f", s.config.Threshold)) + } + return nil +} + +func (s *Stalker) executePluginHook(hook PluginHook, args ...string) error { + if s.plugin != nil { + return s.plugin.Execute(hook, args...) + } + return nil +} + +// Example plugin usage in the Stalker.Stalk() method: +/* + // Before starting to stalk + if err := s.executePluginHook(BeforeStalk); err != nil { + return fmt.Errorf("plugin before_stalk hook failed: %v", err) + } + + // Before collecting metrics + if err := s.executePluginHook(BeforeCollect, prefix); err != nil { + return fmt.Errorf("plugin before_collect hook failed: %v", err) + } + + // After collecting metrics + if err := s.executePluginHook(AfterCollect, prefix); err != nil { + s.logger.Warn("Plugin after_collect hook failed: %v", err) + } +*/ diff --git a/src/go/pt-stalk/stalk.go b/src/go/pt-stalk/stalk.go new file mode 100644 index 000000000..7caca0c39 --- /dev/null +++ b/src/go/pt-stalk/stalk.go @@ -0,0 +1,291 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strconv" + "syscall" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +const ( + defaultInterval = 1 + defaultCycles = 1 + defaultRetention = 30 + defaultDiskPctFree = 5 + timeFormat = "2006_01_02_15_04_05" +) + +type MetricCollector interface { + Collect(ctx context.Context, prefix string) error +} + +type MetricChecker interface { + Check(ctx context.Context) (bool, error) +} + +type StalkConfig interface { + Validate() error +} + +type LogEntry struct { + Level string + Message string + Time time.Time + Fields map[string]interface{} +} + +func (s *Stalker) buildDSN() string { + dsn := "" + if s.config.DefaultsFile != "" { + dsn += fmt.Sprintf("defaults-file=%s", s.config.DefaultsFile) + } + if s.config.User != "" { + if dsn != "" { + dsn += "&" + } + dsn += fmt.Sprintf("user=%s", s.config.User) + } + if s.config.Password != "" { + if dsn != "" { + dsn += "&" + } + dsn += fmt.Sprintf("password=%s", s.config.Password) + } + if s.config.Socket != "" { + if dsn != "" { + dsn += "&" + } + dsn += fmt.Sprintf("socket=%s", s.config.Socket) + } else { + if s.config.Host != "" { + if dsn != "" { + dsn += "&" + } + dsn += fmt.Sprintf("host=%s", s.config.Host) + } + if s.config.Port != 0 { + if dsn != "" { + dsn += "&" + } + dsn += fmt.Sprintf("port=%d", s.config.Port) + } + } + return dsn +} + +func (s *Stalker) Stalk() error { + s.logger.Info("Starting stalker with config: %+v", s.config) + + // Don't connect to MySQL if we're only collecting system metrics + var db *sql.DB + var err error + if !s.config.SystemOnly { + db, err = sql.Open("mysql", s.buildDSN()) + if err != nil { + return fmt.Errorf("failed to connect to MySQL: %v", err) + } + defer db.Close() + + // Test the connection + if err := db.Ping(); err != nil { + return fmt.Errorf("failed to ping MySQL: %v", err) + } + } + + triggerCount := 0 + iteration := 0 + + for { + select { + case <-s.ctx.Done(): + s.logger.Info("Stalker received shutdown signal") + return nil + default: + if s.config.SystemOnly { + // For system-only mode, we treat it as always triggered + triggerCount++ + } else { + triggered, err := s.checkTrigger(db) + if err != nil { + return fmt.Errorf("failed to check trigger: %v", err) + } + + if triggered { + triggerCount++ + s.logger.Info("Trigger condition met (%d/%d)", triggerCount, s.config.Cycles) + } else { + if triggerCount > 0 { + s.logger.Debug("Trigger condition reset (was %d/%d)", triggerCount, s.config.Cycles) + } + triggerCount = 0 + } + } + + if triggerCount >= s.config.Cycles { + s.logger.Info("Trigger threshold reached, starting collection") + + // Generate collection prefix + prefix := s.config.Prefix + if prefix == "" { + prefix = time.Now().Format("2006_01_02_15_04_05") + } + + // Check disk space + if err := s.checkDiskSpace(prefix); err != nil { + s.logger.Error("Disk space check failed: %v", err) + return err + } + + // Start collection + if err := s.collectWithTimeout(db, prefix); err != nil { + s.logger.Error("Collection failed: %v", err) + return err + } + + // Reset trigger count + triggerCount = 0 + iteration++ + + // Sleep after collection + s.logger.Info("Sleeping for %d seconds after collection", s.config.Sleep) + time.Sleep(time.Duration(s.config.Sleep) * time.Second) + + // Check if we've reached max iterations + if s.config.RetentionCount > 0 && iteration >= s.config.RetentionCount { + s.logger.Info("Reached maximum iterations (%d), shutting down", s.config.RetentionCount) + return nil + } + } + + // Sleep before next check + time.Sleep(time.Duration(s.config.Interval) * time.Second) + } + } +} + +func (s *Stalker) checkTrigger(db *sql.DB) (bool, error) { + switch s.config.Function { + case "status": + return s.checkStatusTrigger(db) + case "processlist": + return s.checkProcesslistTrigger(db) + default: + return false, fmt.Errorf("unknown function: %s", s.config.Function) + } +} + +func (s *Stalker) checkStatusTrigger(db *sql.DB) (bool, error) { + query := "SHOW GLOBAL STATUS WHERE Variable_name = ?" + var name, value string + err := db.QueryRow(query, s.config.Variable).Scan(&name, &value) + if err != nil { + return false, fmt.Errorf("failed to query status: %v", err) + } + + val, err := strconv.ParseFloat(value, 64) + if err != nil { + return false, fmt.Errorf("failed to parse value %s: %v", value, err) + } + + s.logger.Debug("Status check: %s = %v (threshold: %v)", s.config.Variable, val, s.config.Threshold) + return val > s.config.Threshold, nil +} + +func (s *Stalker) checkProcesslistTrigger(db *sql.DB) (bool, error) { + query := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.PROCESSLIST WHERE State = ?` + var count int + err := db.QueryRow(query, s.config.Match).Scan(&count) + if err != nil { + return false, fmt.Errorf("failed to query processlist: %v", err) + } + + s.logger.Debug("Processlist check: count = %d (threshold: %v)", count, s.config.Threshold) + return float64(count) > s.config.Threshold, nil +} + +func (s *Stalker) checkDiskSpace(prefix string) error { + // Get disk usage information + var stat syscall.Statfs_t + err := syscall.Statfs(s.config.Dest, &stat) + if err != nil { + return fmt.Errorf("failed to get disk stats: %v", err) + } + + // Calculate free space + blockSize := uint64(stat.Bsize) + totalBlocks := stat.Blocks + freeBlocks := stat.Bfree + + totalBytes := totalBlocks * blockSize + freeBytes := freeBlocks * blockSize + freePercent := float64(freeBytes) / float64(totalBytes) * 100 + + // Check if we have enough free space + if freeBytes < uint64(s.config.DiskBytesFree) { + return fmt.Errorf("insufficient free disk space: %d bytes (need %d)", freeBytes, s.config.DiskBytesFree) + } + + if freePercent < float64(s.config.DiskPctFree) { + return fmt.Errorf("insufficient free disk space: %.2f%% (need %d%%)", freePercent, s.config.DiskPctFree) + } + + s.logger.Debug("Disk space check passed: %.2f%% (%.2f GB) free", freePercent, float64(freeBytes)/(1024*1024*1024)) + return nil +} + +func (s *Stalker) cleanup() error { + s.logger.Info("Starting cleanup") + + // Clean up based on retention time + if s.config.RetentionTime > 0 { + cutoff := time.Now().AddDate(0, 0, -s.config.RetentionTime) + err := filepath.Walk(s.config.Dest, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() && info.ModTime().Before(cutoff) { + if err := os.RemoveAll(path); err != nil { + s.logger.Error("Failed to remove old directory %s: %v", path, err) + return err + } + s.logger.Info("Removed old directory %s", path) + } + return nil + }) + if err != nil { + s.logger.Error("Error during retention cleanup: %v", err) + return err + } + } + return nil +} + +func (s *Stalker) collectWithTimeout(db *sql.DB, prefix string) error { + ctx, cancel := context.WithTimeout(context.Background(), + time.Duration(s.config.RunTime)*time.Second) + defer cancel() + + return s.collect(ctx, db, prefix) +} + +type Metrics struct { + TriggersTotal int64 + CollectionsTotal int64 + ErrorsTotal int64 + // ... etc +} + +type CollectionError struct { + Prefix string + Err error +} + +func (e *CollectionError) Error() string { + return fmt.Sprintf("collection failed for %s: %v", e.Prefix, e.Err) +} diff --git a/src/go/pt-stalk/stalk_test.go b/src/go/pt-stalk/stalk_test.go new file mode 100644 index 000000000..daf1460b0 --- /dev/null +++ b/src/go/pt-stalk/stalk_test.go @@ -0,0 +1,197 @@ +package main + +import ( + "database/sql" + "os" + "path/filepath" + "testing" + "time" +) + +func TestTriggerFunctions(t *testing.T) { + tests := []struct { + name string + function string + variable string + match string + threshold float64 + expected bool + }{ + {"status_threads_running", "status", "Threads_running", "", 5, false}, + {"processlist_sleep", "processlist", "", "Sleep", 10, false}, + {"invalid_function", "invalid", "", "", 0, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{ + Function: tc.function, + Variable: tc.variable, + Match: tc.match, + Threshold: tc.threshold, + } + stalker := &Stalker{config: cfg} + + // Setup test database connection + db, err := setupTestDB(t) + if err != nil { + t.Skip("MySQL not available:", err) + } + defer db.Close() + + triggered, err := stalker.checkTrigger(db) + if tc.function == "invalid" { + if err == nil { + t.Error("Expected error for invalid function") + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } else if triggered != tc.expected { + t.Errorf("Expected triggered=%v, got %v", tc.expected, triggered) + } + }) + } +} + +func TestRetention(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-retention-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + // Create some test files with different dates + dates := []struct { + dir string + time time.Time + }{ + {"old", time.Now().AddDate(0, 0, -31)}, + {"new", time.Now()}, + } + + for _, d := range dates { + dir := filepath.Join(tmpDir, d.dir) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatal(err) + } + if err := os.Chtimes(dir, d.time, d.time); err != nil { + t.Fatal(err) + } + } + + cfg := &Config{ + Dest: tmpDir, + RetentionTime: 30, + } + stalker := &Stalker{config: cfg} + + if err := stalker.cleanup(); err != nil { + t.Fatal(err) + } + + // Check that old directory was removed and new remains + if _, err := os.Stat(filepath.Join(tmpDir, "old")); !os.IsNotExist(err) { + t.Error("Old directory should have been removed") + } + if _, err := os.Stat(filepath.Join(tmpDir, "new")); os.IsNotExist(err) { + t.Error("New directory should still exist") + } +} + +func TestDiskSpace(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-disk-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + tests := []struct { + name string + bytesFree int64 + pctFree int + shouldError bool + }{ + {"sufficient_space", 1024 * 1024 * 1024, 10, false}, + {"insufficient_bytes", 1024, 10, true}, + {"insufficient_percent", 1024 * 1024 * 1024, 99, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{ + Dest: tmpDir, + DiskBytesFree: tc.bytesFree, + DiskPctFree: tc.pctFree, + } + stalker := &Stalker{config: cfg} + + err := stalker.checkDiskSpace("test") + if tc.shouldError && err == nil { + t.Error("Expected disk space error") + } else if !tc.shouldError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestPluginHooks(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + // Create test plugin + pluginContent := `#!/bin/bash +before_stalk() { echo "before_stalk"; } +before_collect() { echo "before_collect $1"; } +after_collect() { echo "after_collect $1"; } +after_collect_sleep() { echo "after_collect_sleep"; } +after_interval_sleep() { echo "after_interval_sleep"; } +after_stalk() { echo "after_stalk"; } +` + pluginPath := filepath.Join(tmpDir, "test.sh") + if err := os.WriteFile(pluginPath, []byte(pluginContent), 0755); err != nil { + t.Fatal(err) + } + + logger, _ := NewLogger("", 3) + cfg := &Config{ + Plugin: pluginPath, + } + stalker := &Stalker{ + config: cfg, + logger: logger, + } + + if err := stalker.initPlugin(); err != nil { + t.Fatal(err) + } + + hooks := []struct { + hook PluginHook + args []string + }{ + {BeforeStalk, nil}, + {BeforeCollect, []string{"test"}}, + {AfterCollect, []string{"test"}}, + {AfterCollectSleep, nil}, + {AfterIntervalSleep, nil}, + {AfterStalk, nil}, + } + + for _, h := range hooks { + if err := stalker.executePluginHook(h.hook, h.args...); err != nil { + t.Errorf("Hook %s failed: %v", h.hook, err) + } + } +} + +func setupTestDB(t *testing.T) (*sql.DB, error) { + dsn := os.Getenv("MYSQL_TEST_DSN") + if dsn == "" { + dsn = "root@tcp(localhost:3306)/test" + } + return sql.Open("mysql", dsn) +} diff --git a/src/go/pt-stalk/utils.go b/src/go/pt-stalk/utils.go new file mode 100644 index 000000000..d4e547ba0 --- /dev/null +++ b/src/go/pt-stalk/utils.go @@ -0,0 +1,169 @@ +package main + +import ( + "database/sql" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "syscall" + "time" +) + +// Common size units for parsing +const ( + _ = iota + KB = 1 << (10 * iota) + MB + GB + TB +) + +// Regular expression for parsing size strings (e.g., "100M", "1.5G") +var sizeRegex = regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*([kKmMgGtT])?[bB]?$`) + +// ParseSize converts a human-readable size string to bytes +func ParseSize(size string) (int64, error) { + matches := sizeRegex.FindStringSubmatch(strings.TrimSpace(size)) + if matches == nil { + return 0, fmt.Errorf("invalid size format: %s", size) + } + + value, err := strconv.ParseFloat(matches[1], 64) + if err != nil { + return 0, fmt.Errorf("invalid size value: %s", matches[1]) + } + + var multiplier int64 = 1 + if len(matches) > 2 && matches[2] != "" { + switch strings.ToUpper(matches[2]) { + case "K": + multiplier = KB + case "M": + multiplier = MB + case "G": + multiplier = GB + case "T": + multiplier = TB + } + } + + return int64(value * float64(multiplier)), nil +} + +// FormatSize converts bytes to a human-readable string +func FormatSize(bytes int64) string { + switch { + case bytes >= TB: + return fmt.Sprintf("%.2fTB", float64(bytes)/float64(TB)) + case bytes >= GB: + return fmt.Sprintf("%.2fGB", float64(bytes)/float64(GB)) + case bytes >= MB: + return fmt.Sprintf("%.2fMB", float64(bytes)/float64(MB)) + case bytes >= KB: + return fmt.Sprintf("%.2fKB", float64(bytes)/float64(KB)) + default: + return fmt.Sprintf("%dB", bytes) + } +} + +// GetDirectorySize calculates the total size of a directory +func GetDirectorySize(path string) (int64, error) { + var size int64 + err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + size += info.Size() + } + return nil + }) + return size, err +} + +// IsProcessRunning checks if a process with the given PID is running +func IsProcessRunning(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // On Unix systems, FindProcess always succeeds, so we need to send + // signal 0 to actually check if the process exists + err = process.Signal(syscall.Signal(0)) + return err == nil +} + +// EnsureDirectoryExists creates a directory if it doesn't exist +func EnsureDirectoryExists(path string) error { + if _, err := os.Stat(path); os.IsNotExist(err) { + return os.MkdirAll(path, 0755) + } + return nil +} + +// ReadPIDFile reads a PID from a file +func ReadPIDFile(path string) (int, error) { + content, err := os.ReadFile(path) + if err != nil { + return 0, err + } + + pid, err := strconv.Atoi(strings.TrimSpace(string(content))) + if err != nil { + return 0, fmt.Errorf("invalid PID in file: %v", err) + } + + return pid, nil +} + +// WritePIDFile writes the current process PID to a file +func WritePIDFile(path string) error { + return os.WriteFile(path, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0644) +} + +// CleanOldFiles removes files older than the specified retention time +func CleanOldFiles(dir string, retentionDays int) error { + if retentionDays <= 0 { + return nil + } + + cutoff := time.Now().AddDate(0, 0, -retentionDays) + return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if !info.IsDir() && info.ModTime().Before(cutoff) { + if err := os.Remove(path); err != nil { + return fmt.Errorf("failed to remove old file %s: %v", path, err) + } + } + return nil + }) +} + +// SendEmail sends a notification email +func SendEmail(to, subject, body string) error { + if to == "" { + return nil + } + + cmd := exec.Command("mail", "-s", subject, to) + cmd.Stdin = strings.NewReader(body) + return cmd.Run() +} + +// GetMySQLProcessID gets the process ID of the MySQL server +func GetMySQLProcessID(db *sql.DB) (int, error) { + var pid int + err := db.QueryRow("SELECT @@pid").Scan(&pid) + if err != nil { + return 0, fmt.Errorf("failed to get MySQL PID: %v", err) + } + return pid, nil +} From 51288a1b0fa9f48023348eb191b0cc113f701951 Mon Sep 17 00:00:00 2001 From: David Murphy Date: Tue, 31 Dec 2024 14:08:13 -0600 Subject: [PATCH 2/3] feat(stalk): added tests and broke down go to be more extendable --- go.mod | 3 + go.sum | 6 + src/go/pt-stalk/README.md | 118 ++++++++++ src/go/pt-stalk/collect.go | 300 +----------------------- src/go/pt-stalk/collect_mysql.go | 152 +++++++++++++ src/go/pt-stalk/collect_mysql_test.go | 48 ++++ src/go/pt-stalk/collect_system.go | 103 +++++++++ src/go/pt-stalk/collect_system_test.go | 62 +++++ src/go/pt-stalk/collect_test.go | 290 +---------------------- src/go/pt-stalk/main.go | 262 ++++++--------------- src/go/pt-stalk/main_test.go | 245 +++++++------------- src/go/pt-stalk/plugin.go | 156 +++---------- src/go/pt-stalk/plugin_test.go | 87 +++++++ src/go/pt-stalk/stalk.go | 304 ++++--------------------- src/go/pt-stalk/stalk_test.go | 217 ++++++------------ 15 files changed, 899 insertions(+), 1454 deletions(-) create mode 100644 src/go/pt-stalk/README.md create mode 100644 src/go/pt-stalk/collect_mysql.go create mode 100644 src/go/pt-stalk/collect_mysql_test.go create mode 100644 src/go/pt-stalk/collect_system.go create mode 100644 src/go/pt-stalk/collect_system_test.go create mode 100644 src/go/pt-stalk/plugin_test.go diff --git a/go.mod b/go.mod index 0014b5a33..3dce92a3f 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,9 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/spf13/pflag v1.0.5 // indirect ) require ( @@ -59,6 +61,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sevlyar/go-daemon v0.1.6 + github.com/spf13/cobra v1.8.1 github.com/tklauser/go-sysconf v0.3.11 // indirect github.com/tklauser/numcpus v0.6.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index 48f92aeab..ace2d980b 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,7 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc= github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -55,6 +56,8 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef h1:A9HsByNhogrvm9cWb28sjiS3i7tcKCkflWFEkHfuAgM= github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= @@ -97,12 +100,15 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99 github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sevlyar/go-daemon v0.1.6 h1:EUh1MDjEM4BI109Jign0EaknA2izkOyi0LV3ro3QQGs= github.com/sevlyar/go-daemon v0.1.6/go.mod h1:6dJpPatBT9eUwM5VCw9Bt6CdX9Tk6UWvhW3MebLDRKE= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/src/go/pt-stalk/README.md b/src/go/pt-stalk/README.md new file mode 100644 index 000000000..7371fd0d1 --- /dev/null +++ b/src/go/pt-stalk/README.md @@ -0,0 +1,118 @@ +# pt-stalk (Go Version) + +A Go implementation of the Percona pt-stalk tool for collecting MySQL and system metrics. + +## Installation + +Install using go: + + go install github.com/percona/pt-stalk@latest + +## Usage Examples + +### Basic MySQL Monitoring + +Collect only MySQL metrics: + + pt-stalk --collectors=mysql \ + --mysql-host=localhost \ + --mysql-user=root \ + --mysql-password=secret \ + --dest=/var/log/mysql/samples \ + --interval=1 + +### MySQL and System Monitoring + +Collect both MySQL and system metrics: + + pt-stalk --collectors=mysql,system \ + --mysql-host=localhost \ + --mysql-user=root \ + --mysql-password=secret \ + --collect-gdb=true \ + --collect-tcpdump=true \ + --dest=/var/log/mysql/samples \ + --interval=1 + +### Running as a Daemon + +Run pt-stalk in the background: + + pt-stalk --collectors=mysql,system \ + --mysql-host=localhost \ + --mysql-user=root \ + --mysql-password=secret \ + --daemonize=true \ + --pid=/var/run/pt-stalk.pid \ + --log=/var/log/pt-stalk.log \ + --dest=/var/log/mysql/samples + +### Using a Custom Plugin + +Run with a custom plugin script: + + pt-stalk --collectors=mysql \ + --mysql-host=localhost \ + --mysql-user=root \ + --mysql-password=secret \ + --plugin=/path/to/custom/plugin.sh \ + --dest=/var/log/mysql/samples + +Example plugin script (plugin.sh): + + #!/bin/bash + # Environment variables available: + # PT_DEST - destination directory + # PT_PREFIX - file prefix + # PT_INTERVAL - check interval + # PT_RUNTIME - collection duration + + echo "Custom collection started" > "$PT_DEST/${PT_PREFIX}_custom.txt" + # Add your custom collection logic here + +## Configuration Options + +### Common Options +- --collectors: Comma-separated list of collectors to enable (mysql,system) +- --interval: Check interval in seconds (default: 1) +- --run-time: How long to collect data in seconds (default: 30) +- --sleep: Sleep time between collections in seconds (default: 1) +- --dest: Destination directory for collected data (default: /var/lib/pt-stalk) +- --prefix: Filename prefix for samples +- --daemonize: Run as daemon (default: false) + +### MySQL Collector Options +- --mysql-host: MySQL host (default: localhost) +- --mysql-port: MySQL port (default: 3306) +- --mysql-user: MySQL user +- --mysql-password: MySQL password +- --mysql-socket: MySQL socket file +- --mysql-defaults-file: MySQL configuration file + +### System Collector Options +- --collect-gdb: Collect GDB stacktraces (default: false) +- --collect-oprofile: Collect OProfile data (default: false) +- --collect-strace: Collect strace data (default: false) +- --collect-tcpdump: Collect tcpdump data (default: false) + +### Retention Options +- --retention-time: Days to retain samples (default: 30) +- --retention-count: Number of samples to retain (default: 0) +- --retention-size: Maximum size in MB to retain (default: 0) +- --disk-bytes-free: Minimum bytes free (default: 100MB) +- --disk-pct-free: Minimum percent free (default: 5) + +### Notification Options +- --notify-by-email: Email address for notifications +- --verbose: Verbosity level (0-3) (default: 2) + +## Output Files + +Each collection creates files with the specified prefix and timestamps: +- {prefix}_status.txt: MySQL status variables +- {prefix}_variables.txt: MySQL system variables +- {prefix}_processlist.txt: MySQL process list +- {prefix}_diskstats.txt: System disk statistics +- {prefix}_meminfo.txt: System memory information +- {prefix}_loadavg.txt: System load average +- {prefix}_plugin.txt: Custom plugin output (if configured) \ No newline at end of file diff --git a/src/go/pt-stalk/collect.go b/src/go/pt-stalk/collect.go index c9e46c6f9..07e40f2eb 100644 --- a/src/go/pt-stalk/collect.go +++ b/src/go/pt-stalk/collect.go @@ -1,300 +1,24 @@ package main import ( - "bufio" "context" - "database/sql" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "syscall" - "time" -) - -type Collector struct { - stalker *Stalker - db *sql.DB - outDir string - prefix string - wg sync.WaitGroup -} - -func (s *Stalker) collect(ctx context.Context, db *sql.DB, prefix string) error { - outDir := filepath.Join(s.config.Dest, prefix) - if err := os.MkdirAll(outDir, 0755); err != nil { - return fmt.Errorf("failed to create output directory: %v", err) - } - - collector := &Collector{ - stalker: s, - db: db, - outDir: outDir, - prefix: prefix, - } - - // Start collection goroutines - collector.wg.Add(1) - go func() { - defer collector.wg.Done() - if err := collector.collectSystemMetrics(); err != nil { - s.logger.Error("System metrics collection failed: %v", err) - } - }() - - if !s.config.SystemOnly { - collector.wg.Add(1) - go func() { - defer collector.wg.Done() - if err := collector.collectMySQLMetrics(); err != nil { - s.logger.Error("MySQL metrics collection failed: %v", err) - } - }() - } - - // Wait for collections with timeout - done := make(chan struct{}) - go func() { - collector.wg.Wait() - close(done) - }() - - select { - case <-done: - s.logger.Info("Collection completed successfully") - case <-time.After(time.Duration(s.config.RunTime) * time.Second): - s.logger.Warn("Collection timed out after %d seconds", s.config.RunTime) - } - - return nil -} - -func (c *Collector) collectSystemMetrics() error { - if c.stalker.config.MySQLOnly { - return nil - } - - metrics := []struct { - name string - command string - args []string - }{ - {"uptime", "uptime", nil}, - {"uname", "uname", []string{"-a"}}, - {"vmstat", "vmstat", []string{"1"}}, - {"iostat", "iostat", []string{"-dx", "1"}}, - {"mpstat", "mpstat", []string{"1"}}, - {"free", "free", []string{"-m"}}, - {"df", "df", []string{"-h"}}, - {"dmesg", "dmesg", nil}, - {"netstat", "netstat", []string{"-antp"}}, - {"top", "top", []string{"-b", "-n", "1"}}, - } - - for _, metric := range metrics { - c.wg.Add(1) - go func(m struct { - name string - command string - args []string - }) { - defer c.wg.Done() - outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-%s.txt", c.prefix, m.name)) - if err := c.runCommand(m.command, m.args, outFile); err != nil { - c.stalker.logger.Error("Failed to collect %s: %v", m.name, err) - } - }(metric) - } - - // Collect special metrics that need custom handling - if c.stalker.config.CollectTcpdump { - c.wg.Add(1) - go func() { - defer c.wg.Done() - if err := c.collectTcpdump(); err != nil { - c.stalker.logger.Error("Failed to collect tcpdump: %v", err) - } - }() - } - - return nil -} - -func (c *Collector) collectMySQLMetrics() error { - queries := []struct { - name string - query string - }{ - {"variables", "SHOW GLOBAL VARIABLES"}, - {"status", "SHOW GLOBAL STATUS"}, - {"processlist", "SHOW FULL PROCESSLIST"}, - {"slave_status", "SHOW SLAVE STATUS"}, - {"innodb_status", "SHOW ENGINE INNODB STATUS"}, - {"mutex_status", "SHOW ENGINE INNODB MUTEX"}, - } - - for _, q := range queries { - c.wg.Add(1) - go func(query struct { - name string - query string - }) { - defer c.wg.Done() - outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-mysql-%s.txt", c.prefix, query.name)) - if err := c.collectMySQLQuery(query.query, outFile); err != nil { - c.stalker.logger.Error("Failed to collect MySQL %s: %v", query.name, err) - } - }(q) - } - - // Collect special MySQL metrics that need custom handling - if c.stalker.config.CollectGDB { - c.wg.Add(1) - go func() { - defer c.wg.Done() - if err := c.collectGDBStacktrace(); err != nil { - c.stalker.logger.Error("Failed to collect GDB stacktrace: %v", err) - } - }() - } - - return nil -} - -func (c *Collector) runCommand(command string, args []string, outFile string) error { - cmd := exec.Command(command, args...) - - out, err := os.Create(outFile) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer out.Close() - cmd.Stdout = out - cmd.Stderr = out - - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start command: %v", err) - } - - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - }() - - select { - case err := <-done: - if err != nil { - return fmt.Errorf("command failed: %v", err) - } - case <-time.After(time.Duration(c.stalker.config.RunTime) * time.Second): - if err := cmd.Process.Kill(); err != nil { - c.stalker.logger.Error("Failed to kill process: %v", err) - } - return fmt.Errorf("command timed out") - } - - return nil -} - -func (c *Collector) collectMySQLQuery(query string, outFile string) error { - rows, err := c.db.Query(query) - if err != nil { - return fmt.Errorf("query failed: %v", err) - } - defer rows.Close() - - out, err := os.Create(outFile) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer out.Close() - - w := bufio.NewWriter(out) - - // Get column names - cols, err := rows.Columns() - if err != nil { - return fmt.Errorf("failed to get columns: %v", err) - } - - // Write header - fmt.Fprintf(w, "# %s\n", strings.Join(cols, "\t")) - - // Prepare values holders - vals := make([]interface{}, len(cols)) - for i := range vals { - vals[i] = new(sql.RawBytes) - } - - // Write data - for rows.Next() { - if err := rows.Scan(vals...); err != nil { - return fmt.Errorf("failed to scan row: %v", err) - } - - for i, val := range vals { - if i > 0 { - w.WriteString("\t") - } - if rb, ok := val.(*sql.RawBytes); ok { - w.Write(*rb) - } - } - w.WriteString("\n") - } + "github.com/spf13/cobra" +) - return w.Flush() +type CollectorRegistration struct { + Name string + AddFlags func(*cobra.Command, map[string]interface{}) + NewCollector func(*Config) Collector } -func (c *Collector) collectGDBStacktrace() error { - // Find MySQL process ID - var pid int - err := c.db.QueryRow("SELECT @@pid").Scan(&pid) - if err != nil { - return fmt.Errorf("failed to get MySQL PID: %v", err) - } - - outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-gdb.txt", c.prefix)) - - gdbCommands := fmt.Sprintf("attach %d\nthread apply all bt\ndetach\nquit", pid) - cmd := exec.Command("gdb", "-batch", "-nx", "-ex", gdbCommands) - - out, err := os.Create(outFile) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer out.Close() - - cmd.Stdout = out - cmd.Stderr = out +var registeredCollectors = make(map[string]CollectorRegistration) - return cmd.Run() +func RegisterCollector(reg CollectorRegistration) { + registeredCollectors[reg.Name] = reg } -func (c *Collector) collectTcpdump() error { - // Get MySQL port - var port int - err := c.db.QueryRow("SELECT @@port").Scan(&port) - if err != nil { - return fmt.Errorf("failed to get MySQL port: %v", err) - } - - outFile := filepath.Join(c.outDir, fmt.Sprintf("%s-tcpdump.cap", c.prefix)) - - cmd := exec.Command("tcpdump", "-i", "any", fmt.Sprintf("port %d", port), "-w", outFile) - - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start tcpdump: %v", err) - } - - time.Sleep(time.Duration(c.stalker.config.RunTime) * time.Second) - - if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { - return fmt.Errorf("failed to stop tcpdump: %v", err) - } - - return cmd.Wait() +// Base interface that all collectors must implement +type Collector interface { + Collect(ctx context.Context) error } diff --git a/src/go/pt-stalk/collect_mysql.go b/src/go/pt-stalk/collect_mysql.go new file mode 100644 index 000000000..a83da745a --- /dev/null +++ b/src/go/pt-stalk/collect_mysql.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/spf13/cobra" +) + +type MySQLCollector struct { + stalker *Stalker + db *sql.DB + outDir string + prefix string + wg sync.WaitGroup + mysqlCfg *MySQLConfig +} + +func NewMySQLCollector(config *Config) Collector { + mysqlCfg := config.CollectorConfigs["mysql"].(*MySQLConfig) + return &MySQLCollector{ + stalker: nil, + db: nil, + outDir: config.Dest, + prefix: config.Prefix, + mysqlCfg: mysqlCfg, + } +} + +func (c *MySQLCollector) Collect(ctx context.Context) error { + if c.db == nil { + mysqlCfg := c.mysqlCfg + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", mysqlCfg.User, mysqlCfg.Password, mysqlCfg.Host, mysqlCfg.Port) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("failed to connect to MySQL: %v", err) + } + c.db = db + defer db.Close() + } + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.collectStatus(ctx) + c.collectVariables(ctx) + c.collectProcesslist(ctx) + }() + + c.wg.Wait() + return nil +} + +func (c *MySQLCollector) collectStatus(ctx context.Context) error { + rows, err := c.db.QueryContext(ctx, "SHOW GLOBAL STATUS") + if err != nil { + return err + } + defer rows.Close() + + return c.writeResults(rows, c.prefix+"_status.txt") +} + +func (c *MySQLCollector) collectVariables(ctx context.Context) error { + rows, err := c.db.QueryContext(ctx, "SHOW GLOBAL VARIABLES") + if err != nil { + return err + } + defer rows.Close() + + return c.writeResults(rows, c.prefix+"_variables.txt") +} + +func (c *MySQLCollector) collectProcesslist(ctx context.Context) error { + rows, err := c.db.QueryContext(ctx, "SHOW FULL PROCESSLIST") + if err != nil { + return err + } + defer rows.Close() + + return c.writeProcesslist(rows, c.prefix+"_processlist.txt") +} + +func (c *MySQLCollector) writeResults(rows *sql.Rows, filename string) error { + f, err := os.Create(filepath.Join(c.outDir, filename)) + if err != nil { + return err + } + defer f.Close() + + for rows.Next() { + var name, value string + if err := rows.Scan(&name, &value); err != nil { + return err + } + fmt.Fprintf(f, "%s\t%s\n", name, value) + } + return rows.Err() +} + +func (c *MySQLCollector) writeProcesslist(rows *sql.Rows, filename string) error { + f, err := os.Create(filepath.Join(c.outDir, filename)) + if err != nil { + return err + } + defer f.Close() + + for rows.Next() { + var id, user, host, db, command, time, state, info sql.NullString + if err := rows.Scan(&id, &user, &host, &db, &command, &time, &state, &info); err != nil { + return err + } + fmt.Fprintf(f, "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n", + id.String, user.String, host.String, db.String, + command.String, time.String, state.String, info.String) + } + return rows.Err() +} + +type MySQLConfig struct { + Host string + Port int + User string + Password string + Socket string + DefaultsFile string +} + +func addMySQLFlags(cmd *cobra.Command, cfg map[string]interface{}) { + mysqlCfg := &MySQLConfig{} + cfg["mysql"] = mysqlCfg + + cmd.PersistentFlags().StringVar(&mysqlCfg.Host, "mysql-host", "", "MySQL host") + cmd.PersistentFlags().IntVar(&mysqlCfg.Port, "mysql-port", 3306, "MySQL port") + cmd.PersistentFlags().StringVar(&mysqlCfg.User, "mysql-user", "", "MySQL user") + cmd.PersistentFlags().StringVar(&mysqlCfg.Password, "mysql-password", "", "MySQL password") + cmd.PersistentFlags().StringVar(&mysqlCfg.Socket, "mysql-socket", "", "MySQL socket") + cmd.PersistentFlags().StringVar(&mysqlCfg.DefaultsFile, "mysql-defaults-file", "", "MySQL defaults file") +} + +func init() { + RegisterCollector(CollectorRegistration{ + Name: "mysql", + AddFlags: addMySQLFlags, + NewCollector: NewMySQLCollector, + }) +} diff --git a/src/go/pt-stalk/collect_mysql_test.go b/src/go/pt-stalk/collect_mysql_test.go new file mode 100644 index 000000000..805e853f9 --- /dev/null +++ b/src/go/pt-stalk/collect_mysql_test.go @@ -0,0 +1,48 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestMySQLCollector(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-mysql-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "mysql": &MySQLConfig{ + Host: "localhost", + Port: 3306, + User: "root", + }, + }, + } + + collector := NewMySQLCollector(cfg) + err = collector.Collect(context.Background()) + if err != nil { + t.Fatal(err) + } + + // Verify MySQL specific files + expectedFiles := []string{ + "test_status.txt", + "test_variables.txt", + "test_processlist.txt", + } + + for _, file := range expectedFiles { + path := filepath.Join(tmpDir, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file not found: %s", file) + } + } +} diff --git a/src/go/pt-stalk/collect_system.go b/src/go/pt-stalk/collect_system.go new file mode 100644 index 000000000..1518d9fba --- /dev/null +++ b/src/go/pt-stalk/collect_system.go @@ -0,0 +1,103 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "sync" + + "github.com/spf13/cobra" +) + +type SystemCollector struct { + stalker *Stalker + outDir string + prefix string + wg sync.WaitGroup + systemCfg *SystemConfig +} + +func NewSystemCollector(config *Config) Collector { + systemCfg := config.CollectorConfigs["system"].(*SystemConfig) + return &SystemCollector{ + stalker: nil, + outDir: config.Dest, + prefix: config.Prefix, + systemCfg: systemCfg, + } +} + +func (c *SystemCollector) Collect(ctx context.Context) error { + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.collectDiskStats(ctx) + c.collectMemInfo(ctx) + c.collectLoadAvg(ctx) + if c.systemCfg.CollectGDB { + c.collectGDB(ctx) + } + if c.systemCfg.CollectTcpdump { + c.collectTcpdump(ctx) + } + }() + + c.wg.Wait() + return nil +} + +func (c *SystemCollector) collectDiskStats(ctx context.Context) error { + return c.readAndWriteFile("/proc/diskstats", c.prefix+"_diskstats.txt") +} + +func (c *SystemCollector) collectMemInfo(ctx context.Context) error { + return c.readAndWriteFile("/proc/meminfo", c.prefix+"_meminfo.txt") +} + +func (c *SystemCollector) collectLoadAvg(ctx context.Context) error { + return c.readAndWriteFile("/proc/loadavg", c.prefix+"_loadavg.txt") +} + +func (c *SystemCollector) collectGDB(ctx context.Context) error { + // GDB collection implementation + return nil +} + +func (c *SystemCollector) collectTcpdump(ctx context.Context) error { + // Tcpdump collection implementation + return nil +} + +func (c *SystemCollector) readAndWriteFile(srcPath, destName string) error { + content, err := os.ReadFile(srcPath) + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(c.outDir, destName), content, 0644) +} + +type SystemConfig struct { + CollectGDB bool + CollectOProfile bool + CollectStrace bool + CollectTcpdump bool +} + +func addSystemFlags(cmd *cobra.Command, cfg map[string]interface{}) { + systemCfg := &SystemConfig{} + cfg["system"] = systemCfg + + cmd.PersistentFlags().BoolVar(&systemCfg.CollectGDB, "collect-gdb", false, "Collect GDB stacktraces") + cmd.PersistentFlags().BoolVar(&systemCfg.CollectOProfile, "collect-oprofile", false, "Collect OProfile data") + cmd.PersistentFlags().BoolVar(&systemCfg.CollectStrace, "collect-strace", false, "Collect strace data") + cmd.PersistentFlags().BoolVar(&systemCfg.CollectTcpdump, "collect-tcpdump", false, "Collect tcpdump data") +} + +func init() { + RegisterCollector(CollectorRegistration{ + Name: "system", + AddFlags: addSystemFlags, + NewCollector: NewSystemCollector, + }) +} diff --git a/src/go/pt-stalk/collect_system_test.go b/src/go/pt-stalk/collect_system_test.go new file mode 100644 index 000000000..220f8367d --- /dev/null +++ b/src/go/pt-stalk/collect_system_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestSystemCollector(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-system-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + cfg := &Config{ + Dest: tmpDir, + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "system": &SystemConfig{ + CollectGDB: true, + CollectTcpdump: true, + }, + }, + } + + collector := NewSystemCollector(cfg) + err = collector.Collect(context.Background()) + if err != nil { + t.Fatal(err) + } + + // Verify system specific files + expectedFiles := []string{ + "test_diskstats.txt", + "test_meminfo.txt", + "test_loadavg.txt", + } + + for _, file := range expectedFiles { + path := filepath.Join(tmpDir, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file not found: %s", file) + } + } + + // Verify optional collectors + if cfg.CollectorConfigs["system"].(*SystemConfig).CollectGDB { + gdbFile := filepath.Join(tmpDir, "test_gdb.txt") + if _, err := os.Stat(gdbFile); os.IsNotExist(err) { + t.Error("Expected GDB file not found") + } + } + + if cfg.CollectorConfigs["system"].(*SystemConfig).CollectTcpdump { + tcpdumpFile := filepath.Join(tmpDir, "test_tcpdump.cap") + if _, err := os.Stat(tcpdumpFile); os.IsNotExist(err) { + t.Error("Expected tcpdump file not found") + } + } +} diff --git a/src/go/pt-stalk/collect_test.go b/src/go/pt-stalk/collect_test.go index a4356780b..53b6228b9 100644 --- a/src/go/pt-stalk/collect_test.go +++ b/src/go/pt-stalk/collect_test.go @@ -1,292 +1,20 @@ package main import ( - "context" - "database/sql" - "os" - "path/filepath" - "strings" "testing" - "time" ) -func TestCollectionFunctionality(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-collect-*") - if err != nil { - t.Fatal(err) +func TestCollectorRegistry(t *testing.T) { + // Test collector registration + if len(registeredCollectors) == 0 { + t.Error("No collectors registered") } - defer os.RemoveAll(tmpDir) - cfg := &Config{ - Dest: tmpDir, - RunTime: 2, - CollectGDB: true, - CollectStrace: true, - CollectTcpdump: true, - MySQLOnly: false, - SystemOnly: false, - } - - logger, _ := NewLogger("", 3) - stalker := &Stalker{ - config: cfg, - logger: logger, - } - - db, err := setupTestDB(t) - if err != nil { - t.Skip("MySQL not available:", err) - } - defer db.Close() - - prefix := time.Now().Format("2006_01_02_15_04_05") - ctx := context.Background() - if err := stalker.collect(ctx, db, prefix); err != nil { - t.Fatal(err) - } - - // Verify collection files - expectedFiles := []string{ - "mysql-variables.txt", - "mysql-status.txt", - "mysql-processlist.txt", - "vmstat.txt", - "iostat.txt", - "mpstat.txt", - "tcpdump.cap", - } - - outDir := filepath.Join(tmpDir, prefix) - for _, file := range expectedFiles { - path := filepath.Join(outDir, file) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("Expected file not found: %s", file) - } - } -} - -func TestMySQLOnlyCollection(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-mysql-*") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) - - cfg := &Config{ - Dest: tmpDir, - RunTime: 2, - MySQLOnly: true, - } - - logger, _ := NewLogger("", 3) - stalker := &Stalker{ - config: cfg, - logger: logger, - } - - db, err := setupTestDB(t) - if err != nil { - t.Skip("MySQL not available:", err) - } - defer db.Close() - - prefix := time.Now().Format("2006_01_02_15_04_05") - ctx := context.Background() - if err := stalker.collect(ctx, db, prefix); err != nil { - t.Fatal(err) - } - - // Verify only MySQL files exist - files, err := os.ReadDir(filepath.Join(tmpDir, prefix)) - if err != nil { - t.Fatal(err) - } - - for _, file := range files { - if !strings.HasPrefix(file.Name(), "mysql-") { - t.Errorf("Found non-MySQL file: %s", file.Name()) - } - } -} - -func TestSystemOnlyCollection(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-system-*") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) - - cfg := &Config{ - Dest: tmpDir, - RunTime: 2, - SystemOnly: true, - } - - logger, _ := NewLogger("", 3) - stalker := &Stalker{ - config: cfg, - logger: logger, - } - - prefix := time.Now().Format("2006_01_02_15_04_05") - ctx := context.Background() - if err := stalker.collect(ctx, nil, prefix); err != nil { - t.Fatal(err) - } - - // Verify only system files exist - files, err := os.ReadDir(filepath.Join(tmpDir, prefix)) - if err != nil { - t.Fatal(err) - } - - for _, file := range files { - if strings.HasPrefix(file.Name(), "mysql-") { - t.Errorf("Found MySQL file in system-only mode: %s", file.Name()) - } - } -} - -func TestDaemonization(t *testing.T) { - if os.Getenv("TEST_DAEMONIZATION") != "1" { - t.Skip("Skipping daemonization test") - } - - tmpDir, err := os.MkdirTemp("", "pt-stalk-daemon-*") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) - - pidFile := filepath.Join(tmpDir, "pt-stalk.pid") - logFile := filepath.Join(tmpDir, "pt-stalk.log") - - cfg := &Config{ - Dest: tmpDir, - Pid: pidFile, - Log: logFile, - Daemonize: true, - } - - logger, _ := NewLogger(logFile, 3) - stalker := &Stalker{ - config: cfg, - logger: logger, - } - - if err := stalker.daemonize(); err != nil { - t.Fatal(err) - } - - // Verify PID file - if _, err := os.Stat(pidFile); os.IsNotExist(err) { - t.Error("PID file not created") - } - - // Verify log file - if _, err := os.Stat(logFile); os.IsNotExist(err) { - t.Error("Log file not created") - } -} - -func TestCollectionErrors(t *testing.T) { - tests := []struct { - name string - config *Config - setupDB bool - expectError bool - }{ - { - name: "invalid_dest", - config: &Config{ - Dest: "/nonexistent/directory", - }, - setupDB: true, - expectError: true, - }, - { - name: "no_db_mysql_only", - config: &Config{ - MySQLOnly: true, - }, - setupDB: false, - expectError: true, - }, - { - name: "invalid_command", - config: &Config{ - CollectGDB: true, - }, - setupDB: true, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - var db *sql.DB - if tc.setupDB { - var err error - db, err = setupTestDB(t) - if err != nil { - t.Skip("MySQL not available:", err) - } - defer db.Close() - } - - logger, _ := NewLogger("", 3) - stalker := &Stalker{ - config: tc.config, - logger: logger, - } - - ctx := context.Background() - err := stalker.collect(ctx, db, "test") - if tc.expectError && err == nil { - t.Error("Expected error but got none") - } else if !tc.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - }) - } -} - -func TestCollectionTimeout(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-timeout-*") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) - - cfg := &Config{ - Dest: tmpDir, - RunTime: 1, - } - - logger, _ := NewLogger("", 3) - stalker := &Stalker{ - config: cfg, - logger: logger, - } - - db, err := setupTestDB(t) - if err != nil { - t.Skip("MySQL not available:", err) - } - defer db.Close() - - // Create a long-running collection - done := make(chan error) - go func() { - done <- stalker.collect(context.Background(), db, "test") - }() - - select { - case err := <-done: - if err == nil { - t.Error("Expected timeout error") + // Verify expected collectors are registered + expectedCollectors := []string{"mysql", "system"} + for _, name := range expectedCollectors { + if _, ok := registeredCollectors[name]; !ok { + t.Errorf("Expected collector %s not registered", name) } - case <-time.After(time.Duration(cfg.RunTime+1) * time.Second): - t.Error("Collection did not timeout as expected") } } diff --git a/src/go/pt-stalk/main.go b/src/go/pt-stalk/main.go index a1c35328f..f4da9bf10 100644 --- a/src/go/pt-stalk/main.go +++ b/src/go/pt-stalk/main.go @@ -1,215 +1,89 @@ package main import ( - "context" - "flag" "fmt" + "log" "os" - "os/signal" - "path/filepath" - "syscall" - "time" - "github.com/sevlyar/go-daemon" + "github.com/spf13/cobra" ) type Config struct { - Function string - Variable string - Match string - Threshold float64 - Cycles int - Interval int - RunTime int - Sleep int - SleepCollect int - Dest string - Prefix string - CollectGDB bool - CollectOProfile bool - CollectStrace bool - CollectTcpdump bool - Socket string - Host string - Port int - User string - Password string - DefaultsFile string - Log string - Pid string - Plugin string - Daemonize bool - SystemOnly bool - MySQLOnly bool - RetentionTime int - RetentionCount int - RetentionSize int - DiskBytesFree int64 - DiskPctFree int - NotifyByEmail string - Verbose int + // Common configuration only + Collectors string + Interval int + RunTime int + Sleep int + SleepCollect int + Dest string + Prefix string + Log string + Pid string + Daemonize bool + RetentionTime int + RetentionCount int + RetentionSize int + DiskBytesFree int64 + DiskPctFree int + NotifyByEmail string + Verbose int + Plugin string + + // Collector configs + CollectorConfigs map[string]interface{} } -type Stalker struct { - config *Config - ctx context.Context - cancel context.CancelFunc - logger *Logger - plugin *Plugin -} - -func main() { - cfg := &Config{} - - // Parse command line flags - flag.StringVar(&cfg.Function, "function", "status", "Trigger function (status|processlist)") - flag.StringVar(&cfg.Variable, "variable", "Threads_running", "Variable to monitor") - flag.StringVar(&cfg.Match, "match", "", "Pattern to match (for processlist)") - flag.Float64Var(&cfg.Threshold, "threshold", 25, "Threshold value") - flag.IntVar(&cfg.Cycles, "cycles", 5, "Number of cycles before collecting") - flag.IntVar(&cfg.Interval, "interval", 1, "Check interval in seconds") - flag.IntVar(&cfg.RunTime, "run-time", 30, "How long to collect data in seconds") - flag.IntVar(&cfg.Sleep, "sleep", 300, "How long to sleep after collection") - flag.IntVar(&cfg.SleepCollect, "sleep-collect", 1, "How long to sleep between collection cycles") - flag.StringVar(&cfg.Dest, "dest", "/var/lib/pt-stalk", "Output destination directory") - flag.StringVar(&cfg.Prefix, "prefix", "", "Filename prefix for samples") - flag.BoolVar(&cfg.CollectGDB, "collect-gdb", false, "Collect GDB stacktraces") - flag.BoolVar(&cfg.CollectOProfile, "collect-oprofile", false, "Collect OProfile data") - flag.BoolVar(&cfg.CollectStrace, "collect-strace", false, "Collect strace data") - flag.BoolVar(&cfg.CollectTcpdump, "collect-tcpdump", false, "Collect tcpdump data") - flag.StringVar(&cfg.Socket, "socket", "", "MySQL socket file") - flag.StringVar(&cfg.Host, "host", "", "MySQL host") - flag.IntVar(&cfg.Port, "port", 3306, "MySQL port") - flag.StringVar(&cfg.User, "user", "", "MySQL user") - flag.StringVar(&cfg.Password, "password", "", "MySQL password") - flag.StringVar(&cfg.DefaultsFile, "defaults-file", "", "MySQL defaults file") - flag.StringVar(&cfg.Log, "log", "/var/log/pt-stalk.log", "Log file when daemonized") - flag.StringVar(&cfg.Pid, "pid", "/var/run/pt-stalk.pid", "PID file") - flag.BoolVar(&cfg.Daemonize, "daemonize", false, "Run as daemon") - flag.BoolVar(&cfg.SystemOnly, "system-only", false, "Collect only system metrics") - flag.BoolVar(&cfg.MySQLOnly, "mysql-only", false, "Collect only MySQL metrics") - flag.IntVar(&cfg.RetentionTime, "retention-time", 30, "Days to retain samples") - flag.IntVar(&cfg.RetentionCount, "retention-count", 0, "Number of samples to retain") - flag.IntVar(&cfg.RetentionSize, "retention-size", 0, "Maximum size in MB to retain") - flag.Int64Var(&cfg.DiskBytesFree, "disk-bytes-free", 100*1024*1024, "Minimum bytes free") - flag.IntVar(&cfg.DiskPctFree, "disk-pct-free", 5, "Minimum percent free") - flag.StringVar(&cfg.NotifyByEmail, "notify-by-email", "", "Email address for notifications") - flag.IntVar(&cfg.Verbose, "verbose", 2, "Verbosity level (0-3)") - flag.Parse() - - // Setup signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - - // Create context with cancellation - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Initialize logger - logger, err := NewLogger(cfg.Log, cfg.Verbose) - if err != nil { - fmt.Fprintf(os.Stderr, "Error initializing logger: %v\n", err) - os.Exit(1) - } - - stalker := &Stalker{ - config: cfg, - ctx: ctx, - cancel: cancel, - logger: logger, - } - - // Initialize plugin if specified - if err := stalker.initPlugin(); err != nil { - fmt.Fprintf(os.Stderr, "Error initializing plugin: %v\n", err) - os.Exit(1) - } - - // Create PID file if daemonizing - if cfg.Daemonize { - if err := createPIDFile(cfg.Pid); err != nil { - stalker.logger.Error("Failed to create PID file: %v", err) - os.Exit(1) - } - defer os.Remove(cfg.Pid) - } - - // Create destination directory if it doesn't exist - if err := os.MkdirAll(cfg.Dest, 0755); err != nil { - stalker.logger.Error("Failed to create destination directory: %v", err) - os.Exit(1) - } - - // Start stalking in a goroutine - errChan := make(chan error, 1) - go func() { - errChan <- stalker.Stalk() - }() - - // Wait for signal or error - select { - case sig := <-sigChan: - stalker.logger.Info("Received signal: %v", sig) - cancel() - case err := <-errChan: - if err != nil { - stalker.logger.Error("Stalking error: %v", err) - os.Exit(1) - } +func newRootCmd() *cobra.Command { + rootCmd := &cobra.Command{ + Use: "pt-stalk", + Short: "MySQL and system metrics collector", + RunE: func(cmd *cobra.Command, args []string) error { + cfg := cmd.Context().Value("config").(*Config) + logger := log.New(os.Stderr, "", log.LstdFlags) + + stalker, err := NewStalker(cfg, logger) + if err != nil { + return fmt.Errorf("failed to initialize stalker: %v", err) + } + + return stalker.Run(cmd.Context()) + }, } - // Wait for cleanup - cleanup := make(chan struct{}) - go func() { - stalker.cleanup() - close(cleanup) - }() - - select { - case <-cleanup: - case <-time.After(time.Duration(cfg.RunTime*3) * time.Second): - stalker.logger.Warn("Cleanup timed out") + cfg := &Config{ + CollectorConfigs: make(map[string]interface{}), } -} - -func createPIDFile(pidFile string) error { - if _, err := os.Stat(pidFile); err == nil { - // PID file exists, check if process is running - pidBytes, err := os.ReadFile(pidFile) - if err != nil { - return fmt.Errorf("failed to read PID file: %v", err) - } - pid := string(pidBytes) - if _, err := os.Stat(filepath.Join("/proc", pid)); err == nil { - return fmt.Errorf("process %s is already running", pid) - } + rootCmd.PersistentFlags().StringVar(&cfg.Collectors, "collectors", "", "Comma-separated list of collectors to enable (mysql,system)") + rootCmd.PersistentFlags().IntVar(&cfg.Interval, "interval", 1, "Check interval in seconds") + rootCmd.PersistentFlags().IntVar(&cfg.RunTime, "run-time", 30, "How long to collect data in seconds") + rootCmd.PersistentFlags().IntVar(&cfg.Sleep, "sleep", 1, "Sleep time between collections in seconds") + rootCmd.PersistentFlags().StringVar(&cfg.Dest, "dest", "/var/lib/pt-stalk", "Destination directory for collected data") + rootCmd.PersistentFlags().StringVar(&cfg.Prefix, "prefix", "", "Filename prefix for samples") + rootCmd.PersistentFlags().StringVar(&cfg.Log, "log", "/var/log/pt-stalk.log", "Log file when daemonized") + rootCmd.PersistentFlags().StringVar(&cfg.Pid, "pid", "/var/run/pt-stalk.pid", "PID file") + rootCmd.PersistentFlags().BoolVar(&cfg.Daemonize, "daemonize", false, "Run as daemon") + rootCmd.PersistentFlags().IntVar(&cfg.RetentionTime, "retention-time", 30, "Days to retain samples") + rootCmd.PersistentFlags().IntVar(&cfg.RetentionCount, "retention-count", 0, "Number of samples to retain") + rootCmd.PersistentFlags().IntVar(&cfg.RetentionSize, "retention-size", 0, "Maximum size in MB to retain") + rootCmd.PersistentFlags().Int64Var(&cfg.DiskBytesFree, "disk-bytes-free", 100*1024*1024, "Minimum bytes free") + rootCmd.PersistentFlags().IntVar(&cfg.DiskPctFree, "disk-pct-free", 5, "Minimum percent free") + rootCmd.PersistentFlags().StringVar(&cfg.NotifyByEmail, "notify-by-email", "", "Email address for notifications") + rootCmd.PersistentFlags().IntVar(&cfg.Verbose, "verbose", 2, "Verbosity level (0-3)") + rootCmd.PersistentFlags().StringVar(&cfg.Plugin, "plugin", "", "Path to plugin script") + + // Add collector-specific flags + for _, reg := range registeredCollectors { + reg.AddFlags(rootCmd, cfg.CollectorConfigs) } - return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0644) + return rootCmd } -func (s *Stalker) daemonize() error { - if !s.config.Daemonize { - return nil - } - - cntxt := &daemon.Context{ - PidFileName: s.config.Pid, - PidFilePerm: 0644, - LogFileName: s.config.Log, - LogFilePerm: 0640, - WorkDir: "./", - Umask: 027, - } - - d, err := cntxt.Reborn() - if err != nil { - return fmt.Errorf("failed to daemonize: %v", err) - } - if d != nil { - os.Exit(0) +func main() { + cmd := newRootCmd() + if err := cmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } - - return nil } diff --git a/src/go/pt-stalk/main_test.go b/src/go/pt-stalk/main_test.go index 2301cf08f..7b37f74ef 100644 --- a/src/go/pt-stalk/main_test.go +++ b/src/go/pt-stalk/main_test.go @@ -2,98 +2,54 @@ package main import ( "context" - "database/sql" "os" "path/filepath" "testing" "time" - - _ "github.com/go-sql-driver/mysql" ) -func TestStalker(t *testing.T) { - // Create temporary directory for test outputs - tmpDir, err := os.MkdirTemp("", "pt-stalk-test-*") +func TestMainCommand(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-main-test-*") if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) + t.Fatal(err) } defer os.RemoveAll(tmpDir) // Test configuration cfg := &Config{ - Function: "status", - Variable: "Threads_running", - Threshold: 5, - Cycles: 2, - Interval: 1, - RunTime: 5, - Sleep: 2, - Dest: tmpDir, - Host: "localhost", - Port: 3306, - User: os.Getenv("MYSQL_TEST_USER"), - Password: os.Getenv("MYSQL_TEST_PASS"), - Verbose: 3, - DiskBytesFree: 1024 * 1024, // 1MB - DiskPctFree: 1, - } - - // Initialize logger - logger, err := NewLogger("", cfg.Verbose) - if err != nil { - t.Fatalf("Failed to create logger: %v", err) + Collectors: "mysql,system", + Interval: 1, + RunTime: 2, + Sleep: 1, + Dest: tmpDir, + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "mysql": &MySQLConfig{ + Host: "localhost", + Port: 3306, + User: os.Getenv("MYSQL_TEST_USER"), + Password: os.Getenv("MYSQL_TEST_PASS"), + }, + "system": &SystemConfig{ + CollectGDB: true, + }, + }, } // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Initialize stalker - stalker := &Stalker{ - config: cfg, - ctx: ctx, - cancel: cancel, - logger: logger, - } + // Set up command + cmd := newRootCmd() + cmd.SetContext(context.WithValue(ctx, "config", cfg)) - // Test MySQL connection - db, err := sql.Open("mysql", stalker.buildDSN()) - if err != nil { - t.Skipf("Skipping test, could not connect to MySQL: %v", err) - } - defer db.Close() - - if err := db.Ping(); err != nil { - t.Skipf("Skipping test, MySQL not responding: %v", err) + // Execute command + if err := cmd.Execute(); err != nil && err != context.DeadlineExceeded { + t.Errorf("Command execution failed: %v", err) } - // Run stalker in goroutine - errChan := make(chan error, 1) - go func() { - errChan <- stalker.Stalk() - }() - - // Create some test load - go func() { - for i := 0; i < 10; i++ { - db.Exec("SELECT SLEEP(1)") - time.Sleep(time.Second) - } - }() - - // Wait for stalker to finish or timeout - select { - case err := <-errChan: - if err != nil { - t.Errorf("Stalker failed: %v", err) - } - case <-ctx.Done(): - if ctx.Err() != context.DeadlineExceeded { - t.Errorf("Unexpected context error: %v", ctx.Err()) - } - } - - // Verify outputs + // Verify output directory structure files, err := os.ReadDir(tmpDir) if err != nil { t.Fatalf("Failed to read output directory: %v", err) @@ -103,110 +59,85 @@ func TestStalker(t *testing.T) { t.Error("No output files were created") } - // Check specific files - expectedFiles := []string{ - "mysql-variables.txt", - "mysql-status.txt", - "mysql-processlist.txt", - "uptime.txt", - "vmstat.txt", - "iostat.txt", + // Check for collector outputs + expectedFiles := map[string]bool{ + "mysql": false, + "system": false, } - for _, dir := range files { - if !dir.IsDir() { - continue + for _, file := range files { + if file.IsDir() { + for collector := range expectedFiles { + if _, err := os.Stat(filepath.Join(tmpDir, file.Name(), collector)); !os.IsNotExist(err) { + expectedFiles[collector] = true + } + } } + } - for _, expected := range expectedFiles { - path := filepath.Join(tmpDir, dir.Name(), expected) - if _, err := os.Stat(path); os.IsNotExist(err) { - t.Errorf("Expected file %s not found", path) - } + for collector, found := range expectedFiles { + if !found { + t.Errorf("Expected output for %s collector not found", collector) } } } -func TestPluginExecution(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) +func TestMainCommandFlags(t *testing.T) { + cmd := newRootCmd() - // Create test plugin - pluginContent := `#!/bin/bash -before_stalk() { - echo "before_stalk called" - return 0 -} -before_collect() { - echo "before_collect called with $1" - return 0 -} -after_collect() { - echo "after_collect called with $1" - return 0 -}` - - pluginPath := filepath.Join(tmpDir, "test-plugin.sh") - if err := os.WriteFile(pluginPath, []byte(pluginContent), 0755); err != nil { - t.Fatalf("Failed to write test plugin: %v", err) + // Test required flags + if cmd.Flag("collectors") == nil { + t.Error("Required flag 'collectors' not found") } - logger, err := NewLogger("", 3) - if err != nil { - t.Fatalf("Failed to create logger: %v", err) + // Test MySQL collector flags + if cmd.Flag("mysql-host") == nil { + t.Error("MySQL flag 'mysql-host' not found") } - - plugin, err := NewPlugin(pluginPath, logger) - if err != nil { - t.Fatalf("Failed to create plugin: %v", err) + if cmd.Flag("mysql-port") == nil { + t.Error("MySQL flag 'mysql-port' not found") } - // Test each hook - hooks := []struct { - hook PluginHook - args []string - }{ - {BeforeStalk, nil}, - {BeforeCollect, []string{"test_prefix"}}, - {AfterCollect, []string{"test_prefix"}}, + // Test System collector flags + if cmd.Flag("collect-gdb") == nil { + t.Error("System flag 'collect-gdb' not found") } - - for _, tc := range hooks { - err := plugin.Execute(tc.hook, tc.args...) - if err != nil { - t.Errorf("Plugin execution failed for %s: %v", tc.hook, err) - } + if cmd.Flag("collect-tcpdump") == nil { + t.Error("System flag 'collect-tcpdump' not found") } } -func TestSizeParser(t *testing.T) { +func TestMainCommandValidation(t *testing.T) { tests := []struct { - input string - expected int64 - hasError bool + name string + args []string + wantErr bool }{ - {"1K", 1024, false}, - {"1M", 1024 * 1024, false}, - {"1G", 1024 * 1024 * 1024, false}, - {"1T", 1024 * 1024 * 1024 * 1024, false}, - {"1.5G", 1610612736, false}, - {"1024", 1024, false}, - {"invalid", 0, true}, - } - - for _, tc := range tests { - result, err := ParseSize(tc.input) - if tc.hasError && err == nil { - t.Errorf("Expected error for input %s, got none", tc.input) - } - if !tc.hasError && err != nil { - t.Errorf("Unexpected error for input %s: %v", tc.input, err) - } - if !tc.hasError && result != tc.expected { - t.Errorf("For input %s, expected %d, got %d", tc.input, tc.expected, result) - } + { + name: "no collectors", + args: []string{}, + wantErr: true, + }, + { + name: "invalid collector", + args: []string{"--collectors=invalid"}, + wantErr: true, + }, + { + name: "valid collectors", + args: []string{"--collectors=mysql,system"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs(tt.args) + err := cmd.Execute() + if (err != nil) != tt.wantErr { + t.Errorf("Command execution error = %v, wantErr %v", err, tt.wantErr) + } + }) } } diff --git a/src/go/pt-stalk/plugin.go b/src/go/pt-stalk/plugin.go index d873a61d0..3812a5763 100644 --- a/src/go/pt-stalk/plugin.go +++ b/src/go/pt-stalk/plugin.go @@ -1,165 +1,75 @@ package main import ( + "context" "fmt" "os" "os/exec" "path/filepath" - "strings" ) type Plugin struct { path string + config *Config env map[string]string - logger *Logger } -type PluginHook string - -const ( - BeforeStalk PluginHook = "before_stalk" - BeforeCollect PluginHook = "before_collect" - AfterCollect PluginHook = "after_collect" - AfterCollectSleep PluginHook = "after_collect_sleep" - AfterIntervalSleep PluginHook = "after_interval_sleep" - AfterStalk PluginHook = "after_stalk" -) - -func NewPlugin(path string, logger *Logger) (*Plugin, error) { +func NewPlugin(path string, config *Config) (*Plugin, error) { if path == "" { - return nil, nil - } - - absPath, err := filepath.Abs(path) - if err != nil { - return nil, fmt.Errorf("failed to resolve plugin path: %v", err) + return nil, nil // No plugin configured } - if _, err := os.Stat(absPath); err != nil { - return nil, fmt.Errorf("plugin file not found: %v", err) + if _, err := os.Stat(path); err != nil { + return nil, fmt.Errorf("plugin not found: %v", err) } return &Plugin{ - path: absPath, + path: path, + config: config, env: make(map[string]string), - logger: logger, }, nil } func (p *Plugin) SetEnv(key, value string) { - if p != nil { - p.env[key] = value - } + p.env[key] = value } -func (p *Plugin) Execute(hook PluginHook, args ...string) error { +func (p *Plugin) Execute(ctx context.Context) error { if p == nil { - return nil + return nil // No plugin configured } - p.logger.Debug("Executing plugin hook: %s", hook) + cmd := exec.CommandContext(ctx, p.path) - // Prepare environment variables - env := os.Environ() + // Set up environment + cmd.Env = os.Environ() // Start with current environment for k, v := range p.env { - env = append(env, fmt.Sprintf("%s=%s", k, v)) - } - - // Add hook name to environment - env = append(env, fmt.Sprintf("PT_HOOK=%s", hook)) - - // Create temporary script to execute the plugin - tmpScript, err := os.CreateTemp("", "pt-stalk-plugin-*.sh") - if err != nil { - return fmt.Errorf("failed to create temporary script: %v", err) - } - defer os.Remove(tmpScript.Name()) - - // Write plugin execution script - script := fmt.Sprintf(`#!/bin/bash -source "%s" -if type %s >/dev/null 2>&1; then - %s "$@" - exit $? -else - exit 0 -fi -`, p.path, hook, hook) - - if _, err := tmpScript.WriteString(script); err != nil { - return fmt.Errorf("failed to write plugin script: %v", err) - } - - if err := tmpScript.Close(); err != nil { - return fmt.Errorf("failed to close plugin script: %v", err) - } - - if err := os.Chmod(tmpScript.Name(), 0755); err != nil { - return fmt.Errorf("failed to make plugin script executable: %v", err) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } - // Execute the plugin - cmd := exec.Command(tmpScript.Name(), args...) - cmd.Env = env - cmd.Dir = filepath.Dir(p.path) - - // Capture output - output, err := cmd.CombinedOutput() + // Add standard plugin environment variables + cmd.Env = append(cmd.Env, + fmt.Sprintf("PT_DEST=%s", p.config.Dest), + fmt.Sprintf("PT_PREFIX=%s", p.config.Prefix), + fmt.Sprintf("PT_INTERVAL=%d", p.config.Interval), + fmt.Sprintf("PT_RUNTIME=%d", p.config.RunTime), + ) + + // Set up output + outputFile := filepath.Join(p.config.Dest, p.config.Prefix+"_plugin.txt") + output, err := os.Create(outputFile) if err != nil { - return fmt.Errorf("plugin hook %s failed: %v\nOutput: %s", hook, err, output) + return fmt.Errorf("failed to create plugin output file: %v", err) } + defer output.Close() - if len(output) > 0 { - p.logger.Debug("Plugin output (%s):\n%s", hook, strings.TrimSpace(string(output))) - } - - return nil -} + cmd.Stdout = output + cmd.Stderr = output -// Helper methods for the Stalker struct to handle plugins -func (s *Stalker) initPlugin() error { - if s.config.Plugin != "" { - plugin, err := NewPlugin(s.config.Plugin, s.logger) - if err != nil { - return fmt.Errorf("failed to initialize plugin: %v", err) - } - s.plugin = plugin - - // Set up environment variables for the plugin - s.plugin.SetEnv("PT_DEST", s.config.Dest) - s.plugin.SetEnv("PT_MYSQL_USER", s.config.User) - s.plugin.SetEnv("PT_MYSQL_HOST", s.config.Host) - s.plugin.SetEnv("PT_MYSQL_PORT", fmt.Sprintf("%d", s.config.Port)) - s.plugin.SetEnv("PT_INTERVAL", fmt.Sprintf("%d", s.config.Interval)) - s.plugin.SetEnv("PT_SLEEP", fmt.Sprintf("%d", s.config.Sleep)) - s.plugin.SetEnv("PT_FUNCTION", s.config.Function) - s.plugin.SetEnv("PT_VARIABLE", s.config.Variable) - s.plugin.SetEnv("PT_THRESHOLD", fmt.Sprintf("%f", s.config.Threshold)) + // Execute plugin + if err := cmd.Run(); err != nil { + return fmt.Errorf("plugin execution failed: %v", err) } - return nil -} -func (s *Stalker) executePluginHook(hook PluginHook, args ...string) error { - if s.plugin != nil { - return s.plugin.Execute(hook, args...) - } return nil } - -// Example plugin usage in the Stalker.Stalk() method: -/* - // Before starting to stalk - if err := s.executePluginHook(BeforeStalk); err != nil { - return fmt.Errorf("plugin before_stalk hook failed: %v", err) - } - - // Before collecting metrics - if err := s.executePluginHook(BeforeCollect, prefix); err != nil { - return fmt.Errorf("plugin before_collect hook failed: %v", err) - } - - // After collecting metrics - if err := s.executePluginHook(AfterCollect, prefix); err != nil { - s.logger.Warn("Plugin after_collect hook failed: %v", err) - } -*/ diff --git a/src/go/pt-stalk/plugin_test.go b/src/go/pt-stalk/plugin_test.go new file mode 100644 index 000000000..34e80191b --- /dev/null +++ b/src/go/pt-stalk/plugin_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestPlugin(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + // Create test plugin + pluginContent := `#!/bin/sh +echo "Test plugin output" +echo "PT_DEST=$PT_DEST" +echo "PT_PREFIX=$PT_PREFIX" +` + pluginPath := filepath.Join(tmpDir, "test-plugin.sh") + if err := os.WriteFile(pluginPath, []byte(pluginContent), 0755); err != nil { + t.Fatal(err) + } + + cfg := &Config{ + Dest: tmpDir, + Prefix: "test", + Interval: 1, + RunTime: 30, + } + + // Test plugin creation + plugin, err := NewPlugin(pluginPath, cfg) + if err != nil { + t.Fatal(err) + } + + // Test plugin environment + plugin.SetEnv("TEST_VAR", "test_value") + + // Test plugin execution + err = plugin.Execute(context.Background()) + if err != nil { + t.Fatal(err) + } + + // Verify plugin output + outputFile := filepath.Join(tmpDir, "test_plugin.txt") + content, err := os.ReadFile(outputFile) + if err != nil { + t.Fatal(err) + } + + if len(content) == 0 { + t.Error("Plugin output is empty") + } +} + +func TestPluginNotFound(t *testing.T) { + cfg := &Config{ + Dest: "/tmp", + Prefix: "test", + } + + _, err := NewPlugin("/nonexistent/plugin", cfg) + if err == nil { + t.Error("Expected error for nonexistent plugin") + } +} + +func TestNoPlugin(t *testing.T) { + cfg := &Config{ + Dest: "/tmp", + Prefix: "test", + } + + plugin, err := NewPlugin("", cfg) + if err != nil { + t.Fatal(err) + } + if plugin != nil { + t.Error("Expected nil plugin when no path provided") + } +} diff --git a/src/go/pt-stalk/stalk.go b/src/go/pt-stalk/stalk.go index 7caca0c39..eff32fb5d 100644 --- a/src/go/pt-stalk/stalk.go +++ b/src/go/pt-stalk/stalk.go @@ -2,290 +2,76 @@ package main import ( "context" - "database/sql" "fmt" + "log" "os" - "path/filepath" - "strconv" - "syscall" + "strings" "time" - - _ "github.com/go-sql-driver/mysql" -) - -const ( - defaultInterval = 1 - defaultCycles = 1 - defaultRetention = 30 - defaultDiskPctFree = 5 - timeFormat = "2006_01_02_15_04_05" ) -type MetricCollector interface { - Collect(ctx context.Context, prefix string) error -} - -type MetricChecker interface { - Check(ctx context.Context) (bool, error) +type Stalker struct { + config *Config + logger *log.Logger + plugin *Plugin } -type StalkConfig interface { - Validate() error -} - -type LogEntry struct { - Level string - Message string - Time time.Time - Fields map[string]interface{} -} - -func (s *Stalker) buildDSN() string { - dsn := "" - if s.config.DefaultsFile != "" { - dsn += fmt.Sprintf("defaults-file=%s", s.config.DefaultsFile) - } - if s.config.User != "" { - if dsn != "" { - dsn += "&" - } - dsn += fmt.Sprintf("user=%s", s.config.User) - } - if s.config.Password != "" { - if dsn != "" { - dsn += "&" - } - dsn += fmt.Sprintf("password=%s", s.config.Password) - } - if s.config.Socket != "" { - if dsn != "" { - dsn += "&" - } - dsn += fmt.Sprintf("socket=%s", s.config.Socket) - } else { - if s.config.Host != "" { - if dsn != "" { - dsn += "&" - } - dsn += fmt.Sprintf("host=%s", s.config.Host) - } - if s.config.Port != 0 { - if dsn != "" { - dsn += "&" - } - dsn += fmt.Sprintf("port=%d", s.config.Port) - } +func NewStalker(config *Config, logger *log.Logger) (*Stalker, error) { + s := &Stalker{ + config: config, + logger: logger, } - return dsn -} - -func (s *Stalker) Stalk() error { - s.logger.Info("Starting stalker with config: %+v", s.config) - // Don't connect to MySQL if we're only collecting system metrics - var db *sql.DB - var err error - if !s.config.SystemOnly { - db, err = sql.Open("mysql", s.buildDSN()) + // Initialize plugin if configured + if config.Plugin != "" { + plugin, err := NewPlugin(config.Plugin, config) if err != nil { - return fmt.Errorf("failed to connect to MySQL: %v", err) - } - defer db.Close() - - // Test the connection - if err := db.Ping(); err != nil { - return fmt.Errorf("failed to ping MySQL: %v", err) + return nil, fmt.Errorf("failed to initialize plugin: %v", err) } + s.plugin = plugin } - triggerCount := 0 - iteration := 0 + return s, nil +} +func (s *Stalker) Run(ctx context.Context) error { + // Create destination directory if it doesn't exist + if err := os.MkdirAll(s.config.Dest, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %v", err) + } + + // Main collection loop for { select { - case <-s.ctx.Done(): - s.logger.Info("Stalker received shutdown signal") - return nil + case <-ctx.Done(): + return ctx.Err() default: - if s.config.SystemOnly { - // For system-only mode, we treat it as always triggered - triggerCount++ - } else { - triggered, err := s.checkTrigger(db) - if err != nil { - return fmt.Errorf("failed to check trigger: %v", err) - } - - if triggered { - triggerCount++ - s.logger.Info("Trigger condition met (%d/%d)", triggerCount, s.config.Cycles) - } else { - if triggerCount > 0 { - s.logger.Debug("Trigger condition reset (was %d/%d)", triggerCount, s.config.Cycles) - } - triggerCount = 0 - } - } - - if triggerCount >= s.config.Cycles { - s.logger.Info("Trigger threshold reached, starting collection") - - // Generate collection prefix - prefix := s.config.Prefix - if prefix == "" { - prefix = time.Now().Format("2006_01_02_15_04_05") - } - - // Check disk space - if err := s.checkDiskSpace(prefix); err != nil { - s.logger.Error("Disk space check failed: %v", err) - return err - } - - // Start collection - if err := s.collectWithTimeout(db, prefix); err != nil { - s.logger.Error("Collection failed: %v", err) - return err - } - - // Reset trigger count - triggerCount = 0 - iteration++ - - // Sleep after collection - s.logger.Info("Sleeping for %d seconds after collection", s.config.Sleep) - time.Sleep(time.Duration(s.config.Sleep) * time.Second) - - // Check if we've reached max iterations - if s.config.RetentionCount > 0 && iteration >= s.config.RetentionCount { - s.logger.Info("Reached maximum iterations (%d), shutting down", s.config.RetentionCount) - return nil - } + if err := s.runCollectors(ctx); err != nil { + s.logger.Printf("Collection error: %v", err) } - - // Sleep before next check - time.Sleep(time.Duration(s.config.Interval) * time.Second) + time.Sleep(time.Duration(s.config.Sleep) * time.Second) } } } -func (s *Stalker) checkTrigger(db *sql.DB) (bool, error) { - switch s.config.Function { - case "status": - return s.checkStatusTrigger(db) - case "processlist": - return s.checkProcesslistTrigger(db) - default: - return false, fmt.Errorf("unknown function: %s", s.config.Function) - } -} - -func (s *Stalker) checkStatusTrigger(db *sql.DB) (bool, error) { - query := "SHOW GLOBAL STATUS WHERE Variable_name = ?" - var name, value string - err := db.QueryRow(query, s.config.Variable).Scan(&name, &value) - if err != nil { - return false, fmt.Errorf("failed to query status: %v", err) - } - - val, err := strconv.ParseFloat(value, 64) - if err != nil { - return false, fmt.Errorf("failed to parse value %s: %v", value, err) - } - - s.logger.Debug("Status check: %s = %v (threshold: %v)", s.config.Variable, val, s.config.Threshold) - return val > s.config.Threshold, nil -} - -func (s *Stalker) checkProcesslistTrigger(db *sql.DB) (bool, error) { - query := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.PROCESSLIST WHERE State = ?` - var count int - err := db.QueryRow(query, s.config.Match).Scan(&count) - if err != nil { - return false, fmt.Errorf("failed to query processlist: %v", err) - } - - s.logger.Debug("Processlist check: count = %d (threshold: %v)", count, s.config.Threshold) - return float64(count) > s.config.Threshold, nil -} - -func (s *Stalker) checkDiskSpace(prefix string) error { - // Get disk usage information - var stat syscall.Statfs_t - err := syscall.Statfs(s.config.Dest, &stat) - if err != nil { - return fmt.Errorf("failed to get disk stats: %v", err) - } - - // Calculate free space - blockSize := uint64(stat.Bsize) - totalBlocks := stat.Blocks - freeBlocks := stat.Bfree - - totalBytes := totalBlocks * blockSize - freeBytes := freeBlocks * blockSize - freePercent := float64(freeBytes) / float64(totalBytes) * 100 - - // Check if we have enough free space - if freeBytes < uint64(s.config.DiskBytesFree) { - return fmt.Errorf("insufficient free disk space: %d bytes (need %d)", freeBytes, s.config.DiskBytesFree) - } - - if freePercent < float64(s.config.DiskPctFree) { - return fmt.Errorf("insufficient free disk space: %.2f%% (need %d%%)", freePercent, s.config.DiskPctFree) - } - - s.logger.Debug("Disk space check passed: %.2f%% (%.2f GB) free", freePercent, float64(freeBytes)/(1024*1024*1024)) - return nil -} - -func (s *Stalker) cleanup() error { - s.logger.Info("Starting cleanup") - - // Clean up based on retention time - if s.config.RetentionTime > 0 { - cutoff := time.Now().AddDate(0, 0, -s.config.RetentionTime) - err := filepath.Walk(s.config.Dest, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err +func (s *Stalker) runCollectors(ctx context.Context) error { + // Run collectors + enabledCollectors := strings.Split(s.config.Collectors, ",") + for _, name := range enabledCollectors { + name = strings.TrimSpace(name) + if reg, ok := registeredCollectors[name]; ok { + collector := reg.NewCollector(s.config) + if err := collector.Collect(ctx); err != nil { + return fmt.Errorf("collector %s failed: %v", name, err) } - if info.IsDir() && info.ModTime().Before(cutoff) { - if err := os.RemoveAll(path); err != nil { - s.logger.Error("Failed to remove old directory %s: %v", path, err) - return err - } - s.logger.Info("Removed old directory %s", path) - } - return nil - }) - if err != nil { - s.logger.Error("Error during retention cleanup: %v", err) - return err } } - return nil -} -func (s *Stalker) collectWithTimeout(db *sql.DB, prefix string) error { - ctx, cancel := context.WithTimeout(context.Background(), - time.Duration(s.config.RunTime)*time.Second) - defer cancel() - - return s.collect(ctx, db, prefix) -} - -type Metrics struct { - TriggersTotal int64 - CollectionsTotal int64 - ErrorsTotal int64 - // ... etc -} - -type CollectionError struct { - Prefix string - Err error -} + // Run plugin if configured + if s.plugin != nil { + if err := s.plugin.Execute(ctx); err != nil { + return fmt.Errorf("plugin execution failed: %v", err) + } + } -func (e *CollectionError) Error() string { - return fmt.Sprintf("collection failed for %s: %v", e.Prefix, e.Err) + return nil } diff --git a/src/go/pt-stalk/stalk_test.go b/src/go/pt-stalk/stalk_test.go index daf1460b0..62ae5c914 100644 --- a/src/go/pt-stalk/stalk_test.go +++ b/src/go/pt-stalk/stalk_test.go @@ -1,197 +1,110 @@ package main import ( - "database/sql" + "context" + "log" "os" "path/filepath" "testing" "time" ) -func TestTriggerFunctions(t *testing.T) { - tests := []struct { - name string - function string - variable string - match string - threshold float64 - expected bool - }{ - {"status_threads_running", "status", "Threads_running", "", 5, false}, - {"processlist_sleep", "processlist", "", "Sleep", 10, false}, - {"invalid_function", "invalid", "", "", 0, false}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - cfg := &Config{ - Function: tc.function, - Variable: tc.variable, - Match: tc.match, - Threshold: tc.threshold, - } - stalker := &Stalker{config: cfg} - - // Setup test database connection - db, err := setupTestDB(t) - if err != nil { - t.Skip("MySQL not available:", err) - } - defer db.Close() - - triggered, err := stalker.checkTrigger(db) - if tc.function == "invalid" { - if err == nil { - t.Error("Expected error for invalid function") - } - } else if err != nil { - t.Errorf("Unexpected error: %v", err) - } else if triggered != tc.expected { - t.Errorf("Expected triggered=%v, got %v", tc.expected, triggered) - } - }) - } -} - -func TestRetention(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-retention-*") +func TestStalkerBasicOperation(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-test-*") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpDir) - // Create some test files with different dates - dates := []struct { - dir string - time time.Time - }{ - {"old", time.Now().AddDate(0, 0, -31)}, - {"new", time.Now()}, - } - - for _, d := range dates { - dir := filepath.Join(tmpDir, d.dir) - if err := os.MkdirAll(dir, 0755); err != nil { - t.Fatal(err) - } - if err := os.Chtimes(dir, d.time, d.time); err != nil { - t.Fatal(err) - } - } + logger := log.New(os.Stderr, "", log.LstdFlags) cfg := &Config{ - Dest: tmpDir, - RetentionTime: 30, - } - stalker := &Stalker{config: cfg} - - if err := stalker.cleanup(); err != nil { - t.Fatal(err) - } - - // Check that old directory was removed and new remains - if _, err := os.Stat(filepath.Join(tmpDir, "old")); !os.IsNotExist(err) { - t.Error("Old directory should have been removed") - } - if _, err := os.Stat(filepath.Join(tmpDir, "new")); os.IsNotExist(err) { - t.Error("New directory should still exist") - } -} - -func TestDiskSpace(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-disk-*") + Collectors: "mysql,system", + Interval: 1, + Sleep: 1, + Dest: tmpDir, + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "mysql": &MySQLConfig{ + Host: "localhost", + Port: 3306, + }, + "system": &SystemConfig{ + CollectGDB: true, + }, + }, + } + + stalker, err := NewStalker(cfg, logger) if err != nil { t.Fatal(err) } - defer os.RemoveAll(tmpDir) - tests := []struct { - name string - bytesFree int64 - pctFree int - shouldError bool - }{ - {"sufficient_space", 1024 * 1024 * 1024, 10, false}, - {"insufficient_bytes", 1024, 10, true}, - {"insufficient_percent", 1024 * 1024 * 1024, 99, true}, + // Run stalker with timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = stalker.Run(ctx) + if err != context.DeadlineExceeded { + t.Errorf("Expected deadline exceeded error, got: %v", err) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - cfg := &Config{ - Dest: tmpDir, - DiskBytesFree: tc.bytesFree, - DiskPctFree: tc.pctFree, - } - stalker := &Stalker{config: cfg} - - err := stalker.checkDiskSpace("test") - if tc.shouldError && err == nil { - t.Error("Expected disk space error") - } else if !tc.shouldError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - }) + // Verify destination directory was created + if _, err := os.Stat(tmpDir); os.IsNotExist(err) { + t.Error("Destination directory was not created") } } -func TestPluginHooks(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-*") +func TestStalkerWithPluginExecution(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "pt-stalk-plugin-test-*") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpDir) // Create test plugin - pluginContent := `#!/bin/bash -before_stalk() { echo "before_stalk"; } -before_collect() { echo "before_collect $1"; } -after_collect() { echo "after_collect $1"; } -after_collect_sleep() { echo "after_collect_sleep"; } -after_interval_sleep() { echo "after_interval_sleep"; } -after_stalk() { echo "after_stalk"; } + pluginContent := `#!/bin/sh +echo "Test plugin output" ` - pluginPath := filepath.Join(tmpDir, "test.sh") + pluginPath := filepath.Join(tmpDir, "test-plugin.sh") if err := os.WriteFile(pluginPath, []byte(pluginContent), 0755); err != nil { t.Fatal(err) } - logger, _ := NewLogger("", 3) - cfg := &Config{ - Plugin: pluginPath, - } - stalker := &Stalker{ - config: cfg, - logger: logger, - } + logger := log.New(os.Stderr, "", log.LstdFlags) - if err := stalker.initPlugin(); err != nil { + cfg := &Config{ + Collectors: "mysql", + Interval: 1, + Sleep: 1, + Dest: tmpDir, + Prefix: "test", + Plugin: pluginPath, + CollectorConfigs: map[string]interface{}{ + "mysql": &MySQLConfig{ + Host: "localhost", + Port: 3306, + }, + }, + } + + stalker, err := NewStalker(cfg, logger) + if err != nil { t.Fatal(err) } - hooks := []struct { - hook PluginHook - args []string - }{ - {BeforeStalk, nil}, - {BeforeCollect, []string{"test"}}, - {AfterCollect, []string{"test"}}, - {AfterCollectSleep, nil}, - {AfterIntervalSleep, nil}, - {AfterStalk, nil}, - } + // Run stalker with timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() - for _, h := range hooks { - if err := stalker.executePluginHook(h.hook, h.args...); err != nil { - t.Errorf("Hook %s failed: %v", h.hook, err) - } + err = stalker.Run(ctx) + if err != context.DeadlineExceeded { + t.Errorf("Expected deadline exceeded error, got: %v", err) } -} -func setupTestDB(t *testing.T) (*sql.DB, error) { - dsn := os.Getenv("MYSQL_TEST_DSN") - if dsn == "" { - dsn = "root@tcp(localhost:3306)/test" + // Verify plugin output file exists + pluginOutput := filepath.Join(tmpDir, "test_plugin.txt") + if _, err := os.Stat(pluginOutput); os.IsNotExist(err) { + t.Error("Plugin output file was not created") } - return sql.Open("mysql", dsn) } From e0e2756efb8158cfab965aecbf71b8c9077b483e Mon Sep 17 00:00:00 2001 From: David Murphy Date: Tue, 31 Dec 2024 14:17:19 -0600 Subject: [PATCH 3/3] feat(stalk): add mongodb support --- src/go/pt-stalk/README.md | 22 +++- src/go/pt-stalk/collect_mongodb.go | 140 ++++++++++++++++++++++++ src/go/pt-stalk/collect_mongodb_test.go | 86 +++++++++++++++ 3 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 src/go/pt-stalk/collect_mongodb.go create mode 100644 src/go/pt-stalk/collect_mongodb_test.go diff --git a/src/go/pt-stalk/README.md b/src/go/pt-stalk/README.md index 7371fd0d1..bf36f2262 100644 --- a/src/go/pt-stalk/README.md +++ b/src/go/pt-stalk/README.md @@ -70,6 +70,17 @@ Example plugin script (plugin.sh): echo "Custom collection started" > "$PT_DEST/${PT_PREFIX}_custom.txt" # Add your custom collection logic here +### Basic MongoDB Monitoring + +Collect MongoDB metrics: + + pt-stalk --collectors=mongodb \ + --mongodb-host=localhost \ + --mongodb-user=myuser \ + --mongodb-password=secret \ + --dest=/var/log/mongodb/samples \ + --interval=1 + ## Configuration Options ### Common Options @@ -106,6 +117,12 @@ Example plugin script (plugin.sh): - --notify-by-email: Email address for notifications - --verbose: Verbosity level (0-3) (default: 2) +### MongoDB Collector Options +- --mongodb-host: MongoDB host (default: localhost) +- --mongodb-port: MongoDB port (default: 27017) +- --mongodb-user: MongoDB user +- --mongodb-password: MongoDB password + ## Output Files Each collection creates files with the specified prefix and timestamps: @@ -115,4 +132,7 @@ Each collection creates files with the specified prefix and timestamps: - {prefix}_diskstats.txt: System disk statistics - {prefix}_meminfo.txt: System memory information - {prefix}_loadavg.txt: System load average -- {prefix}_plugin.txt: Custom plugin output (if configured) \ No newline at end of file +- {prefix}_plugin.txt: Custom plugin output (if configured) +- {prefix}_server_status.txt: MongoDB server status metrics +- {prefix}_current_op.txt: MongoDB currently running operations +- {prefix}_db_stats.txt: MongoDB database statistics \ No newline at end of file diff --git a/src/go/pt-stalk/collect_mongodb.go b/src/go/pt-stalk/collect_mongodb.go new file mode 100644 index 000000000..c2f099dc2 --- /dev/null +++ b/src/go/pt-stalk/collect_mongodb.go @@ -0,0 +1,140 @@ +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/spf13/cobra" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type MongoDBCollector struct { + stalker *Stalker + client *mongo.Client + outDir string + prefix string + wg sync.WaitGroup + mongoCfg *MongoDBConfig +} + +func NewMongoDBCollector(config *Config) Collector { + mongoCfg := config.CollectorConfigs["mongodb"].(*MongoDBConfig) + return &MongoDBCollector{ + stalker: nil, + client: nil, + outDir: config.Dest, + prefix: config.Prefix, + mongoCfg: mongoCfg, + } +} + +func (c *MongoDBCollector) Collect(ctx context.Context) error { + if c.client == nil { + uri := fmt.Sprintf("mongodb://%s:%s@%s:%d", + c.mongoCfg.User, + c.mongoCfg.Password, + c.mongoCfg.Host, + c.mongoCfg.Port, + ) + + client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + if err != nil { + return fmt.Errorf("failed to connect to MongoDB: %v", err) + } + c.client = client + defer client.Disconnect(ctx) + } + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.collectServerStatus(ctx) + c.collectCurrentOp(ctx) + c.collectDatabaseStats(ctx) + }() + + c.wg.Wait() + return nil +} + +func (c *MongoDBCollector) collectServerStatus(ctx context.Context) error { + result := bson.M{} + err := c.client.Database("admin").RunCommand(ctx, bson.D{{Key: "serverStatus", Value: 1}}).Decode(&result) + if err != nil { + return err + } + return c.writeResults(result, c.prefix+"_server_status.txt") +} + +func (c *MongoDBCollector) collectCurrentOp(ctx context.Context) error { + result := bson.M{} + err := c.client.Database("admin").RunCommand(ctx, bson.D{{Key: "currentOp", Value: 1}}).Decode(&result) + if err != nil { + return err + } + return c.writeResults(result, c.prefix+"_current_op.txt") +} + +func (c *MongoDBCollector) collectDatabaseStats(ctx context.Context) error { + dbs, err := c.client.ListDatabaseNames(ctx, bson.D{}) + if err != nil { + return err + } + + stats := make(map[string]bson.M) + for _, dbName := range dbs { + result := bson.M{} + err := c.client.Database(dbName).RunCommand(ctx, bson.D{{Key: "dbStats", Value: 1}}).Decode(&result) + if err != nil { + return err + } + stats[dbName] = result + } + return c.writeResults(stats, c.prefix+"_db_stats.txt") +} + +func (c *MongoDBCollector) writeResults(data interface{}, filename string) error { + f, err := os.Create(filepath.Join(c.outDir, filename)) + if err != nil { + return err + } + defer f.Close() + + formatted, err := bson.MarshalExtJSON(data, true, false) + if err != nil { + return err + } + + _, err = f.Write(formatted) + return err +} + +type MongoDBConfig struct { + Host string + Port int + User string + Password string +} + +func addMongoDBFlags(cmd *cobra.Command, cfg map[string]interface{}) { + mongoCfg := &MongoDBConfig{} + cfg["mongodb"] = mongoCfg + + cmd.PersistentFlags().StringVar(&mongoCfg.Host, "mongodb-host", "localhost", "MongoDB host") + cmd.PersistentFlags().IntVar(&mongoCfg.Port, "mongodb-port", 27017, "MongoDB port") + cmd.PersistentFlags().StringVar(&mongoCfg.User, "mongodb-user", "", "MongoDB user") + cmd.PersistentFlags().StringVar(&mongoCfg.Password, "mongodb-password", "", "MongoDB password") +} + +func init() { + RegisterCollector(CollectorRegistration{ + Name: "mongodb", + AddFlags: addMongoDBFlags, + NewCollector: NewMongoDBCollector, + }) +} diff --git a/src/go/pt-stalk/collect_mongodb_test.go b/src/go/pt-stalk/collect_mongodb_test.go new file mode 100644 index 000000000..a88db2976 --- /dev/null +++ b/src/go/pt-stalk/collect_mongodb_test.go @@ -0,0 +1,86 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMongoDBCollector(t *testing.T) { + // Skip if no MongoDB connection available + mongoURI := os.Getenv("TEST_MONGODB_URI") + if mongoURI == "" { + t.Skip("Skipping MongoDB tests: TEST_MONGODB_URI not set") + } + + // Create temp directory for test outputs + tmpDir, err := os.MkdirTemp("", "mongodb-collector-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &Config{ + Dest: tmpDir, + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "mongodb": &MongoDBConfig{ + Host: "localhost", + Port: 27017, + User: "testuser", + Password: "testpass", + }, + }, + } + + // Create collector + collector := NewMongoDBCollector(cfg) + assert.NotNil(t, collector) + + // Test collection + ctx := context.Background() + err = collector.Collect(ctx) + assert.NoError(t, err) + + // Verify output files exist + expectedFiles := []string{ + "test_server_status.txt", + "test_current_op.txt", + "test_db_stats.txt", + } + + for _, file := range expectedFiles { + path := filepath.Join(tmpDir, file) + _, err := os.Stat(path) + assert.NoError(t, err, "Expected file %s to exist", file) + + // Verify file is not empty + content, err := os.ReadFile(path) + assert.NoError(t, err) + assert.NotEmpty(t, content) + } +} + +func TestMongoDBCollectorConnection(t *testing.T) { + // Test invalid connection + cfg := &Config{ + Dest: os.TempDir(), + Prefix: "test", + CollectorConfigs: map[string]interface{}{ + "mongodb": &MongoDBConfig{ + Host: "nonexistent", + Port: 27017, + User: "invalid", + Password: "invalid", + }, + }, + } + + collector := NewMongoDBCollector(cfg) + err := collector.Collect(context.Background()) + assert.Error(t, err) +}