Skip to content

Commit b456117

Browse files
committed
pkg/driver/qemu: Wait for SSH to be ready in AdditionalSetupForSSH()
- pkg/sshutil: Add `WaitSSHReady()` `WaitSSHReady` waits until the SSH port is ready to accept connections. The `dialContext` function is used to create a connection to the SSH server. The `addressForLogging` parameter is used for logging purposes. The `timeoutSeconds` parameter specifies the maximum number of seconds to wait. Signed-off-by: Norio Nomura <[email protected]>
1 parent 6ead149 commit b456117

File tree

3 files changed

+78
-23
lines changed

3 files changed

+78
-23
lines changed

pkg/driver/qemu/qemu_driver.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/lima-vm/lima/v2/pkg/osutil"
3838
"github.com/lima-vm/lima/v2/pkg/ptr"
3939
"github.com/lima-vm/lima/v2/pkg/reflectutil"
40+
"github.com/lima-vm/lima/v2/pkg/sshutil"
4041
"github.com/lima-vm/lima/v2/pkg/version/versionutil"
4142
)
4243

@@ -721,6 +722,17 @@ func (l *LimaQemuDriver) ForwardGuestAgent() bool {
721722
return l.vSockPort == 0 && l.virtioPort == ""
722723
}
723724

724-
func (l *LimaQemuDriver) AdditionalSetupForSSH(_ context.Context) error {
725+
func (l *LimaQemuDriver) AdditionalSetupForSSH(ctx context.Context) error {
726+
// Ensure that the QEMU instance is ready to accept SSH connections.
727+
time.Sleep(10 * time.Second)
728+
// Wait until the port is available.
729+
addr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", l.SSHLocalPort))
730+
dialContext := func(ctx context.Context) (net.Conn, error) {
731+
dialer := net.Dialer{Timeout: 1 * time.Second}
732+
return dialer.DialContext(ctx, "tcp", addr)
733+
}
734+
if err := sshutil.WaitSSHReady(ctx, dialContext, addr, 600); err != nil {
735+
return err
736+
}
725737
return nil
726738
}

pkg/networks/usernet/gvproxy.go

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork"
2323
"github.com/sirupsen/logrus"
2424
"golang.org/x/sync/errgroup"
25+
26+
"github.com/lima-vm/lima/v2/pkg/sshutil"
2527
)
2628

2729
type GVisorNetstackOpts struct {
@@ -255,8 +257,7 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
255257
http.Error(w, err.Error(), http.StatusBadRequest)
256258
return
257259
}
258-
port := uint16(port16)
259-
addr := fmt.Sprintf("%s:%d", ip, port)
260+
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", uint16(port16)))
260261

261262
timeoutSeconds := 10
262263
if timeoutString := r.URL.Query().Get("timeout"); timeoutString != "" {
@@ -267,27 +268,14 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
267268
}
268269
timeoutSeconds = int(timeout16)
269270
}
270-
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds)*time.Second)
271-
defer cancel()
271+
dialContext := func(ctx context.Context) (net.Conn, error) {
272+
return n.DialContextTCP(ctx, addr)
273+
}
272274
// Wait until the port is available.
273-
for {
274-
conn, err := n.DialContextTCP(ctx, addr)
275-
if err == nil {
276-
conn.Close()
277-
logrus.Debugf("Port is available on %s", addr)
278-
w.WriteHeader(http.StatusOK)
279-
break
280-
}
281-
select {
282-
case <-ctx.Done():
283-
msg := fmt.Sprintf("timed out waiting for port to become available on %s", addr)
284-
logrus.Warn(msg)
285-
http.Error(w, msg, http.StatusRequestTimeout)
286-
return
287-
default:
288-
}
289-
logrus.Debugf("Waiting for port to become available on %s", addr)
290-
time.Sleep(1 * time.Second)
275+
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, timeoutSeconds); err != nil {
276+
http.Error(w, err.Error(), http.StatusRequestTimeout)
277+
} else {
278+
w.WriteHeader(http.StatusOK)
291279
}
292280
})
293281
return m

pkg/sshutil/sshutil.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"encoding/binary"
1111
"errors"
1212
"fmt"
13+
"io"
1314
"io/fs"
1415
"net"
1516
"os"
@@ -608,3 +609,57 @@ func findRegexpInSSHArgs(sshArgs []string, re *regexp.Regexp) string {
608609
}
609610
return ""
610611
}
612+
613+
// WaitSSHReady waits until the SSH port is ready to accept connections.
614+
// The dialContext function is used to create a connection to the SSH server.
615+
// The addressForLogging parameter is used for logging purposes.
616+
// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
617+
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addressForLogging string, timeoutSeconds int) error {
618+
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
619+
defer cancel()
620+
// Wait until the SSH port is available.
621+
for {
622+
conn, err := dialContext(ctx)
623+
if err == nil {
624+
// Check if the SSH banner is received.
625+
buf := make([]byte, 128)
626+
RetryRead:
627+
for {
628+
if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
629+
logrus.WithError(err).Debugf("Failed to set read deadline on SSH connection to %s", addressForLogging)
630+
break RetryRead
631+
}
632+
n, err := conn.Read(buf)
633+
if err == nil {
634+
if bytes.HasPrefix(buf[:n], []byte("SSH-")) {
635+
conn.Close()
636+
logrus.Debugf("SSH port is available on %s", addressForLogging)
637+
return nil // SSH ready!
638+
}
639+
return fmt.Errorf("invalid SSH banner from %s: %q", addressForLogging, string(buf[:n]))
640+
} else if ne, ok := err.(net.Error); ok && ne.Timeout() {
641+
logrus.Debugf("Timeout reading SSH banner from %s, retrying...", addressForLogging)
642+
} else if errors.Is(err, io.EOF) {
643+
logrus.Debugf("EOF reading SSH banner from %s, retrying...", addressForLogging)
644+
break RetryRead
645+
} else {
646+
return fmt.Errorf("failed to read SSH banner from %s: %w", addressForLogging, err)
647+
}
648+
select {
649+
case <-ctx.Done():
650+
break RetryRead
651+
case <-time.After(1 * time.Second):
652+
continue
653+
}
654+
}
655+
conn.Close()
656+
}
657+
logrus.Debugf("Waiting for SSH port to accept connections on %s", addressForLogging)
658+
select {
659+
case <-ctx.Done():
660+
return fmt.Errorf("failed to waiting for SSH port to become available on %s: %w", addressForLogging, ctx.Err())
661+
case <-time.After(1 * time.Second):
662+
continue
663+
}
664+
}
665+
}

0 commit comments

Comments
 (0)