Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
661 changes: 389 additions & 272 deletions pool/node.go

Large diffs are not rendered by default.

124 changes: 117 additions & 7 deletions pool/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,10 @@ func TestDispatchJobRaceCondition(t *testing.T) {

// Set a stale pending timestamp
staleTS := time.Now().Add(-time.Hour).UnixNano()
_, err := node1.pendingJobsMap.SetAndWait(ctx, jobKey, strconv.FormatInt(staleTS, 10))
_, err := node1.jobPendingMap.SetAndWait(ctx, jobKey, strconv.FormatInt(staleTS, 10))
require.NoError(t, err, "Failed to set stale pending timestamp")
defer func() {
_, err = node1.pendingJobsMap.Delete(ctx, jobKey)
_, err = node1.jobPendingMap.Delete(ctx, jobKey)
assert.NoError(t, err, "Failed to delete pending timestamp")
}()

Expand All @@ -346,7 +346,7 @@ func TestDispatchJobRaceCondition(t *testing.T) {
require.NoError(t, err, "Dispatch should succeed")
// Verify pending entry was cleaned up
require.Eventually(t, func() bool {
val, exists := node1.pendingJobsMap.Get(jobKey)
val, exists := node1.jobPendingMap.Get(jobKey)
t.Logf("Got pending value: %q", val)
return !exists
}, max, delay, "Pending entry should be cleaned up after successful dispatch")
Expand All @@ -357,7 +357,7 @@ func TestDispatchJobRaceCondition(t *testing.T) {
payload := []byte("test payload")

// Set an invalid pending timestamp
_, err := node1.pendingJobsMap.SetAndWait(ctx, jobKey, "invalid-timestamp")
_, err := node1.jobPendingMap.SetAndWait(ctx, jobKey, "invalid-timestamp")
require.NoError(t, err, "Failed to set invalid pending timestamp")

// Dispatch should succeed (invalid timestamps are logged and ignored)
Expand All @@ -380,7 +380,7 @@ func TestDispatchJobRaceCondition(t *testing.T) {

// Verify pending entry was cleaned up
require.Eventually(t, func() bool {
_, exists := node1.pendingJobsMap.Get(jobKey)
_, exists := node1.jobPendingMap.Get(jobKey)
return !exists
}, max, delay, "Pending entry should be cleaned up after failed dispatch")
})
Expand Down Expand Up @@ -623,7 +623,7 @@ func TestAckWorkerEventWithMissingPendingEvent(t *testing.T) {
}

// Call ackWorkerEvent with the mock event
node.ackWorkerEvent(ctx, mockEvent)
node.ackWorkerEvent(mockEvent)

// Verify that no panic occurred and the function completed successfully
assert.True(t, true, "ackWorkerEvent should complete without panic")
Expand Down Expand Up @@ -681,7 +681,7 @@ func TestStaleEventsAreRemoved(t *testing.T) {
node.pendingEvents.Store(pendingEventKey("worker", mockEventID), mockEvent)

// Call ackWorkerEvent to trigger the stale event cleanup
node.ackWorkerEvent(ctx, mockEvent)
node.ackWorkerEvent(mockEvent)

assert.Eventually(t, func() bool {
_, ok := node.pendingEvents.Load(pendingEventKey("worker", staleEventID))
Expand Down Expand Up @@ -834,6 +834,116 @@ func TestShutdownStopsAllJobs(t *testing.T) {
assert.Empty(t, worker2.Jobs(), "Worker2 should have no remaining jobs")
}

func TestWorkerAckStreams(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
rdb := ptesting.NewRedisClient(t)
node := newTestNode(t, ctx, rdb, testName)
defer ptesting.CleanupRedis(t, rdb, true, testName)

// Create a worker and dispatch a job
worker := newTestWorker(t, ctx, node)
require.NoError(t, node.DispatchJob(ctx, testName, []byte("payload")))

// Wait for the job to start and be acknowledged
require.Eventually(t, func() bool {
return len(worker.Jobs()) == 1
}, max, delay)

// Verify stream is created and cached
stream1, err := node.getNodeStream(node.ID)
require.NoError(t, err)
stream2, err := node.getNodeStream(node.ID)
require.NoError(t, err)
assert.Same(t, stream1, stream2, "Expected same stream instance to be returned")

// Verify stream exists before shutdown
streamKey := "pulse:stream:" + nodeStreamName(testName, node.ID)
exists, err := rdb.Exists(ctx, streamKey).Result()
assert.NoError(t, err)
assert.Equal(t, int64(1), exists, "Expected stream to exist before shutdown")

// Shutdown node
assert.NoError(t, node.Shutdown(ctx))

// Verify stream is destroyed in Redis
exists, err = rdb.Exists(ctx, streamKey).Result()
assert.NoError(t, err)
assert.Equal(t, int64(0), exists, "Expected stream to be destroyed after shutdown")
}

func TestStaleWorkerCleanupAfterJobRequeue(t *testing.T) {
// Setup test environment
ctx := ptesting.NewTestContext(t)
testName := strings.Replace(t.Name(), "/", "_", -1)
rdb := ptesting.NewRedisClient(t)
defer ptesting.CleanupRedis(t, rdb, true, testName)

node := newTestNode(t, ctx, rdb, testName)
defer func() { assert.NoError(t, node.Shutdown(ctx)) }()

// Create a worker that will become stale
staleWorker := newTestWorker(t, ctx, node)

// Dispatch some jobs to the worker
for i := 0; i < 3; i++ {
jobKey := fmt.Sprintf("%s_%d", testName, i)
require.NoError(t, node.DispatchJob(ctx, jobKey, []byte("test-payload")))
}

// Wait for jobs to be assigned
require.Eventually(t, func() bool {
return len(staleWorker.Jobs()) == 3
}, max, delay, "Jobs were not assigned to worker")

// Make the worker stale by stopping it and setting an old keepalive
staleWorker.stop(ctx)
_, err := node.workerKeepAliveMap.Set(ctx, staleWorker.ID,
strconv.FormatInt(time.Now().Add(-2*node.workerTTL).UnixNano(), 10))
require.NoError(t, err)

// Create a new worker to receive requeued jobs
newWorker := newTestWorker(t, ctx, node)

// Wait for cleanup to happen and jobs to be requeued
require.Eventually(t, func() bool {
return len(newWorker.Jobs()) == 3
}, max, delay, "Jobs were not requeued to new worker")

// Verify stale worker was deleted
require.Eventually(t, func() bool {
// Check that worker is removed from all tracking maps
workers := node.Workers()
if len(workers) != 1 {
t.Logf("Expected 1 worker, got %d", len(workers))
return false
}

// Check worker is removed from worker map
workerMap := node.workerMap.Map()
if _, exists := workerMap[staleWorker.ID]; exists {
t.Log("Worker still exists in worker map")
return false
}

// Check keepalive is removed
keepAlive := node.workerKeepAliveMap.Map()
if _, exists := keepAlive[staleWorker.ID]; exists {
t.Log("Worker still has keepalive entry")
return false
}

// Check jobs are removed
jobs := node.jobMap.Map()
if _, exists := jobs[staleWorker.ID]; exists {
t.Log("Worker still has jobs assigned")
return false
}

return true
}, max, delay, "Stale worker was not properly cleaned up")
}

type mockAcker struct {
XAckFunc func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd
}
Expand Down
7 changes: 4 additions & 3 deletions pool/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ func (node *Node) Schedule(ctx context.Context, producer JobProducer, interval t
if err := sched.stopJobs(ctx, plan); err != nil {
return fmt.Errorf("failed to stop jobs: %w", err)
}
pulse.Go(ctx, func() { sched.scheduleJobs(ctx, ticker, producer) })
pulse.Go(ctx, func() { sched.handleStop(ctx) })

pulse.Go(sched.logger, func() { sched.scheduleJobs(ctx, ticker, producer) })
pulse.Go(sched.logger, func() { sched.handleStop() })
return nil
}

Expand Down Expand Up @@ -175,7 +176,7 @@ func (sched *scheduler) stopJobs(ctx context.Context, plan *JobPlan) error {
}

// handleStop handles the scheduler stop signal.
func (sched *scheduler) handleStop(_ context.Context) {
func (sched *scheduler) handleStop() {
ch := sched.jobMap.Subscribe()
for ev := range ch {
if ev == rmap.EventReset {
Expand Down
2 changes: 1 addition & 1 deletion pool/ticker.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (node *Node) NewTicker(ctx context.Context, name string, d time.Duration, o
}
t.initTimer()
t.wg.Add(1)
pulse.Go(ctx, func() { t.handleEvents() })
pulse.Go(logger, func() { t.handleEvents() })
return t, nil
}

Expand Down
54 changes: 31 additions & 23 deletions pool/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"time"

"github.com/oklog/ulid/v2"
"goa.design/clue/log"

"goa.design/pulse/pulse"
"goa.design/pulse/rmap"
"goa.design/pulse/streaming"
soptions "goa.design/pulse/streaming/options"
"goa.design/pulse/streaming/options"
)

type (
Expand Down Expand Up @@ -94,11 +96,14 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) {
if _, err := node.workerKeepAliveMap.SetAndWait(ctx, wid, now); err != nil {
return nil, fmt.Errorf("failed to update worker keep-alive: %w", err)
}
stream, err := streaming.NewStream(workerStreamName(wid), node.rdb, soptions.WithStreamLogger(node.logger))
stream, err := streaming.NewStream(workerStreamName(wid), node.rdb, options.WithStreamLogger(node.logger))
if err != nil {
return nil, fmt.Errorf("failed to create jobs stream for worker %q: %w", wid, err)
}
reader, err := stream.NewReader(ctx, soptions.WithReaderBlockDuration(node.workerTTL/2), soptions.WithReaderStartAtOldest())
if _, err := stream.Add(ctx, evInit, marshalEnvelope(node.ID, []byte(wid))); err != nil {
return nil, fmt.Errorf("failed to add init event to worker stream %q: %w", workerStreamName(wid), err)
}
reader, err := stream.NewReader(ctx, options.WithReaderBlockDuration(node.workerTTL/2), options.WithReaderStartAtOldest())
if err != nil {
return nil, fmt.Errorf("failed to create reader for worker %q: %w", wid, err)
}
Expand All @@ -110,10 +115,10 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) {
stream: stream,
reader: reader,
done: make(chan struct{}),
jobsMap: node.jobsMap,
jobPayloadsMap: node.jobPayloadsMap,
jobsMap: node.jobMap,
jobPayloadsMap: node.jobPayloadMap,
keepAliveMap: node.workerKeepAliveMap,
shutdownMap: node.shutdownMap,
shutdownMap: node.nodeShutdownMap,
workerTTL: node.workerTTL,
workerShutdownTTL: node.workerShutdownTTL,
logger: node.logger.WithPrefix("worker", wid),
Expand All @@ -126,8 +131,13 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) {
"worker_shutdown_ttl", w.workerShutdownTTL)

w.wg.Add(2)
pulse.Go(ctx, func() { w.handleEvents(ctx, reader.Subscribe()) })
pulse.Go(ctx, func() { w.keepAlive(ctx) })

// Create new context for the worker so that canceling the original one does
// not cancel the worker.
logCtx := context.Background()
logCtx = log.WithContext(logCtx, ctx)
pulse.Go(w.logger, func() { w.handleEvents(logCtx, reader.Subscribe()) })
pulse.Go(w.logger, func() { w.keepAlive(logCtx) })

return w, nil
}
Expand Down Expand Up @@ -178,6 +188,9 @@ func (w *Worker) handleEvents(ctx context.Context, c <-chan *streaming.Event) {
nodeID, payload := unmarshalEnvelope(ev.Payload)
var err error
switch ev.EventName {
case evInit:
w.logger.Debug("handleEvents: received init", "event", ev.EventName, "id", ev.ID)
continue
case evStartJob:
w.logger.Debug("handleEvents: received start job", "event", ev.EventName, "id", ev.ID)
err = w.startJob(ctx, unmarshalJob(payload))
Expand All @@ -200,6 +213,7 @@ func (w *Worker) handleEvents(ctx context.Context, c <-chan *streaming.Event) {
}
w.ackPoolEvent(ctx, nodeID, ev.ID, nil)
case <-w.done:
w.logger.Debug("handleEvents: done")
return
}
}
Expand Down Expand Up @@ -291,22 +305,18 @@ func (w *Worker) notify(_ context.Context, key string, payload []byte) error {
// ackPoolEvent acknowledges the pool event that originated from the node with
// the given ID.
func (w *Worker) ackPoolEvent(ctx context.Context, nodeID, eventID string, ackerr error) {
stream, ok := w.nodeStreams.Load(nodeID)
if !ok {
var err error
stream, err = streaming.NewStream(nodeStreamName(w.node.PoolName, nodeID), w.node.rdb, soptions.WithStreamLogger(w.logger))
if err != nil {
w.logger.Error(fmt.Errorf("failed to create stream for node %q: %w", nodeID, err))
return
}
w.nodeStreams.Store(nodeID, stream)
stream, err := w.node.getNodeStream(nodeID)
if err != nil {
w.logger.Error(fmt.Errorf("failed to get ack stream for node %q: %w", nodeID, err))
return
}

var msg string
if ackerr != nil {
msg = ackerr.Error()
}
ack := &ack{EventID: eventID, Error: msg}
if _, err := stream.(*streaming.Stream).Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack))); err != nil {
if _, err := stream.Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack)), options.WithOnlyIfStreamExists()); err != nil {
w.logger.Error(fmt.Errorf("failed to ack event %q from node %q: %w", eventID, nodeID, err))
}
}
Expand All @@ -328,6 +338,7 @@ func (w *Worker) keepAlive(ctx context.Context) {
w.logger.Error(fmt.Errorf("failed to update worker keep-alive: %w", err))
}
case <-w.done:
w.logger.Debug("keepAlive: done")
return
}
}
Expand All @@ -350,15 +361,14 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) {
w.logger.Debug("rebalance: no jobs to rebalance")
return
}
cherrs := make(map[string]chan error, total)
for key, job := range rebalanced {
if err := w.handler.Stop(key); err != nil {
w.logger.Error(fmt.Errorf("rebalance: failed to stop job: %w", err), "job", key)
continue
}
w.logger.Debug("stopped job", "job", key)
w.jobs.Delete(key)
cherr, err := w.node.requeueJob(ctx, w.ID, job)
err := w.node.dispatchJob(ctx, key, marshalJob(job), true)
if err != nil {
w.logger.Error(fmt.Errorf("rebalance: failed to requeue job: %w", err), "job", key)
if err := w.handler.Start(job); err != nil {
Expand All @@ -367,9 +377,7 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) {
continue
}
delete(rebalanced, key)
cherrs[key] = cherr
}
pulse.Go(ctx, func() { w.node.processRequeuedJobs(ctx, w.ID, cherrs, false) })
}

// requeueJobs requeues the jobs handled by the worker.
Expand Down Expand Up @@ -432,7 +440,7 @@ func (w *Worker) attemptRequeue(ctx context.Context, jobsToRequeue map[string]*J

wg.Add(len(jobsToRequeue))
for key, job := range jobsToRequeue {
pulse.Go(ctx, func() {
pulse.Go(w.logger, func() {
defer wg.Done()
err := w.requeueJob(ctx, job)
if err != nil {
Expand Down
15 changes: 6 additions & 9 deletions pulse/goroutine.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package pulse

import (
"context"
"fmt"
"runtime/debug"

"goa.design/clue/log"
)

// Go runs the given function in a separate goroutine and recovers from any panic,
Expand All @@ -16,14 +13,14 @@ import (
// Go(ctx, func() {
// // Your code here
// })
func Go(ctx context.Context, f func()) {
go func() {
defer func(ctx context.Context) {
func Go(logger Logger, f func()) {
go func(logger Logger) {
defer func() {
if r := recover(); r != nil {
panicErr := fmt.Errorf("Panic recovered: %v\n%s", r, debug.Stack())
log.Error(ctx, panicErr)
logger.Error(panicErr)
}
}(ctx)
}()
f()
}()
}(logger)
}
Loading
Loading