Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions internal/streams/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type Producer struct {
state state
mu sync.Mutex
workerID int
// Add channel to signal worker to stop
stopChan chan struct{}
}

const SourceTemplate = "{input}"
Expand Down Expand Up @@ -154,22 +156,50 @@ func (p *Producer) start() {
p.state = stateStart
p.workerID++

go p.worker(p.conn, p.workerID)
// Create stop channel for this worker
p.stopChan = make(chan struct{})

go p.worker(p.conn, p.workerID, p.stopChan)
}

func (p *Producer) worker(conn core.Producer, workerID int) {
if err := conn.Start(); err != nil {
p.mu.Lock()
closed := p.workerID != workerID
p.mu.Unlock()
func (p *Producer) worker(conn core.Producer, workerID int, stopChan chan struct{}) {
// Create done channel to track Start() completion
done := make(chan error, 1)

// Run Start() in a goroutine so we can monitor stopChan
go func() {
done <- conn.Start()
}()

// Wait for either completion or stop signal
var err error
select {
case err = <-done:
// Start() completed (either success or error)
case <-stopChan:
// Stop was called - force stop the connection
_ = conn.Stop()
// Wait for Start() to finish after Stop()
err = <-done
log.Debug().Msgf("[streams] worker stopped by signal url=%s", p.url)
return
}

if closed {
return
}
// Check if this worker was cancelled during Start()
p.mu.Lock()
closed := p.workerID != workerID
p.mu.Unlock()

if closed {
return
}

// Log error if any
if err != nil {
log.Warn().Err(err).Str("url", p.url).Caller().Send()
}

// Attempt reconnection
p.reconnect(workerID, 0)
}

Expand Down Expand Up @@ -239,7 +269,7 @@ func (p *Producer) reconnect(workerID, retry int) {
// swap connections
p.conn = conn

go p.worker(conn, workerID)
go p.worker(conn, workerID, p.stopChan)
}

func (p *Producer) stop() {
Expand All @@ -254,6 +284,11 @@ func (p *Producer) stop() {
log.Trace().Msgf("[streams] skip stop none producer")
return
case stateStart:
// Signal the worker to stop before incrementing workerID
if p.stopChan != nil {
close(p.stopChan)
p.stopChan = nil
}
p.workerID++
}

Expand Down
5 changes: 3 additions & 2 deletions internal/streams/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ func (s *Stream) stopProducers() {
s.mu.Lock()
producers:
for _, producer := range s.producers {
// Use thread-safe HasSenders() method instead of deprecated Senders()
for _, track := range producer.receivers {
if len(track.Senders()) > 0 {
if track.HasSenders() {
continue producers
}
}
for _, track := range producer.senders {
if len(track.Senders()) > 0 {
if track.HasSenders() {
continue producers
}
}
Expand Down
80 changes: 76 additions & 4 deletions pkg/core/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Node struct {
id uint32
childs []*Node
parent *Node
closed bool // Track if node is closed

mu sync.Mutex
}
Expand All @@ -41,10 +42,21 @@ func (n *Node) WithParent(parent *Node) *Node {

func (n *Node) AppendChild(child *Node) {
n.mu.Lock()

// Don't add children to closed nodes
if n.closed {
n.mu.Unlock()
// Parent is closed, close the orphaned child
child.Close()
return
}

n.childs = append(n.childs, child)
n.mu.Unlock()

child.mu.Lock()
child.parent = n
child.mu.Unlock()
}

func (n *Node) RemoveChild(child *Node) {
Expand All @@ -58,16 +70,67 @@ func (n *Node) RemoveChild(child *Node) {
n.mu.Unlock()
}

// Clean up ghost children (children that were closed but not removed)
func (n *Node) cleanGhostChildren() {
n.mu.Lock()
defer n.mu.Unlock()

// Filter out closed children
alive := make([]*Node, 0, len(n.childs))
for _, child := range n.childs {
child.mu.Lock()
isClosed := child.closed
child.mu.Unlock()

if !isClosed {
alive = append(alive, child)
}
}

if len(alive) != len(n.childs) {
n.childs = alive
}
}

func (n *Node) Close() {
if parent := n.parent; parent != nil {
// Lock to safely read parent
n.mu.Lock()

// Prevent double-close
if n.closed {
n.mu.Unlock()
return
}
n.closed = true

parent := n.parent
n.parent = nil // Clear parent reference
n.mu.Unlock()

if parent != nil {
parent.RemoveChild(n)

if len(parent.childs) == 0 {
// Clean ghost children before checking
parent.cleanGhostChildren()

// Check if parent should close
parent.mu.Lock()
hasChildren := len(parent.childs) > 0
parent.mu.Unlock()

if !hasChildren {
parent.Close()
}
} else {
for _, childs := range n.childs {
childs.Close()
// This is a root node, close all children
n.mu.Lock()
children := make([]*Node, len(n.childs))
copy(children, n.childs)
n.childs = nil
n.mu.Unlock()

for _, child := range children {
child.Close()
}
}
}
Expand All @@ -79,6 +142,15 @@ func MoveNode(dst, src *Node) {
src.mu.Unlock()

dst.mu.Lock()
// Don't move to closed node
if dst.closed {
dst.mu.Unlock()
// Close orphaned children
for _, child := range childs {
child.Close()
}
return
}
dst.childs = childs
dst.mu.Unlock()

Expand Down
52 changes: 41 additions & 11 deletions pkg/core/track.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ func NewReceiver(media *Media, codec *Codec) *Receiver {
r.Input = func(packet *Packet) {
r.Bytes += len(packet.Payload)
r.Packets++
for _, child := range r.childs {

// Lock and copy children to prevent race condition
r.mu.Lock()
children := make([]*Node, len(r.childs))
copy(children, r.childs)
r.mu.Unlock()

// Now iterate safely without holding the lock
for _, child := range children {
child.Input(packet)
}
}
Expand All @@ -41,7 +49,17 @@ func (r *Receiver) WriteRTP(packet *rtp.Packet) {
r.Input(packet)
}

// Deprecated: should be removed
// Thread-safe check for senders with ghost children cleanup
func (r *Receiver) HasSenders() bool {
// Clean up any ghost children first
r.cleanGhostChildren()

r.mu.Lock()
defer r.mu.Unlock()
return len(r.childs) > 0
}

// Deprecated: should be removed, use HasSenders instead
func (r *Receiver) Senders() []*Sender {
if len(r.childs) > 0 {
return []*Sender{{}}
Expand Down Expand Up @@ -73,6 +91,7 @@ type Sender struct {

buf chan *Packet
done chan struct{}
closed bool // Track closed state
}

func NewSender(media *Media, codec *Codec) *Sender {
Expand All @@ -98,12 +117,17 @@ func NewSender(media *Media, codec *Codec) *Sender {
}
s.Input = func(packet *Packet) {
s.mu.Lock()
// unblock write to nil chan - OK, write to closed chan - panic
select {
case s.buf <- packet:
s.Bytes += len(packet.Payload)
s.Packets++
default:
// Check closed state before writing
if !s.closed && s.buf != nil {
// unblock write to nil chan - OK, write to closed chan - panic
select {
case s.buf <- packet:
s.Bytes += len(packet.Payload)
s.Packets++
default:
s.Drops++
}
} else {
s.Drops++
}
s.mu.Unlock()
Expand Down Expand Up @@ -167,12 +191,18 @@ func (s *Sender) State() string {
func (s *Sender) Close() {
// close buffer if exists
s.mu.Lock()
if s.buf != nil {
close(s.buf) // exit from for range loop
s.buf = nil // prevent writing to closed chan
if !s.closed {
// Mark as closed first, then close channel
s.closed = true
if s.buf != nil {
close(s.buf)
s.buf = nil
}
}
s.mu.Unlock()

// Always call Node.Close() to ensure proper cleanup
// even if the sender-specific cleanup already happened
s.Node.Close()
}

Expand Down
28 changes: 28 additions & 0 deletions pkg/tapo/backchannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
)

func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver) error {
// Check if client is already closed/cancelled
if c.ctx.Err() != nil {
return c.ctx.Err()
}
if c.sender == nil {
if err := c.SetupBackchannel(); err != nil {
return err
Expand All @@ -23,6 +27,11 @@ func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver

c.sender = core.NewSender(media, track.Codec)
c.sender.Handler = func(packet *rtp.Packet) {
// Check context before writing
if c.ctx.Err() != nil {
return
}

b := muxer.GetPayload(pid, packet.Timestamp, packet.Payload)
_ = c.WriteBackchannel(b)
}
Expand All @@ -33,6 +42,11 @@ func (c *Client) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiver
}

func (c *Client) SetupBackchannel() (err error) {
// Check context
if c.ctx.Err() != nil {
return c.ctx.Err()
}

// if conn1 is not used - we will use it for backchannel
// or we need to start another conn for session2
if c.session1 != "" {
Expand All @@ -48,6 +62,20 @@ func (c *Client) SetupBackchannel() (err error) {
}

func (c *Client) WriteBackchannel(body []byte) (err error) {
// Check if closed before writing
c.closeMutex.Lock()
closed := c.closed
conn := c.conn2
c.closeMutex.Unlock()

if closed || c.ctx.Err() != nil {
return c.ctx.Err()
}

if conn == nil {
return nil
}

// TODO: fixme (size)
buf := bytes.NewBuffer(nil)
buf.WriteString("----client-stream-boundary--\r\n")
Expand Down
Loading