diff --git a/internal/streams/producer.go b/internal/streams/producer.go index 09e2dcc58..1a06151c5 100644 --- a/internal/streams/producer.go +++ b/internal/streams/producer.go @@ -34,6 +34,8 @@ type Producer struct { state state mu sync.Mutex workerID int + // Add channel to signal worker to stop + stopChan chan struct{} } const SourceTemplate = "{input}" @@ -154,22 +156,50 @@ func (p *Producer) start() { p.state = stateStart p.workerID++ - go p.worker(p.conn, p.workerID) + // Create stop channel for this worker + p.stopChan = make(chan struct{}) + + go p.worker(p.conn, p.workerID, p.stopChan) } -func (p *Producer) worker(conn core.Producer, workerID int) { - if err := conn.Start(); err != nil { - p.mu.Lock() - closed := p.workerID != workerID - p.mu.Unlock() +func (p *Producer) worker(conn core.Producer, workerID int, stopChan chan struct{}) { + // Create done channel to track Start() completion + done := make(chan error, 1) + + // Run Start() in a goroutine so we can monitor stopChan + go func() { + done <- conn.Start() + }() + + // Wait for either completion or stop signal + var err error + select { + case err = <-done: + // Start() completed (either success or error) + case <-stopChan: + // Stop was called - force stop the connection + _ = conn.Stop() + // Wait for Start() to finish after Stop() + err = <-done + log.Debug().Msgf("[streams] worker stopped by signal url=%s", p.url) + return + } - if closed { - return - } + // Check if this worker was cancelled during Start() + p.mu.Lock() + closed := p.workerID != workerID + p.mu.Unlock() + if closed { + return + } + + // Log error if any + if err != nil { log.Warn().Err(err).Str("url", p.url).Caller().Send() } + // Attempt reconnection p.reconnect(workerID, 0) } @@ -239,7 +269,7 @@ func (p *Producer) reconnect(workerID, retry int) { // swap connections p.conn = conn - go p.worker(conn, workerID) + go p.worker(conn, workerID, p.stopChan) } func (p *Producer) stop() { @@ -254,6 +284,11 @@ func (p *Producer) stop() { log.Trace().Msgf("[streams] skip stop none producer") return case stateStart: + // Signal the worker to stop before incrementing workerID + if p.stopChan != nil { + close(p.stopChan) + p.stopChan = nil + } p.workerID++ } diff --git a/internal/streams/stream.go b/internal/streams/stream.go index 984c73edd..c9b54b9a8 100644 --- a/internal/streams/stream.go +++ b/internal/streams/stream.go @@ -103,13 +103,14 @@ func (s *Stream) stopProducers() { s.mu.Lock() producers: for _, producer := range s.producers { + // Use thread-safe HasSenders() method instead of deprecated Senders() for _, track := range producer.receivers { - if len(track.Senders()) > 0 { + if track.HasSenders() { continue producers } } for _, track := range producer.senders { - if len(track.Senders()) > 0 { + if track.HasSenders() { continue producers } } diff --git a/pkg/core/node.go b/pkg/core/node.go index a9959c3de..1e8250145 100644 --- a/pkg/core/node.go +++ b/pkg/core/node.go @@ -30,6 +30,7 @@ type Node struct { id uint32 childs []*Node parent *Node + closed bool // Track if node is closed mu sync.Mutex } @@ -41,10 +42,21 @@ func (n *Node) WithParent(parent *Node) *Node { func (n *Node) AppendChild(child *Node) { n.mu.Lock() + + // Don't add children to closed nodes + if n.closed { + n.mu.Unlock() + // Parent is closed, close the orphaned child + child.Close() + return + } + n.childs = append(n.childs, child) n.mu.Unlock() + child.mu.Lock() child.parent = n + child.mu.Unlock() } func (n *Node) RemoveChild(child *Node) { @@ -58,16 +70,67 @@ func (n *Node) RemoveChild(child *Node) { n.mu.Unlock() } +// Clean up ghost children (children that were closed but not removed) +func (n *Node) cleanGhostChildren() { + n.mu.Lock() + defer n.mu.Unlock() + + // Filter out closed children + alive := make([]*Node, 0, len(n.childs)) + for _, child := range n.childs { + child.mu.Lock() + isClosed := child.closed + child.mu.Unlock() + + if !isClosed { + alive = append(alive, child) + } + } + + if len(alive) != len(n.childs) { + n.childs = alive + } +} + func (n *Node) Close() { - if parent := n.parent; parent != nil { + // Lock to safely read parent + n.mu.Lock() + + // Prevent double-close + if n.closed { + n.mu.Unlock() + return + } + n.closed = true + + parent := n.parent + n.parent = nil // Clear parent reference + n.mu.Unlock() + + if parent != nil { parent.RemoveChild(n) - if len(parent.childs) == 0 { + // Clean ghost children before checking + parent.cleanGhostChildren() + + // Check if parent should close + parent.mu.Lock() + hasChildren := len(parent.childs) > 0 + parent.mu.Unlock() + + if !hasChildren { parent.Close() } } else { - for _, childs := range n.childs { - childs.Close() + // This is a root node, close all children + n.mu.Lock() + children := make([]*Node, len(n.childs)) + copy(children, n.childs) + n.childs = nil + n.mu.Unlock() + + for _, child := range children { + child.Close() } } } @@ -79,6 +142,15 @@ func MoveNode(dst, src *Node) { src.mu.Unlock() dst.mu.Lock() + // Don't move to closed node + if dst.closed { + dst.mu.Unlock() + // Close orphaned children + for _, child := range childs { + child.Close() + } + return + } dst.childs = childs dst.mu.Unlock() diff --git a/pkg/core/track.go b/pkg/core/track.go index f363a9fd1..5862aa838 100644 --- a/pkg/core/track.go +++ b/pkg/core/track.go @@ -29,7 +29,15 @@ func NewReceiver(media *Media, codec *Codec) *Receiver { r.Input = func(packet *Packet) { r.Bytes += len(packet.Payload) r.Packets++ - for _, child := range r.childs { + + // Lock and copy children to prevent race condition + r.mu.Lock() + children := make([]*Node, len(r.childs)) + copy(children, r.childs) + r.mu.Unlock() + + // Now iterate safely without holding the lock + for _, child := range children { child.Input(packet) } } @@ -41,7 +49,17 @@ func (r *Receiver) WriteRTP(packet *rtp.Packet) { r.Input(packet) } -// Deprecated: should be removed +// Thread-safe check for senders with ghost children cleanup +func (r *Receiver) HasSenders() bool { + // Clean up any ghost children first + r.cleanGhostChildren() + + r.mu.Lock() + defer r.mu.Unlock() + return len(r.childs) > 0 +} + +// Deprecated: should be removed, use HasSenders instead func (r *Receiver) Senders() []*Sender { if len(r.childs) > 0 { return []*Sender{{}} @@ -73,6 +91,7 @@ type Sender struct { buf chan *Packet done chan struct{} + closed bool // Track closed state } func NewSender(media *Media, codec *Codec) *Sender { @@ -98,12 +117,17 @@ func NewSender(media *Media, codec *Codec) *Sender { } s.Input = func(packet *Packet) { s.mu.Lock() - // unblock write to nil chan - OK, write to closed chan - panic - select { - case s.buf <- packet: - s.Bytes += len(packet.Payload) - s.Packets++ - default: + // Check closed state before writing + if !s.closed && s.buf != nil { + // unblock write to nil chan - OK, write to closed chan - panic + select { + case s.buf <- packet: + s.Bytes += len(packet.Payload) + s.Packets++ + default: + s.Drops++ + } + } else { s.Drops++ } s.mu.Unlock() @@ -167,12 +191,18 @@ func (s *Sender) State() string { func (s *Sender) Close() { // close buffer if exists s.mu.Lock() - if s.buf != nil { - close(s.buf) // exit from for range loop - s.buf = nil // prevent writing to closed chan + if !s.closed { + // Mark as closed first, then close channel + s.closed = true + if s.buf != nil { + close(s.buf) + s.buf = nil + } } s.mu.Unlock() + // Always call Node.Close() to ensure proper cleanup + // even if the sender-specific cleanup already happened s.Node.Close() } diff --git a/pkg/tapo/backchannel.go b/pkg/tapo/backchannel.go index b49412605..ffd187ebc 100644 --- a/pkg/tapo/backchannel.go +++ b/pkg/tapo/backchannel.go @@ -10,6 +10,10 @@ import ( ) func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver) error { + // Check if client is already closed/cancelled + if c.ctx.Err() != nil { + return c.ctx.Err() + } if c.sender == nil { if err := c.SetupBackchannel(); err != nil { return err @@ -23,6 +27,11 @@ func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver c.sender = core.NewSender(media, track.Codec) c.sender.Handler = func(packet *rtp.Packet) { + // Check context before writing + if c.ctx.Err() != nil { + return + } + b := muxer.GetPayload(pid, packet.Timestamp, packet.Payload) _ = c.WriteBackchannel(b) } @@ -33,6 +42,11 @@ func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver } func (c *Client) SetupBackchannel() (err error) { + // Check context + if c.ctx.Err() != nil { + return c.ctx.Err() + } + // if conn1 is not used - we will use it for backchannel // or we need to start another conn for session2 if c.session1 != "" { @@ -48,6 +62,20 @@ func (c *Client) SetupBackchannel() (err error) { } func (c *Client) WriteBackchannel(body []byte) (err error) { + // Check if closed before writing + c.closeMutex.Lock() + closed := c.closed + conn := c.conn2 + c.closeMutex.Unlock() + + if closed || c.ctx.Err() != nil { + return c.ctx.Err() + } + + if conn == nil { + return nil + } + // TODO: fixme (size) buf := bytes.NewBuffer(nil) buf.WriteString("----client-stream-boundary--\r\n") diff --git a/pkg/tapo/client.go b/pkg/tapo/client.go index e52250c33..03b8b1319 100644 --- a/pkg/tapo/client.go +++ b/pkg/tapo/client.go @@ -3,6 +3,7 @@ package tapo import ( "bufio" "bytes" + "context" "crypto/aes" "crypto/cipher" "crypto/md5" @@ -17,6 +18,7 @@ import ( "net/url" "strconv" "strings" + "sync" "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/mpegts" @@ -45,6 +47,12 @@ type Client struct { recv int send int + + // Context for cancellation + ctx context.Context + cancel context.CancelFunc + closeMutex sync.Mutex + closed bool } // block ciphers using cipher block chaining. @@ -70,8 +78,17 @@ func Dial(rawURL string) (*Client, error) { u.Host += ":8800" } - c := &Client{url: u} + // Create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + + c := &Client{ + url: u, + ctx: ctx, + cancel: cancel, + } + if c.conn1, err = c.newConn(); err != nil { + cancel() return nil, err } return c, nil @@ -188,9 +205,39 @@ func (c *Client) Handle() error { var transcode func([]byte) []byte + // Create done channel to signal completion + done := make(chan struct{}) + defer close(done) + + // Monitor context cancellation in separate goroutine + go func() { + select { + case <-c.ctx.Done(): + // Force close connection to unblock Read operations + c.closeMutex.Lock() + if !c.closed && c.conn1 != nil { + _ = c.conn1.Close() + } + c.closeMutex.Unlock() + case <-done: + // Handle() finished normally + } + }() + for { + // Check context before each read + select { + case <-c.ctx.Done(): + return c.ctx.Err() + default: + } + p, err := rd.NextRawPart() if err != nil { + // Check if error is due to cancellation + if c.ctx.Err() != nil { + return c.ctx.Err() + } return err } @@ -210,6 +257,13 @@ func (c *Client) Handle() error { b := body for { + // Check context during read loop + select { + case <-c.ctx.Done(): + return c.ctx.Err() + default: + } + if n, err2 := p.Read(b); err2 == nil { b = b[n:] } else { @@ -257,17 +311,37 @@ func (c *Client) Handle() error { } } -func (c *Client) Close() (err error) { +// Safe close with proper synchronization +func (c *Client) Close() error { + c.closeMutex.Lock() + if c.closed { + c.closeMutex.Unlock() + return nil + } + c.closed = true + c.closeMutex.Unlock() + + // Cancel context first to signal Handle() to stop + c.cancel() + + // Close connections (will unblock any blocking reads) + var err error if c.conn1 != nil { err = c.conn1.Close() } - if c.conn2 != nil { + if c.conn2 != nil && c.conn2 != c.conn1 { _ = c.conn2.Close() } - return + + return err } func (c *Client) Request(conn net.Conn, body []byte) (string, error) { + // Check if context is cancelled + if c.ctx.Err() != nil { + return "", c.ctx.Err() + } + // TODO: fixme (size) buf := bytes.NewBuffer(nil) buf.WriteString("----client-stream-boundary--\r\n") @@ -283,6 +357,11 @@ func (c *Client) Request(conn net.Conn, body []byte) (string, error) { mpReader := multipart.NewReader(conn, "--device-stream-boundary--") for { + // Check cancellation + if c.ctx.Err() != nil { + return "", c.ctx.Err() + } + p, err := mpReader.NextRawPart() if err != nil { return "", err diff --git a/pkg/tapo/producer.go b/pkg/tapo/producer.go index 87a91ff50..6d1b32c39 100644 --- a/pkg/tapo/producer.go +++ b/pkg/tapo/producer.go @@ -68,8 +68,10 @@ func (c *Client) Stop() error { for _, receiver := range c.receivers { receiver.Close() } + // Close and clear sender reference if c.sender != nil { c.sender.Close() + c.sender = nil // Clear reference to prevent ghost sender } return c.Close() }