Skip to content
93 changes: 82 additions & 11 deletions v2/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package shuttle

import (
"context"
"errors"
"fmt"
"reflect"
"sync"
Expand All @@ -20,6 +21,14 @@ const (
// MessageBody is a type to represent that an input message body can be of any type
type MessageBody any

// SendAsBatchOptions contains options for the SendAsBatch method
type SendAsBatchOptions struct {
// AllowMultipleBatch when true, allows splitting large message arrays into multiple batches.
// When false, behaves like the original SendMessageBatch method.
// Default: false
AllowMultipleBatch bool
}

// AzServiceBusSender is satisfied by *azservicebus.Sender
type AzServiceBusSender interface {
SendMessage(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error
Expand Down Expand Up @@ -138,31 +147,87 @@ func (d *Sender) ToServiceBusMessage(
return msg, nil
}

// SendMessageBatch sends the array of azservicebus messages as a batch.
func (d *Sender) SendMessageBatch(ctx context.Context, messages []*azservicebus.Message) error {
// SendAsBatch sends the array of azservicebus messages as batches.
// When options.AllowMultipleBatch is true, large message arrays are split into multiple batches.
// When options.AllowMultipleBatch is false, behaves like SendMessageBatch (fails if messages don't fit in single batch).
func (d *Sender) SendAsBatch(ctx context.Context, messages []*azservicebus.Message, options *SendAsBatchOptions) error {
// Check if there is a context error before doing anything since
// we rely on context failures to detect if the sender is dead.
if ctx.Err() != nil {
return fmt.Errorf("failed to send message: %w", ctx.Err())
return fmt.Errorf("failed to send message batch: %w", ctx.Err())
}

if options == nil {
options = &SendAsBatchOptions{AllowMultipleBatch: false}
}

// Apply timeout for the entire operation
if d.options.SendTimeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.options.SendTimeout)
defer cancel()
}

batch, err := d.newMessageBatch(ctx, &azservicebus.MessageBatchOptions{})
if len(messages) == 0 {
return fmt.Errorf("cannot send empty message array")
}

// Create a message batch. It will automatically be sized for the Service Bus
// namespace's maximum message size.
currentMessageBatch, err := d.newMessageBatch(ctx, nil)
if err != nil {
return err
}
for _, msg := range messages {
if err := batch.AddMessage(msg, nil); err != nil {

for i := 0; i < len(messages); i++ {
// Add a message to our message batch. This can be called multiple times.
err = currentMessageBatch.AddMessage(messages[i], nil)

if err != nil && errors.Is(err, azservicebus.ErrMessageTooLarge) {
if currentMessageBatch.NumMessages() == 0 {
// This means the message itself is too large to be sent, even on its own.
// This will require intervention from the user.
return fmt.Errorf("single message is too large to be sent in a batch: %w", err)
}

// Message batch is full. Send it and create a new one.
if !options.AllowMultipleBatch {
// For single batch mode, return error if messages don't fit
return fmt.Errorf("messages do not fit in a single batch: %w", err)
}

// Send what we have since the batch is full
if err := d.sendBatch(ctx, currentMessageBatch); err != nil {
return err
}

// Create a new batch and retry adding this message to our batch.
newBatch, err := d.newMessageBatch(ctx, nil)
if err != nil {
return err
}

currentMessageBatch = newBatch

// rewind the counter and attempt to add the message again (this batch
// was full so it didn't go out with the previous sendBatch call).
i--
} else if err != nil {
return err
}
}
if d.options.SendTimeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.options.SendTimeout)
defer cancel()

// check if any messages are remaining to be sent.
if currentMessageBatch.NumMessages() > 0 {
return d.sendBatch(ctx, currentMessageBatch)
}

errChan := make(chan error)
return nil
}

// sendBatch sends a single message batch with proper error handling and metrics
func (d *Sender) sendBatch(ctx context.Context, batch *azservicebus.MessageBatch) error {
errChan := make(chan error)
go func() {
if err := d.sendMessageBatch(ctx, batch, nil); err != nil {
errChan <- fmt.Errorf("failed to send message batch: %w", err)
Expand All @@ -185,6 +250,12 @@ func (d *Sender) SendMessageBatch(ctx context.Context, messages []*azservicebus.
}
}

// SendMessageBatch sends the array of azservicebus messages as a batch.
// Deprecated: Use SendAsBatch instead. This method will be removed in a future version.
func (d *Sender) SendMessageBatch(ctx context.Context, messages []*azservicebus.Message) error {
return d.SendAsBatch(ctx, messages, &SendAsBatchOptions{AllowMultipleBatch: false})
}

func (d *Sender) sendMessage(ctx context.Context, msg *azservicebus.Message, options *azservicebus.SendMessageOptions) error {
d.mu.RLock()
defer d.mu.RUnlock()
Expand Down
179 changes: 169 additions & 10 deletions v2/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ func TestSender_WithDefaultSendTimeout(t *testing.T) {
})
err := sender.SendMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), nil)
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), []*azservicebus.Message{})
g.Expect(err).To(HaveOccurred())
}

func TestSender_WithSendTimeout(t *testing.T) {
Expand All @@ -186,14 +186,19 @@ func TestSender_WithSendTimeout(t *testing.T) {
})
err := sender.SendMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), nil)
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), []*azservicebus.Message{})
g.Expect(err).To(HaveOccurred())
err = sender.SendAsBatch(context.Background(), []*azservicebus.Message{}, nil)
g.Expect(err).To(HaveOccurred())
err = sender.SendAsBatch(context.Background(), []*azservicebus.Message{}, &SendAsBatchOptions{AllowMultipleBatch: true})
g.Expect(err).To(HaveOccurred())
}

func TestSender_WithContextCanceled(t *testing.T) {
g := NewWithT(t)
sendTimeout := 1 * time.Second
azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{},
DoSendMessage: func(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error {
time.Sleep(2 * time.Second)
return nil
Expand All @@ -210,8 +215,8 @@ func TestSender_WithContextCanceled(t *testing.T) {

err := sender.SendMessage(context.Background(), "test")
g.Expect(err).To(MatchError(context.DeadlineExceeded))
err = sender.SendMessageBatch(context.Background(), nil)
g.Expect(err).To(MatchError(context.DeadlineExceeded))
err = sender.SendMessageBatch(context.Background(), []*azservicebus.Message{})
g.Expect(err).To(HaveOccurred()) // error for empty messages instead of timeout
}

func TestSender_SendWithCanceledContext(t *testing.T) {
Expand All @@ -233,14 +238,15 @@ func TestSender_SendWithCanceledContext(t *testing.T) {

err := sender.SendMessage(ctx, "test")
g.Expect(err).To(MatchError(context.Canceled))
err = sender.SendMessageBatch(ctx, nil)
err = sender.SendMessageBatch(ctx, []*azservicebus.Message{})
g.Expect(err).To(MatchError(context.Canceled))
}

func TestSender_DisabledSendTimeout(t *testing.T) {
g := NewWithT(t)
sendTimeout := -1 * time.Second
azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{},
DoSendMessage: func(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error {
_, ok := ctx.Deadline()
g.Expect(ok).To(BeFalse())
Expand All @@ -258,8 +264,8 @@ func TestSender_DisabledSendTimeout(t *testing.T) {
})
err := sender.SendMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), nil)
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), []*azservicebus.Message{})
g.Expect(err).To(HaveOccurred())
}

func TestSender_SendMessage(t *testing.T) {
Expand All @@ -286,8 +292,8 @@ func TestSender_SendMessageBatch(t *testing.T) {
msg, err := sender.ToServiceBusMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())
err = sender.SendMessageBatch(context.Background(), []*azservicebus.Message{msg})
// no way to create a MessageBatch struct with a non-0 max bytes in test, so the best we can do is expect an error.
g.Expect(err).To(HaveOccurred())
// No way to create a MessageBatch struct with a non-0 max bytes in test, so the best we can do is expect an error.
}

func TestSender_AzSender(t *testing.T) {
Expand Down Expand Up @@ -341,6 +347,150 @@ func TestSender_ConcurrentSendAndSetAzSender(t *testing.T) {
g.Expect(azSender2.SendMessageCalled).To(BeTrue())
}

func TestSender_SendAsBatch_EmptyMessages(t *testing.T) {
g := NewWithT(t)
azSender := &fakeAzSender{}
sender := NewSender(azSender, nil)

options := &SendAsBatchOptions{AllowMultipleBatch: true}
err := sender.SendAsBatch(context.Background(), []*azservicebus.Message{}, options)
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("cannot send empty message array"))
// Should not call send since error returned early
g.Expect(azSender.SendMessageBatchCalled).To(BeFalse())
// No batches should be created
g.Expect(azSender.BatchesCreated).To(Equal(0))
}

func TestSender_SendAsBatch_EmptyMessages_SingleBatch(t *testing.T) {
g := NewWithT(t)
azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{},
}
sender := NewSender(azSender, nil)

options := &SendAsBatchOptions{AllowMultipleBatch: false}
err := sender.SendAsBatch(context.Background(), []*azservicebus.Message{}, options)
// Should fail because empty message array is not allowed
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("cannot send empty message array"))
// Should not attempt to send
g.Expect(azSender.SendMessageBatchCalled).To(BeFalse())
// No batch should be created
g.Expect(azSender.BatchesCreated).To(Equal(0))
}

func TestSender_SendAsBatch_ContextCanceled(t *testing.T) {
g := NewWithT(t)
azSender := &fakeAzSender{}
sender := NewSender(azSender, nil)

ctx, cancel := context.WithCancel(context.Background())
cancel()

msg, err := sender.ToServiceBusMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())

options := &SendAsBatchOptions{AllowMultipleBatch: true}
err = sender.SendAsBatch(ctx, []*azservicebus.Message{msg}, options)
g.Expect(err).To(MatchError(context.Canceled))
}

func TestSender_SendAsBatch_NewMessageBatchError(t *testing.T) {
g := NewWithT(t)
expectedErr := fmt.Errorf("batch creation failed")
azSender := &fakeAzSender{
NewMessageBatchErr: expectedErr,
}
sender := NewSender(azSender, nil)

msg, err := sender.ToServiceBusMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())

options := &SendAsBatchOptions{AllowMultipleBatch: true}
err = sender.SendAsBatch(context.Background(), []*azservicebus.Message{msg}, options)
g.Expect(err).To(Equal(expectedErr))
g.Expect(azSender.BatchesCreated).To(Equal(1))
g.Expect(azSender.SendMessageBatchCalled).To(BeFalse()) // Should not try to send if batch creation fails
}

func TestSender_SendAsBatch_SingleBatch_Success(t *testing.T) {
g := NewWithT(t)

azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{},
DoSendMessageBatch: func(ctx context.Context, batch *azservicebus.MessageBatch, options *azservicebus.SendMessageBatchOptions) error {
return nil
},
}

sender := NewSender(azSender, &SenderOptions{
Marshaller: &DefaultJSONMarshaller{},
})

// Create a message (the real batch will fail to add it due to zero size, but we can test the logic)
msg, err := sender.ToServiceBusMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())

options := &SendAsBatchOptions{AllowMultipleBatch: true}
err = sender.SendAsBatch(context.Background(), []*azservicebus.Message{msg}, options)

// no way to create a MessageBatch struct with a non-0 max bytes in test, so the best we can do is expect an error.
g.Expect(err).To(HaveOccurred()) // Real MessageBatch fails in tests due to zero max size
g.Expect(azSender.BatchesCreated).To(Equal(1))
g.Expect(azSender.BatchesSent).To(Equal(0)) // No batches sent due to AddMessage failure
}

func TestSender_SendAsBatch_MessageTooLarge_SingleMessage(t *testing.T) {
g := NewWithT(t)

azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{}, // Real MessageBatch with 0 max size
}

sender := NewSender(azSender, &SenderOptions{
Marshaller: &DefaultJSONMarshaller{},
})

// Create any message - it will be too large for the real MessageBatch with 0 max size
msg, err := sender.ToServiceBusMessage(context.Background(), "test")
g.Expect(err).ToNot(HaveOccurred())

options := &SendAsBatchOptions{AllowMultipleBatch: true}
err = sender.SendAsBatch(context.Background(), []*azservicebus.Message{msg}, options)

// Should fail because any message is too large for a real MessageBatch in tests
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("single message is too large"))
g.Expect(azSender.BatchesCreated).To(Equal(1))
}

func TestSender_SendAsBatch_SingleBatch_TooManyMessages_AllowMultipleFalse(t *testing.T) {
g := NewWithT(t)
azSender := &fakeAzSender{
NewMessageBatchReturnValue: &azservicebus.MessageBatch{},
}
sender := NewSender(azSender, &SenderOptions{
Marshaller: &DefaultJSONMarshaller{},
})

// Create multiple messages
messages := make([]*azservicebus.Message, 3)
for i := range messages {
msg, err := sender.ToServiceBusMessage(context.Background(), fmt.Sprintf("test%d", i))
g.Expect(err).ToNot(HaveOccurred())
messages[i] = msg
}

options := &SendAsBatchOptions{AllowMultipleBatch: false}
err := sender.SendAsBatch(context.Background(), messages, options)

// Should fail because messages don't fit in single batch and multiple batches not allowed
// The real MessageBatch has max size 0 in tests, so AddMessage will fail immediately
g.Expect(err).To(HaveOccurred())
g.Expect(azSender.BatchesCreated).To(Equal(1))
}

type fakeAzSender struct {
mu sync.RWMutex
DoSendMessage func(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error
Expand All @@ -355,6 +505,9 @@ type fakeAzSender struct {
NewMessageBatchErr error
SendMessageBatchReceivedValue *azservicebus.MessageBatch
CloseErr error

BatchesCreated int // Track how many batches were created
BatchesSent int // Track how many batches were sent
}

func (f *fakeAzSender) SendMessage(
Expand Down Expand Up @@ -382,6 +535,8 @@ func (f *fakeAzSender) SendMessageBatch(
defer f.mu.Unlock()
f.SendMessageBatchCalled = true
f.SendMessageBatchReceivedValue = batch
f.BatchesSent++

if f.DoSendMessageBatch != nil {
if err := f.DoSendMessageBatch(ctx, batch, options); err != nil {
return err
Expand All @@ -393,6 +548,10 @@ func (f *fakeAzSender) SendMessageBatch(
func (f *fakeAzSender) NewMessageBatch(
ctx context.Context,
options *azservicebus.MessageBatchOptions) (*azservicebus.MessageBatch, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.BatchesCreated++

return f.NewMessageBatchReturnValue, f.NewMessageBatchErr
}

Expand Down
Loading