diff --git a/cns/deviceplugin/server.go b/cns/deviceplugin/server.go index c5243368b9..725f57c543 100644 --- a/cns/deviceplugin/server.go +++ b/cns/deviceplugin/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "os" "time" "github.com/pkg/errors" @@ -45,6 +46,11 @@ func (s *Server) Run(ctx context.Context) error { defer cancel() s.shutdownCh = childCtx.Done() + // remove the socket if it already exists + if err := os.Remove(s.address); err != nil && !os.IsNotExist(err) { + return errors.Wrap(err, "error removing socket") + } + l, err := net.Listen("unix", s.address) if err != nil { return errors.Wrap(err, "error listening on socket") diff --git a/cns/deviceplugin/server_test.go b/cns/deviceplugin/server_test.go new file mode 100644 index 0000000000..b3d044baca --- /dev/null +++ b/cns/deviceplugin/server_test.go @@ -0,0 +1,69 @@ +package deviceplugin + +import ( + "context" + "net" + "os" + "path/filepath" + "testing" + "time" + + "go.uber.org/zap" +) + +type mockDeviceCounter struct { + count int +} + +func (m *mockDeviceCounter) getDeviceCount() int { + return m.count +} + +func TestServer_Run_CleansUpExistingSocket(t *testing.T) { + // Create a temporary directory for the socket + socketPath := filepath.Join("testdata", "test.sock") + defer os.Remove(socketPath) + + // Create a dummy file at the socket path to simulate a stale socket + if err := os.WriteFile(socketPath, []byte("stale socket"), 0o600); err != nil { + t.Fatalf("failed to create dummy socket file: %v", err) + } + + logger := zap.NewNop() + counter := &mockDeviceCounter{count: 1} + server := NewServer(logger, socketPath, counter, time.Second) + + // Create a context that we can cancel to stop the server + ctx, cancel := context.WithCancel(context.Background()) + + // Run the server in a goroutine + errChan := make(chan error) + go func() { + errChan <- server.Run(ctx) + }() + + // Wait for the server to start up, delete the pre-existing file and recreate it as a socket + // We verify this by trying to connect to the socket repeatedly until success or timeout + var conn net.Conn + var err error + // Retry for up to 2 seconds + for start := time.Now(); time.Since(start) < 2*time.Second; time.Sleep(200 * time.Millisecond) { + conn, err = net.Dial("unix", socketPath) + if err == nil { + conn.Close() + break + } + } + + if err != nil { + t.Errorf("failed to connect to socket: %v", err) + } + + // Stop the server + cancel() + + // Wait for Run to return + if err := <-errChan; err != nil { + t.Errorf("server.Run returned error: %v", err) + } +} diff --git a/cns/deviceplugin/socketwatcher.go b/cns/deviceplugin/socketwatcher.go index 05b7df602b..5d1c7d621b 100644 --- a/cns/deviceplugin/socketwatcher.go +++ b/cns/deviceplugin/socketwatcher.go @@ -56,7 +56,12 @@ func (s *SocketWatcher) WatchSocket(ctx context.Context, socket string) <-chan s socketChan := make(chan struct{}) s.socketChans[socket] = socketChan go func() { - defer close(socketChan) + defer func() { + s.mutex.Lock() + delete(s.socketChans, socket) + s.mutex.Unlock() + close(socketChan) + }() ticker := time.NewTicker(s.options.statInterval) defer ticker.Stop() for { diff --git a/cns/deviceplugin/socketwatcher_test.go b/cns/deviceplugin/socketwatcher_test.go index e987358481..4275734726 100644 --- a/cns/deviceplugin/socketwatcher_test.go +++ b/cns/deviceplugin/socketwatcher_test.go @@ -12,12 +12,23 @@ import ( ) func TestWatchContextCancelled(t *testing.T) { + socket := filepath.Join("testdata", "socket.sock") + f, createErr := os.Create(socket) + if createErr != nil { + t.Fatalf("error creating test file %s: %v", socket, createErr) + } + f.Close() + defer os.Remove(socket) + ctx, cancel := context.WithCancel(context.Background()) - logger, _ := zap.NewDevelopment() + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } s := deviceplugin.NewSocketWatcher(logger) done := make(chan struct{}) go func(done chan struct{}) { - <-s.WatchSocket(ctx, "testdata/socket.sock") + <-s.WatchSocket(ctx, socket) close(done) }(done) @@ -39,19 +50,18 @@ func TestWatchContextCancelled(t *testing.T) { } func TestWatchSocketDeleted(t *testing.T) { - // Create a temporary directory - tempDir, err := os.MkdirTemp("", "socket-watcher-test-") - if err != nil { - t.Fatalf("error creating temporary directory: %v", err) + socket := filepath.Join("testdata", "to-be-deleted.sock") + f, createErr := os.Create(socket) + if createErr != nil { + t.Fatalf("error creating test file %s: %v", socket, createErr) } - defer os.RemoveAll(tempDir) // Ensure the directory is cleaned up + f.Close() + defer os.Remove(socket) - socket := filepath.Join(tempDir, "to-be-deleted.sock") - if _, err := os.Create(socket); err != nil { - t.Fatalf("error creating test file %s: %v", socket, err) + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("failed to create logger: %v", err) } - - logger, _ := zap.NewDevelopment() s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(time.Second)) done := make(chan struct{}) go func(done chan struct{}) { @@ -79,19 +89,18 @@ func TestWatchSocketDeleted(t *testing.T) { } func TestWatchSocketTwice(t *testing.T) { - // Create a temporary directory - tempDir, err := os.MkdirTemp("", "socket-watcher-test-") - if err != nil { - t.Fatalf("error creating temporary directory: %v", err) + socket := filepath.Join("testdata", "to-be-deleted.sock") + f, createErr := os.Create(socket) + if createErr != nil { + t.Fatalf("error creating test file %s: %v", socket, createErr) } - defer os.RemoveAll(tempDir) // Ensure the directory is cleaned up + f.Close() + defer os.Remove(socket) - socket := filepath.Join(tempDir, "to-be-deleted.sock") - if _, err := os.Create(socket); err != nil { - t.Fatalf("error creating test file %s: %v", socket, err) + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("failed to create logger: %v", err) } - - logger, _ := zap.NewDevelopment() s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(time.Second)) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -134,3 +143,61 @@ func TestWatchSocketTwice(t *testing.T) { t.Fatal("socket watcher is still watching 5 seconds after file is deleted") } } + +func TestWatchSocketCleanup(t *testing.T) { + socket := filepath.Join("testdata", "to-be-deleted.sock") + f, createErr := os.Create(socket) + if createErr != nil { + t.Fatalf("error creating test file %s: %v", socket, createErr) + } + f.Close() + defer os.Remove(socket) + + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + // Use a short interval for faster test execution + s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(100*time.Millisecond)) + + // 1. Watch the socket + ch1 := s.WatchSocket(context.Background(), socket) + + // Verify it's open + select { + case <-ch1: + t.Fatal("channel should be open initially") + default: + } + + // 2. Delete the socket to trigger watcher exit + if removeErr := os.Remove(socket); removeErr != nil { + t.Fatalf("failed to remove socket: %v", removeErr) + } + + // 3. Wait for ch1 to close + select { + case <-ch1: + // Expected + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for watcher to detect socket deletion") + } + + // 4. Recreate the socket + f, err = os.Create(socket) + if err != nil { + t.Fatalf("error recreating test file %s: %v", socket, err) + } + f.Close() + + // 5. Watch the socket again + ch2 := s.WatchSocket(context.Background(), socket) + + // 6. Verify ch2 is open + select { + case <-ch2: + t.Fatal("channel is closed but expected to be open") + case <-time.After(200 * time.Millisecond): + // Wait for at least one tick to ensure the watcher has had a chance to run. + } +} diff --git a/cns/deviceplugin/testdata/.gitignore b/cns/deviceplugin/testdata/.gitignore new file mode 100644 index 0000000000..c74d682773 --- /dev/null +++ b/cns/deviceplugin/testdata/.gitignore @@ -0,0 +1 @@ +*.sock diff --git a/cns/deviceplugin/testdata/socket.sock b/cns/deviceplugin/testdata/socket.sock deleted file mode 100644 index e69de29bb2..0000000000