diff --git a/pkg/services/forwarder/ports.go b/pkg/services/forwarder/ports.go index 50d50b880..528fc7dfc 100644 --- a/pkg/services/forwarder/ports.go +++ b/pkg/services/forwarder/ports.go @@ -6,16 +6,21 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/http" + "net/url" + "os" "sort" "strconv" "strings" "sync" + "time" "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/google/tcpproxy" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -50,31 +55,174 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote return errors.New("proxy already running") } - split := strings.Split(remote, ":") - if len(split) != 2 { - return errors.New("invalid remote addr") - } - port, err := strconv.Atoi(split[1]) - if err != nil { - return err - } - address := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(split[0]).To4()), - Port: uint16(port), - } - switch protocol { case types.UNIX: + // parse URI for remote + remoteURI, err := url.Parse(remote) + if err != nil { + return fmt.Errorf("failed to parse remote uri :%s : %w", remote, err) + } + + // build the address from remoteURI + remoteAddr := fmt.Sprintf("%s:%s", remoteURI.Hostname(), remoteURI.Port()) + + // dialFn opens remote connection for the proxy + var dialFn func(ctx context.Context, network, addr string) (conn net.Conn, e error) + + // dialFn is set based on the protocol provided by remoteURI.Scheme + switch remoteURI.Scheme { + case "ssh-tunnel": // unix-to-unix proxy (over SSH) + // query string to map for the remoteURI contains ssh config info + remoteQuery := remoteURI.Query() + + // username + sshuser := firstValueOrEmpty(remoteQuery["user"]) + if sshuser == "" { + return fmt.Errorf("user not provided for unix-ssh connection") + } + + // key + sshkeypath := firstValueOrEmpty(remoteQuery["key"]) + if sshkeypath == "" { + return fmt.Errorf("key not provided for unix-ssh connection") + } + + sshkeyBytes, err := ioutil.ReadFile(sshkeypath) + if err != nil { + return fmt.Errorf("failed to read ssh key: %s: %w", sshkeypath, err) + } + + // passphrase + passphrase := firstValueOrEmpty(remoteQuery["passphrase"]) + + var sshsigner ssh.Signer + + if passphrase == "" { + sshsigner, err = ssh.ParsePrivateKey(sshkeyBytes) + } else { + sshsigner, err = ssh.ParsePrivateKeyWithPassphrase(sshkeyBytes, []byte(passphrase)) + } + + // parse private key error? + if err != nil { + return fmt.Errorf("failed to parse ssh key: %s: %w", sshkeypath, err) + } + + // default ssh port if not set + if remoteURI.Port() == "" { + remoteAddr = fmt.Sprintf("%s:%s", remoteURI.Hostname(), "22") + } + + // build address + address, err := tcpipAddress(1, remoteAddr) + if err != nil { + return err + } + + // check the remoteURI path provided for nonsense + if remoteURI.Path == "" || remoteURI.Path == "/" { + return fmt.Errorf("remote uri must contain a path to a socket file") + } + + // captured and used by dialFn + var tcpConn *gonet.TCPConn + var sshClient *ssh.Client + var connLock sync.Mutex + + // handles getting underlying ssh connection, having this outside of + // dialFn limits connLock to only the parts it's needed for in a way + // that doesn't get racy. + sshConnFn := func(ctx context.Context, network, addr string) (client *ssh.Client, err error) { + connLock.Lock() + defer connLock.Unlock() + + // check underlying tcpConn to see if it's closed + if tcpConn != nil { + if _, err := tcpConn.Read(make([]byte, 0)); err == io.EOF { + tcpConn = nil // set back to nil to force reconnect + } + } + + // connect or reconnect to ssh + if tcpConn == nil || sshClient == nil { + // underlying connection to endpoint for the ssh client + tcpConn, err := gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber) + if err != nil { + return sshClient, err + } + + // ssh client config that uses key authentication + config := &ssh.ClientConfig{ + User: sshuser, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(sshsigner), + }, + // #nosec G106 + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyAlgorithms: []string{ + ssh.KeyAlgoRSA, + ssh.KeyAlgoDSA, + ssh.KeyAlgoECDSA256, + ssh.KeyAlgoECDSA384, + ssh.KeyAlgoECDSA521, + ssh.KeyAlgoED25519, + }, + Timeout: 5 * time.Second, + } + + // get an sshConn using the underlying gonet.TCPConn + sshConn, chans, reqs, err := ssh.NewClientConn(tcpConn, addr, config) + if err != nil { + return sshClient, err + } + + // build an ssh client using sshConn + sshClient = ssh.NewClient(sshConn, chans, reqs) + } + + return sshClient, err + } + + // the dialFn for unix-to-unix over SSH + dialFn = func(ctx context.Context, network, addr string) (conn net.Conn, e error) { + // check or create new ssh connection + sshClient, err = sshConnFn(ctx, network, addr) + if err != nil { + return nil, err + } + + // connection using sshclient's dialer + return sshClient.Dial("unix", remoteURI.Path) + } + + case "tcp": // unix-to-tcp proxy + // build address + address, err := tcpipAddress(1, remoteAddr) + if err != nil { + return err + } + + dialFn = func(ctx context.Context, network, addr string) (conn net.Conn, e error) { + return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber) + } + + default: + return fmt.Errorf("remote protocol for unix forwarder is not implemented: %s", remoteURI.Scheme) + } + + // build the tcp proxy var p tcpproxy.Proxy p.ListenFunc = func(_, socketPath string) (net.Listener, error) { + // remove existing socket file + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + return nil, err + } + return net.Listen("unix", socketPath) // override tcp to use unix socket } p.AddRoute(local, &tcpproxy.DialProxy{ - Addr: remote, - DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) { - return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber) - }, + Addr: remoteAddr, + DialContext: dialFn, }) if err := p.Start(); err != nil { return err @@ -91,7 +239,13 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote Remote: remote, underlying: &p, } + case types.UDP: + address, err := tcpipAddress(1, remote) + if err != nil { + return err + } + addr, err := net.ResolveUDPAddr("udp", local) if err != nil { return err @@ -114,6 +268,11 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote underlying: p, } case types.TCP: + address, err := tcpipAddress(1, remote) + if err != nil { + return err + } + var p tcpproxy.Proxy p.AddRoute(local, &tcpproxy.DialProxy{ Addr: remote, @@ -186,12 +345,21 @@ func (f *PortsForwarder) Mux() http.Handler { if req.Protocol == "" { req.Protocol = types.TCP } - remote, err := remote(req, r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + + // contains unparsed remote field + remoteAddr := req.Remote + + // TCP and UDP rely on remote() to preparse the remote field + if req.Protocol != types.UNIX { + var err error + remoteAddr, err = remote(req, r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } } - if err := f.Expose(req.Protocol, req.Local, remote); err != nil { + + if err := f.Expose(req.Protocol, req.Local, remoteAddr); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -234,3 +402,35 @@ func remote(req types.ExposeRequest, ip string) (string, error) { } return req.Remote, nil } + +// helper function for parsed URL query strings +func firstValueOrEmpty(x []string) string { + if len(x) > 0 { + return x[0] + } + return "" +} + +// helper function to build tcpip address +func tcpipAddress(nicID tcpip.NICID, remote string) (address tcpip.FullAddress, err error) { + + // build the address manual way + split := strings.Split(remote, ":") + if len(split) != 2 { + return address, errors.New("invalid remote addr") + } + + port, err := strconv.Atoi(split[1]) + if err != nil { + return address, err + + } + + address = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(net.ParseIP(split[0]).To4()), + Port: uint16(port), + } + + return address, err +} diff --git a/test/port_forwarding_test.go b/test/port_forwarding_test.go index 31f2dd841..62c4408aa 100644 --- a/test/port_forwarding_test.go +++ b/test/port_forwarding_test.go @@ -2,6 +2,7 @@ package e2e import ( "context" + "fmt" "io" "net" "net/http" @@ -191,7 +192,7 @@ var _ = Describe("port forwarding", func() { unix2tcpfwdsock, _ := filepath.Abs(filepath.Join(tmpDir, "podman-unix-to-unix-forwarding.sock")) - out, err := sshExec(`curl http://gateway.containers.internal/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2tcpfwdsock + `","remote":"192.168.127.2:8080"}'`) + out, err := sshExec(`curl http://gateway.containers.internal/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2tcpfwdsock + `","remote":"tcp://192.168.127.2:8080"}'`) Expect(string(out)).Should(Equal("")) Expect(err).ShouldNot(HaveOccurred()) @@ -215,4 +216,43 @@ var _ = Describe("port forwarding", func() { g.Expect(resp.StatusCode).To(Equal(http.StatusOK)) }).Should(Succeed()) }) + + It("should expose and reach rootless podman API using unix to unix forwarding over ssh", func() { + if runtime.GOOS == "windows" { + Skip("AF_UNIX not supported on Windows") + } + + unix2unixfwdsock, _ := filepath.Abs(filepath.Join(tmpDir, "podman-unix-to-unix-forwarding.sock")) + + remoteuri := fmt.Sprintf(`ssh-tunnel://%s:%d%s?user=root&key=%s`, "192.168.127.2", 22, podmanSock, privateKeyFile) + _, err := sshExec(`curl http://192.168.127.1/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2unixfwdsock + `","remote":"` + remoteuri + `"}'`) + Expect(err).ShouldNot(HaveOccurred()) + + Eventually(func(g Gomega) { + sockfile, err := os.Stat(unix2unixfwdsock) + g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(sockfile.Mode().Type().String()).To(Equal(os.ModeSocket.String())) + }).Should(Succeed()) + + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", unix2unixfwdsock) + }, + }, + } + + Eventually(func(g Gomega) { + resp, err := httpClient.Get("http://host/_ping") + g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(resp.StatusCode).To(Equal(http.StatusOK)) + g.Expect(resp.ContentLength).To(Equal(int64(2))) + + reply := make([]byte, resp.ContentLength) + _, err = io.ReadAtLeast(resp.Body, reply, len(reply)) + + g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(string(reply)).To(Equal("OK")) + }).Should(Succeed()) + }) })