Skip to content

Commit f49951e

Browse files
authored
fix: keep memory state consistent when recovering broadcast task from proto (#45788)
issue: #45782 pr: #45787 - because the zero value of the repeated field and bytes field in proto is ignored or treated as empty value but not nil pointer, so we need to fix the recovery info of the broadcast task from proto to keep the consistency of memory state. Signed-off-by: chyezh <[email protected]>
1 parent 0a89b23 commit f49951e

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

internal/streamingcoord/server/broadcaster/broadcast_task.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import (
2020
func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask {
2121
msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties)
2222
m := metrics.NewBroadcastTask(msg.MessageType(), proto.GetState(), msg.BroadcastHeader().ResourceKeys.Collect())
23+
24+
fixAckInfoFromProto(proto, len(msg.BroadcastHeader().VChannels))
25+
2326
bt := &broadcastTask{
2427
mu: sync.Mutex{},
2528
taskMetricsGuard: m,
@@ -40,6 +43,24 @@ func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadc
4043
return bt
4144
}
4245

46+
// fixAckInfoFromProto fixes the recovery info of the broadcast task.
47+
// because the zero value of the repeated field and bytes field in proto is ignored or treated as empty value but not nil pointer,
48+
// so we need to fix the recovery info of the broadcast task from proto to keep the consistency of memory state.
49+
func fixAckInfoFromProto(proto *streamingpb.BroadcastTask, vchannelCount int) {
50+
bitmap := make([]byte, vchannelCount)
51+
copy(bitmap, proto.AckedVchannelBitmap)
52+
53+
checkpoints := make([]*streamingpb.AckedCheckpoint, vchannelCount)
54+
for i, cp := range proto.AckedCheckpoints {
55+
if cp != nil && cp.TimeTick == 0 {
56+
cp = nil
57+
}
58+
checkpoints[i] = cp
59+
}
60+
proto.AckedVchannelBitmap = bitmap
61+
proto.AckedCheckpoints = checkpoints
62+
}
63+
4364
// newBroadcastTaskFromBroadcastMessage creates a new broadcast task from the broadcast message.
4465
func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask {
4566
m := metrics.NewBroadcastTask(msg.MessageType(), streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, msg.BroadcastHeader().ResourceKeys.Collect())

internal/streamingcoord/server/broadcaster/broadcaster_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"github.com/cockroachdb/errors"
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/mock"
12+
"github.com/stretchr/testify/require"
1213
"go.uber.org/atomic"
14+
"google.golang.org/protobuf/proto"
1315

1416
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
1517
"github.com/milvus-io/milvus/internal/distributed/streaming"
@@ -267,3 +269,82 @@ func createNewWaitAckBroadcastTaskFromMessage(
267269
AckedCheckpoints: acks,
268270
}
269271
}
272+
273+
func TestRecoverBroadcastTaskFromProto(t *testing.T) {
274+
task := createNewBroadcastTask(8, []string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c1"))
275+
b, err := proto.Marshal(task)
276+
require.NoError(t, err)
277+
278+
task = unmarshalTask(t, b, 3)
279+
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x00})
280+
assert.Len(t, task.AckedCheckpoints, 3)
281+
assert.Nil(t, task.AckedCheckpoints[0])
282+
assert.Nil(t, task.AckedCheckpoints[1])
283+
assert.Nil(t, task.AckedCheckpoints[2])
284+
285+
cp := &streamingpb.AckedCheckpoint{
286+
MessageId: walimplstest.NewTestMessageID(1).IntoProto(),
287+
LastConfirmedMessageId: walimplstest.NewTestMessageID(1).IntoProto(),
288+
TimeTick: 1,
289+
}
290+
291+
task.AckedCheckpoints[2] = cp
292+
task.AckedVchannelBitmap[2] = 0x01
293+
b, err = proto.Marshal(task)
294+
require.NoError(t, err)
295+
task = unmarshalTask(t, b, 3)
296+
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x01})
297+
assert.Len(t, task.AckedCheckpoints, 3)
298+
assert.Nil(t, task.AckedCheckpoints[0])
299+
assert.Nil(t, task.AckedCheckpoints[1])
300+
assert.NotNil(t, task.AckedCheckpoints[2])
301+
302+
task.AckedCheckpoints[2] = nil
303+
task.AckedVchannelBitmap[2] = 0x0
304+
task.AckedCheckpoints[0] = cp
305+
task.AckedVchannelBitmap[0] = 0x01
306+
b, err = proto.Marshal(task)
307+
require.NoError(t, err)
308+
task = unmarshalTask(t, b, 3)
309+
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x00, 0x00})
310+
assert.Len(t, task.AckedCheckpoints, 3)
311+
assert.NotNil(t, task.AckedCheckpoints[0])
312+
assert.Nil(t, task.AckedCheckpoints[1])
313+
assert.Nil(t, task.AckedCheckpoints[2])
314+
315+
task.AckedCheckpoints[0] = nil
316+
task.AckedVchannelBitmap[0] = 0x0
317+
task.AckedCheckpoints[1] = cp
318+
task.AckedVchannelBitmap[1] = 0x01
319+
b, err = proto.Marshal(task)
320+
require.NoError(t, err)
321+
task = unmarshalTask(t, b, 3)
322+
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x01, 0x00})
323+
assert.Len(t, task.AckedCheckpoints, 3)
324+
assert.Nil(t, task.AckedCheckpoints[0])
325+
assert.NotNil(t, task.AckedCheckpoints[1])
326+
assert.Nil(t, task.AckedCheckpoints[2])
327+
328+
task.AckedVchannelBitmap = []byte{0x01, 0x01, 0x01}
329+
task.AckedCheckpoints = []*streamingpb.AckedCheckpoint{
330+
cp,
331+
cp,
332+
cp,
333+
}
334+
b, err = proto.Marshal(task)
335+
require.NoError(t, err)
336+
task = unmarshalTask(t, b, 3)
337+
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x01, 0x01})
338+
assert.Len(t, task.AckedCheckpoints, 3)
339+
assert.NotNil(t, task.AckedCheckpoints[0])
340+
assert.NotNil(t, task.AckedCheckpoints[1])
341+
assert.NotNil(t, task.AckedCheckpoints[2])
342+
}
343+
344+
func unmarshalTask(t *testing.T, b []byte, vchannelCount int) *streamingpb.BroadcastTask {
345+
task := &streamingpb.BroadcastTask{}
346+
err := proto.Unmarshal(b, task)
347+
require.NoError(t, err)
348+
fixAckInfoFromProto(task, vchannelCount)
349+
return task
350+
}

0 commit comments

Comments
 (0)