Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tuic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
Expand All @@ -23,6 +24,7 @@ import (
type ClientOptions struct {
Context context.Context
Dialer N.Dialer
Logger logger.Logger
ServerAddress M.Socksaddr
TLSConfig aTLS.Config
UUID [16]byte
Expand All @@ -36,6 +38,7 @@ type ClientOptions struct {
type Client struct {
ctx context.Context
dialer N.Dialer
logger logger.Logger
serverAddr M.Socksaddr
tlsConfig aTLS.Config
quicConfig *quic.Config
Expand Down Expand Up @@ -69,6 +72,7 @@ func NewClient(options ClientOptions) (*Client, error) {
return &Client{
ctx: options.Context,
dialer: options.Dialer,
logger: options.Logger,
serverAddr: options.ServerAddress,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
Expand Down Expand Up @@ -125,8 +129,9 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
}()
if c.udpStream {
go c.loopUniStreams(conn)
} else {
go c.loopMessages(conn)
}
go c.loopMessages(conn)
go c.loopHeartbeats(conn)
c.conn = conn
return conn, nil
Expand Down
32 changes: 18 additions & 14 deletions tuic/client_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
return E.Cause(err, "decode UDP message")
}
conn.handleUDPMessage(message)
return nil
case CommandHeartbeat:
return nil
default:
return E.New("unknown command ", data[0])
if c.logger != nil {
c.logger.Warn("unknown command ", data[1])
}
}
return nil
}

func (c *Client) loopUniStreams(conn *clientQUICConnection) {
Expand Down Expand Up @@ -78,17 +78,21 @@ func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.Receive
return E.New("unknown version ", version)
}
command, _ := buffer.ReadByte()
if command != CommandPacket {
return E.New("unknown command ", command)
}
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
message := allocMessage()
err = readUDPMessage(message, reader)
if err != nil {
message.release()
return err
switch command {
case CommandPacket:
reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
message := allocMessage()
err = readUDPMessage(message, reader)
if err != nil {
message.release()
return err
}
conn.handleUDPMessage(message)
default:
if c.logger != nil {
c.logger.Warn("unknown command ", command)
}
}
conn.handleUDPMessage(message)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion tuic/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (m *udpMessage) pack() *buf.Buffer {
}

func (m *udpMessage) headerSize() int {
return 10 + AddressSerializer.AddrPortLen(m.destination)
return 12 + AddressSerializer.AddrPortLen(m.destination)
}

func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
Expand Down
25 changes: 2 additions & 23 deletions tuic/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ type ServiceOptions struct {
CongestionControl string
AuthTimeout time.Duration
ZeroRTTHandshake bool
Heartbeat time.Duration
UDPTimeout time.Duration
Handler ServiceHandler
// Deprecated: no longer used.
Heartbeat time.Duration
}

type ServiceHandler interface {
Expand All @@ -48,7 +49,6 @@ type Service[U comparable] struct {
ctx context.Context
logger logger.Logger
tlsConfig aTLS.ServerConfig
heartbeat time.Duration
quicConfig *quic.Config
userMap map[[16]byte]U
passwordMap map[U]string
Expand All @@ -64,9 +64,6 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
if options.AuthTimeout == 0 {
options.AuthTimeout = 3 * time.Second
}
if options.Heartbeat == 0 {
options.Heartbeat = 10 * time.Second
}
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
EnableDatagrams: true,
Expand All @@ -86,7 +83,6 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
ctx: options.Context,
logger: options.Logger,
tlsConfig: options.TLSConfig,
heartbeat: options.Heartbeat,
quicConfig: quicConfig,
userMap: make(map[[16]byte]U),
congestionControl: options.CongestionControl,
Expand Down Expand Up @@ -198,7 +194,6 @@ func (s *serverSession[U]) handle() {
go s.loopStreams()
go s.loopMessages()
go s.handleAuthTimeout()
go s.loopHeartbeats()
}

func (s *serverSession[U]) loopUniStreams() {
Expand Down Expand Up @@ -363,22 +358,6 @@ func (s *serverSession[U]) handleStream(stream quic.Stream) error {
return nil
}

func (s *serverSession[U]) loopHeartbeats() {
ticker := time.NewTicker(s.heartbeat)
defer ticker.Stop()
for {
select {
case <-s.connDone:
return
case <-ticker.C:
err := s.quicConn.SendDatagram([]byte{Version, CommandHeartbeat})
if err != nil {
s.closeWithError(E.Cause(err, "send heartbeat"))
}
}
}
}

func (s *serverSession[U]) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
Expand Down