diff --git a/cmd/gvproxy/config.go b/cmd/gvproxy/config.go index 462ec4f0a..8551cbf9f 100644 --- a/cmd/gvproxy/config.go +++ b/cmd/gvproxy/config.go @@ -12,6 +12,7 @@ import ( "slices" "strings" + "github.com/containers/gvisor-tap-vsock/pkg/notification" "github.com/containers/gvisor-tap-vsock/pkg/types" log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v3" @@ -26,25 +27,26 @@ const ( ) type GvproxyArgs struct { - config string - endpoints arrayFlags - debug bool - mtu int - sshPort int - vpnkitSocket string - qemuSocket string - bessSocket string - stdioSocket string - vfkitSocket string - forwardSocket arrayFlags - forwardDest arrayFlags - forwardUser arrayFlags - forwardIdentify arrayFlags - pidFile string - pcapFile string - logFile string - servicesEndpoint string - ec2MetadataAccess bool + config string + endpoints arrayFlags + debug bool + mtu int + sshPort int + vpnkitSocket string + qemuSocket string + bessSocket string + stdioSocket string + vfkitSocket string + notificationSocket string + forwardSocket arrayFlags + forwardDest arrayFlags + forwardUser arrayFlags + forwardIdentify arrayFlags + pidFile string + pcapFile string + logFile string + servicesEndpoint string + ec2MetadataAccess bool } type GvproxyConfig struct { @@ -58,11 +60,13 @@ type GvproxyConfig struct { Stdio string `yaml:"stdio,omitempty"` Vfkit string `yaml:"vfkit,omitempty"` } `yaml:"interfaces,omitempty"` - Forwards []GvproxyConfigForward `yaml:"forwards,omitempty"` - PIDFile string `yaml:"pid-file,omitempty"` - LogFile string `yaml:"log-file,omitempty"` - Services string `yaml:"services,omitempty"` - Ec2MetadataAccess bool `yaml:"ec2-metadata-access,omitempty"` + Forwards []GvproxyConfigForward `yaml:"forwards,omitempty"` + PIDFile string `yaml:"pid-file,omitempty"` + LogFile string `yaml:"log-file,omitempty"` + Services string `yaml:"services,omitempty"` + Ec2MetadataAccess bool `yaml:"ec2-metadata-access,omitempty"` + NotificationSocket string `yaml:"notification,omitempty"` + NotificationSender *notification.NotificationSender `yaml:"-"` } type GvproxyConfigForward struct { @@ -130,6 +134,7 @@ func GvproxyArgParse(flagSet *flag.FlagSet, args *GvproxyArgs, argv []string) (* flagSet.StringVar(&args.logFile, "log-file", "", "Output log messages (logrus) to a given file path") flagSet.StringVar(&args.servicesEndpoint, "services", "", "Exposes the same HTTP API as the --listen flag, without the /connect endpoint") flagSet.BoolVar(&args.ec2MetadataAccess, "ec2-metadata-access", false, "Permits access to EC2 Metadata Service (TCP only)") + flagSet.StringVar(&args.notificationSocket, "notification", "", "Socket to be used to send network-ready notifications") if err := flagSet.Parse(argv); err != nil { return nil, err } @@ -237,6 +242,17 @@ func GvproxyConfigure(config *GvproxyConfig, args *GvproxyArgs, version string) if args.pidFile != "" { config.PIDFile = args.pidFile } + if args.notificationSocket != "" { + log.Debugf("notification socket: %s", args.notificationSocket) + uri, err := url.Parse(args.notificationSocket) + if err != nil { + return config, fmt.Errorf("invalid value for notification listen address: %w", err) + } + if uri.Scheme != "unix" { + return config, errors.New("notification listen address must be unix:// address") + } + config.NotificationSocket = uri.Path + } if len(args.endpoints) > 0 { config.Listen = args.endpoints } diff --git a/cmd/gvproxy/main.go b/cmd/gvproxy/main.go index af77ce5a2..e657439f1 100644 --- a/cmd/gvproxy/main.go +++ b/cmd/gvproxy/main.go @@ -18,8 +18,10 @@ import ( "time" "github.com/containers/gvisor-tap-vsock/pkg/net/stdio" + "github.com/containers/gvisor-tap-vsock/pkg/notification" "github.com/containers/gvisor-tap-vsock/pkg/sshclient" "github.com/containers/gvisor-tap-vsock/pkg/transport" + "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork" "github.com/containers/winquit/pkg/winquit" humanize "github.com/dustin/go-humanize" @@ -126,6 +128,15 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { } log.Info("waiting for clients...") + // Start the notification sender in a goroutine + notificationSender := notification.NewNotificationSender(config.NotificationSocket) + if config.NotificationSocket != "" { + g.Go(func() error { + notificationSender.Start(ctx) + return nil + }) + } + vn.SetNotificationSender(notificationSender) for _, endpoint := range config.Listen { log.Infof("listening %s", endpoint) ln, err := transport.Listen(endpoint) @@ -134,6 +145,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { } httpServe(ctx, g, ln, withProfiler(vn)) } + notificationSender.Send(types.NotificationMessage{NotificationType: types.Ready}) if config.Services != "" { log.Infof("enabling services API. Listening %s", config.Services) @@ -172,6 +184,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { if config.Interfaces.VPNKit != "" { vpnkitListener, err := transport.Listen(config.Interfaces.VPNKit) if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("vpnkit listen error: %w", err) } g.Go(func() error { @@ -185,6 +198,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { } conn, err := vpnkitListener.Accept() if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) log.Errorf("vpnkit accept error: %s", err) continue } @@ -199,6 +213,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { if config.Interfaces.Qemu != "" { qemuListener, err := transport.Listen(config.Interfaces.Qemu) if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("qemu listen error: %w", err) } @@ -213,6 +228,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { g.Go(func() error { conn, err := qemuListener.Accept() if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("qemu accept error: %w", err) } return vn.AcceptQemu(ctx, conn) @@ -222,6 +238,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { if config.Interfaces.Bess != "" { bessListener, err := transport.Listen(config.Interfaces.Bess) if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("bess listen error: %w", err) } @@ -236,6 +253,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { g.Go(func() error { conn, err := bessListener.Accept() if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("bess accept error: %w", err) } return vn.AcceptBess(ctx, conn) @@ -245,6 +263,7 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { if config.Interfaces.Vfkit != "" { conn, err := transport.ListenUnixgram(config.Interfaces.Vfkit) if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("vfkit listen error: %w", err) } @@ -260,8 +279,10 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { g.Go(func() error { vfkitConn, err := transport.AcceptVfkit(conn) if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) return fmt.Errorf("vfkit accept error: %w", err) } + return vn.AcceptVfkit(ctx, vfkitConn) }) } @@ -269,7 +290,11 @@ func run(ctx context.Context, g *errgroup.Group, config *GvproxyConfig) error { if config.Interfaces.Stdio != "" { g.Go(func() error { conn := stdio.GetStdioConn() - return vn.AcceptStdio(ctx, conn) + err := vn.AcceptStdio(ctx, conn) + if err != nil { + notificationSender.Send(types.NotificationMessage{NotificationType: types.HypervisorError}) + } + return err }) } diff --git a/pkg/notification/sender.go b/pkg/notification/sender.go new file mode 100644 index 000000000..3b3c64003 --- /dev/null +++ b/pkg/notification/sender.go @@ -0,0 +1,75 @@ +package notification + +import ( + "context" + "encoding/json" + "fmt" + "net" + + "github.com/containers/gvisor-tap-vsock/pkg/types" + log "github.com/sirupsen/logrus" +) + +type NotificationSender struct { + notificationCh chan types.NotificationMessage + socket string +} + +func NewNotificationSender(socket string) *NotificationSender { + if socket == "" { + return &NotificationSender{ + socket: "", + notificationCh: nil, + } + } + + return &NotificationSender{ + socket: socket, + notificationCh: make(chan types.NotificationMessage, 100), + } +} + +func (s *NotificationSender) Send(notification types.NotificationMessage) { + if s.notificationCh == nil { + return + } + select { + case s.notificationCh <- notification: + default: + log.Warn("unable to send notification") + } +} + +func (s *NotificationSender) Start(ctx context.Context) { + if s.notificationCh == nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case notification := <-s.notificationCh: + if err := s.sendToSocket(notification); err != nil { + log.Errorf("failed to send notification: %v", err) + continue + } + } + } +} + +func (s *NotificationSender) sendToSocket(notification types.NotificationMessage) error { + if s.socket == "" { + return nil + } + conn, err := net.DialUnix("unix", nil, &net.UnixAddr{Name: s.socket, Net: "unix"}) + if err != nil { + return fmt.Errorf("cannot dial notification socket: %w", err) + } + defer conn.Close() + enc := json.NewEncoder(conn) + if err := enc.Encode(notification); err != nil { + return fmt.Errorf("failed to encode notification: %w", err) + } + return nil +} diff --git a/pkg/tap/switch.go b/pkg/tap/switch.go index 01b491b66..1b07628f7 100644 --- a/pkg/tap/switch.go +++ b/pkg/tap/switch.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "syscall" + "github.com/containers/gvisor-tap-vsock/pkg/notification" "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -48,6 +49,8 @@ type Switch struct { writeLock sync.Mutex gateway VirtualDevice + + notificationSender *notification.NotificationSender } func NewSwitch(debug bool, mtu int) *Switch { @@ -197,6 +200,12 @@ func (e *Switch) disconnect(id int, conn net.Conn) { for address, targetConn := range e.cam { if targetConn == id { + if e.notificationSender != nil { + e.notificationSender.Send(types.NotificationMessage{ + NotificationType: types.ConnectionClosed, + MacAddress: address.String(), + }) + } delete(e.cam, address) } } @@ -267,9 +276,17 @@ func (e *Switch) rxBuf(_ context.Context, id int, buf []byte) { eth := header.Ethernet(buf) e.camLock.Lock() + _, exists := e.cam[eth.SourceAddress()] e.cam[eth.SourceAddress()] = id e.camLock.Unlock() + if !exists && e.notificationSender != nil { + e.notificationSender.Send(types.NotificationMessage{ + NotificationType: types.ConnectionEstablished, + MacAddress: eth.SourceAddress().String(), + }) + } + if eth.DestinationAddress() != e.gateway.LinkAddress() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(buf), @@ -304,3 +321,7 @@ func protocolImplementation(protocol types.Protocol) protocol { return &hyperkitProtocol{} } } + +func (e *Switch) SetNotificationSender(notificationSender *notification.NotificationSender) { + e.notificationSender = notificationSender +} diff --git a/pkg/types/handshake.go b/pkg/types/handshake.go index e9aa78076..c5f0e6601 100644 --- a/pkg/types/handshake.go +++ b/pkg/types/handshake.go @@ -19,3 +19,18 @@ type UnexposeRequest struct { Local string `json:"local"` Protocol TransportProtocol `json:"protocol"` } + +type NotificationMessage struct { + NotificationType NotificationType `json:"notification_type"` + MacAddress string `json:"mac_address,omitempty"` +} + +type NotificationType string + +const ( + Ready NotificationType = "ready" + ConnectionEstablished NotificationType = "connection_established" + HypervisorWarning NotificationType = "hypervisor_warning" + HypervisorError NotificationType = "hypervisor_error" + ConnectionClosed NotificationType = "connection_closed" +) diff --git a/pkg/virtualnetwork/virtualnetwork.go b/pkg/virtualnetwork/virtualnetwork.go index 1fe03e9ea..a7dd08db2 100644 --- a/pkg/virtualnetwork/virtualnetwork.go +++ b/pkg/virtualnetwork/virtualnetwork.go @@ -8,6 +8,7 @@ import ( "net/http" "os" + "github.com/containers/gvisor-tap-vsock/pkg/notification" "github.com/containers/gvisor-tap-vsock/pkg/tap" "github.com/containers/gvisor-tap-vsock/pkg/types" "gvisor.dev/gvisor/pkg/tcpip" @@ -28,6 +29,10 @@ type VirtualNetwork struct { ipPool *tap.IPPool } +func (n *VirtualNetwork) SetNotificationSender(notificationSender *notification.NotificationSender) { + n.networkSwitch.SetNotificationSender(notificationSender) +} + func New(configuration *types.Configuration) (*VirtualNetwork, error) { _, subnet, err := net.ParseCIDR(configuration.Subnet) if err != nil {