Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cns/deviceplugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"os"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -45,6 +46,11 @@ func (s *Server) Run(ctx context.Context) error {
defer cancel()
s.shutdownCh = childCtx.Done()

// remove the socket if it already exists
if err := os.Remove(s.address); err != nil && !os.IsNotExist(err) {
return errors.Wrap(err, "error removing socket")
}

l, err := net.Listen("unix", s.address)
if err != nil {
return errors.Wrap(err, "error listening on socket")
Expand Down
69 changes: 69 additions & 0 deletions cns/deviceplugin/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package deviceplugin

import (
"context"
"net"
"os"
"path/filepath"
"testing"
"time"

"go.uber.org/zap"
)

type mockDeviceCounter struct {
count int
}

func (m *mockDeviceCounter) getDeviceCount() int {
return m.count
}

func TestServer_Run_CleansUpExistingSocket(t *testing.T) {
// Create a temporary directory for the socket
socketPath := filepath.Join("testdata", "test.sock")
defer os.Remove(socketPath)

// Create a dummy file at the socket path to simulate a stale socket
if err := os.WriteFile(socketPath, []byte("stale socket"), 0o600); err != nil {
t.Fatalf("failed to create dummy socket file: %v", err)
}

logger := zap.NewNop()
counter := &mockDeviceCounter{count: 1}
server := NewServer(logger, socketPath, counter, time.Second)

// Create a context that we can cancel to stop the server
ctx, cancel := context.WithCancel(context.Background())

// Run the server in a goroutine
errChan := make(chan error)
go func() {
errChan <- server.Run(ctx)
}()

// Wait for the server to start up, delete the pre-existing file and recreate it as a socket
// We verify this by trying to connect to the socket repeatedly until success or timeout
var conn net.Conn
var err error
// Retry for up to 2 seconds
for start := time.Now(); time.Since(start) < 2*time.Second; time.Sleep(200 * time.Millisecond) {
conn, err = net.Dial("unix", socketPath)
if err == nil {
conn.Close()
break
}
}

if err != nil {
t.Errorf("failed to connect to socket: %v", err)
}

// Stop the server
cancel()

// Wait for Run to return
if err := <-errChan; err != nil {
t.Errorf("server.Run returned error: %v", err)
}
}
7 changes: 6 additions & 1 deletion cns/deviceplugin/socketwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ func (s *SocketWatcher) WatchSocket(ctx context.Context, socket string) <-chan s
socketChan := make(chan struct{})
s.socketChans[socket] = socketChan
go func() {
defer close(socketChan)
defer func() {
s.mutex.Lock()
delete(s.socketChans, socket)
s.mutex.Unlock()
close(socketChan)
}()
ticker := time.NewTicker(s.options.statInterval)
defer ticker.Stop()
for {
Expand Down
111 changes: 89 additions & 22 deletions cns/deviceplugin/socketwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@ import (
)

func TestWatchContextCancelled(t *testing.T) {
socket := filepath.Join("testdata", "socket.sock")
f, createErr := os.Create(socket)
if createErr != nil {
t.Fatalf("error creating test file %s: %v", socket, createErr)
}
f.Close()
defer os.Remove(socket)

ctx, cancel := context.WithCancel(context.Background())
logger, _ := zap.NewDevelopment()
logger, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("failed to create logger: %v", err)
}
s := deviceplugin.NewSocketWatcher(logger)
done := make(chan struct{})
go func(done chan struct{}) {
<-s.WatchSocket(ctx, "testdata/socket.sock")
<-s.WatchSocket(ctx, socket)
close(done)
}(done)

Expand All @@ -39,19 +50,18 @@ func TestWatchContextCancelled(t *testing.T) {
}

func TestWatchSocketDeleted(t *testing.T) {
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "socket-watcher-test-")
if err != nil {
t.Fatalf("error creating temporary directory: %v", err)
socket := filepath.Join("testdata", "to-be-deleted.sock")
f, createErr := os.Create(socket)
if createErr != nil {
t.Fatalf("error creating test file %s: %v", socket, createErr)
}
defer os.RemoveAll(tempDir) // Ensure the directory is cleaned up
f.Close()
defer os.Remove(socket)

socket := filepath.Join(tempDir, "to-be-deleted.sock")
if _, err := os.Create(socket); err != nil {
t.Fatalf("error creating test file %s: %v", socket, err)
logger, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("failed to create logger: %v", err)
}

logger, _ := zap.NewDevelopment()
s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(time.Second))
done := make(chan struct{})
go func(done chan struct{}) {
Expand Down Expand Up @@ -79,19 +89,18 @@ func TestWatchSocketDeleted(t *testing.T) {
}

func TestWatchSocketTwice(t *testing.T) {
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "socket-watcher-test-")
if err != nil {
t.Fatalf("error creating temporary directory: %v", err)
socket := filepath.Join("testdata", "to-be-deleted.sock")
f, createErr := os.Create(socket)
if createErr != nil {
t.Fatalf("error creating test file %s: %v", socket, createErr)
}
defer os.RemoveAll(tempDir) // Ensure the directory is cleaned up
f.Close()
defer os.Remove(socket)

socket := filepath.Join(tempDir, "to-be-deleted.sock")
if _, err := os.Create(socket); err != nil {
t.Fatalf("error creating test file %s: %v", socket, err)
logger, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("failed to create logger: %v", err)
}

logger, _ := zap.NewDevelopment()
s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(time.Second))
done1 := make(chan struct{})
done2 := make(chan struct{})
Expand Down Expand Up @@ -134,3 +143,61 @@ func TestWatchSocketTwice(t *testing.T) {
t.Fatal("socket watcher is still watching 5 seconds after file is deleted")
}
}

func TestWatchSocketCleanup(t *testing.T) {
socket := filepath.Join("testdata", "to-be-deleted.sock")
f, createErr := os.Create(socket)
if createErr != nil {
t.Fatalf("error creating test file %s: %v", socket, createErr)
}
f.Close()
defer os.Remove(socket)

logger, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("failed to create logger: %v", err)
}
// Use a short interval for faster test execution
s := deviceplugin.NewSocketWatcher(logger, deviceplugin.SocketWatcherStatInterval(100*time.Millisecond))

// 1. Watch the socket
ch1 := s.WatchSocket(context.Background(), socket)

// Verify it's open
select {
case <-ch1:
t.Fatal("channel should be open initially")
default:
}

// 2. Delete the socket to trigger watcher exit
if removeErr := os.Remove(socket); removeErr != nil {
t.Fatalf("failed to remove socket: %v", removeErr)
}

// 3. Wait for ch1 to close
select {
case <-ch1:
// Expected
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for watcher to detect socket deletion")
}

// 4. Recreate the socket
f, err = os.Create(socket)
if err != nil {
t.Fatalf("error recreating test file %s: %v", socket, err)
}
f.Close()

// 5. Watch the socket again
ch2 := s.WatchSocket(context.Background(), socket)

// 6. Verify ch2 is open
select {
case <-ch2:
t.Fatal("channel is closed but expected to be open")
case <-time.After(200 * time.Millisecond):
// Wait for at least one tick to ensure the watcher has had a chance to run.
}
}
1 change: 1 addition & 0 deletions cns/deviceplugin/testdata/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.sock
Empty file.
Loading