diff --git a/frac/processor/aggregator.go b/frac/processor/aggregator.go index dc3f99f3..a7ca1582 100644 --- a/frac/processor/aggregator.go +++ b/frac/processor/aggregator.go @@ -401,7 +401,7 @@ func NewSourcedNodeIterator(sourced node.Sourced, ti tokenIndex, tids []uint32, func (s *SourcedNodeIterator) ConsumeTokenSource(lid node.LID) (uint32, bool, error) { for s.lastID.Less(lid) { - s.lastID, s.lastSource = s.sourcedNode.NextSourced() + s.lastID, s.lastSource = s.sourcedNode.NextSourcedGeq(lid) } exists := !s.lastID.IsNull() && s.lastID == lid diff --git a/frac/processor/aggregator_test.go b/frac/processor/aggregator_test.go index b0a173d3..aefb10da 100644 --- a/frac/processor/aggregator_test.go +++ b/frac/processor/aggregator_test.go @@ -179,6 +179,18 @@ func (m *MockNode) NextSourced() (node.LID, uint32) { return first.LID, first.Source } +func (m *MockNode) NextSourcedGeq(minLID node.LID) (node.LID, uint32) { + for len(m.Pairs) > 0 && m.Pairs[0].LID.Less(minLID) { + m.Pairs = m.Pairs[1:] + } + if len(m.Pairs) == 0 { + return node.NullLID(), 0 + } + first := m.Pairs[0] + m.Pairs = m.Pairs[1:] + return first.LID, first.Source +} + func TestTwoSourceAggregator(t *testing.T) { r := require.New(t) diff --git a/frac/sealed/lids/iterator_asc.go b/frac/sealed/lids/iterator_asc.go index 783a3d1f..a2fd0c5a 100644 --- a/frac/sealed/lids/iterator_asc.go +++ b/frac/sealed/lids/iterator_asc.go @@ -72,3 +72,34 @@ func (it *IteratorAsc) Next() node.LID { it.lids = it.lids[:i] return node.NewAscLID(lid) } + +// NextGeq returns the next (in reverse iteration order) LID that is <= maxLID. +func (it *IteratorAsc) NextGeq(nextID node.LID) node.LID { + for { + for len(it.lids) == 0 { + if !it.tryNextBlock { + return node.NullLID() + } + + it.loadNextLIDsBlock() + it.lids, it.tryNextBlock = it.narrowLIDsRange(it.lids, it.tryNextBlock) + it.counter.AddLIDsCount(len(it.lids)) + } + + // fast path: smallest remaining > nextID => skip entire block + // TODO(cheb0): We could also pass LID into narrowLIDsRange to perform block skipping once we add something like MinLID to LID block header + if it.lids[0] > nextID.Unpack() { + it.lids = it.lids[:0] + continue + } + + idx := sort.Search(len(it.lids), func(i int) bool { return it.lids[i] > nextID.Unpack() }) - 1 + if idx >= 0 { + lid := it.lids[idx] + it.lids = it.lids[:idx] + return node.NewAscLID(lid) + } + + it.lids = it.lids[:0] + } +} diff --git a/frac/sealed/lids/iterator_desc.go b/frac/sealed/lids/iterator_desc.go index 0485d41c..bd3edb91 100644 --- a/frac/sealed/lids/iterator_desc.go +++ b/frac/sealed/lids/iterator_desc.go @@ -71,3 +71,34 @@ func (it *IteratorDesc) Next() node.LID { it.lids = it.lids[1:] return node.NewDescLID(lid) } + +// NextGeq finds next greater or equal +func (it *IteratorDesc) NextGeq(nextID node.LID) node.LID { + for { + for len(it.lids) == 0 { + if !it.tryNextBlock { + return node.NullLID() + } + + it.loadNextLIDsBlock() // last chunk in block but not last for tid; need load next block + it.lids, it.tryNextBlock = it.narrowLIDsRange(it.lids, it.tryNextBlock) + it.counter.AddLIDsCount(len(it.lids)) // inc loaded LIDs count + } + + // fast path: last LID < nextID => skip the entire block + if nextID.Unpack() > it.lids[len(it.lids)-1] { + it.lids = it.lids[:0] + continue + } + + idx := sort.Search(len(it.lids), func(i int) bool { return it.lids[i] >= nextID.Unpack() }) + if idx < len(it.lids) { + it.lids = it.lids[idx:] + lid := it.lids[0] + it.lids = it.lids[1:] + return node.NewDescLID(lid) + } + + it.lids = it.lids[:0] + } +} diff --git a/node/bench_test.go b/node/bench_test.go index cf517686..ae60e84d 100644 --- a/node/bench_test.go +++ b/node/bench_test.go @@ -15,6 +15,11 @@ func newNodeStaticSize(r *rand.Rand, size int) *staticAsc { return &staticAsc{staticCursor: staticCursor{data: data}} } +func newNodeStaticSizeFixedDelta(size, start, delta int) *staticAsc { + data, _ := GenerateFixedDelta(size, start, delta) + return &staticAsc{staticCursor: staticCursor{data: data}} +} + func Generate(r *rand.Rand, n int) ([]uint32, uint32) { v := make([]uint32, n) last := uint32(1) @@ -25,6 +30,16 @@ func Generate(r *rand.Rand, n int) ([]uint32, uint32) { return v, last } +func GenerateFixedDelta(n, start, step int) ([]uint32, uint32) { + v := make([]uint32, n) + last := uint32(start) + for i := 0; i < len(v); i++ { + v[i] = last + last += uint32(step) + } + return v, last +} + func BenchmarkNot(b *testing.B) { sizes := []int{1000, 10_000, 1_000_000} @@ -165,6 +180,42 @@ func BenchmarkOrTree(b *testing.B) { } } +// BenchmarkOrTreeNextGeq checks the performance of NextGeq vs Next when no skipping occur and all node +// yield distinct values (no intersection between nodes) +func BenchmarkOrTreeNextGeq(b *testing.B) { + sizes := []int{1000, 10_000, 1_000_000} + // step is equal to total number of nodes, so that every node produces distinct values + step := 8 + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%d", s), func(b *testing.B) { + n1 := NewOr( + newNodeStaticSizeFixedDelta(s, 1, step), + newNodeStaticSizeFixedDelta(s, 5, step)) + n2 := NewOr( + newNodeStaticSizeFixedDelta(s, 2, step), + newNodeStaticSizeFixedDelta(s, 6, step)) + n3 := NewOr( + newNodeStaticSizeFixedDelta(s, 3, step), + newNodeStaticSizeFixedDelta(s, 8, step)) + n4 := NewOr( + newNodeStaticSizeFixedDelta(s, 4, step), + newNodeStaticSizeFixedDelta(s, 7, step)) + n12 := NewOr(n1, n2) + n34 := NewOr(n3, n4) + n := NewOr(n12, n34) + res := make([]uint32, 0, s*8) + + for b.Loop() { + res = readAllIntoGeq(n, res) + } + + assert.Equal(b, cap(res), s*8) + + }) + } +} + func BenchmarkComplex(b *testing.B) { sizes := []int{1000, 10_000, 1_000_000} diff --git a/node/lid.go b/node/lid.go index 620d4c1d..db242acc 100644 --- a/node/lid.go +++ b/node/lid.go @@ -41,6 +41,14 @@ func NewAscLID(lid uint32) LID { } } +func NewLID(lid uint32, asc bool) LID { + if asc { + return NewAscLID(lid) + } else { + return NewDescLID(lid) + } +} + // Less compares two values. It also does an implicit null check, since we store math.MaxUint32 for null values. // Which means if we call x.Less(y), then we know for sure that x is not null. Therefore, this Less call can work // as both "null check + less" combo. @@ -61,6 +69,22 @@ func (c LID) Eq(other LID) bool { return c.lid == other.lid } +func Max(left, right LID) LID { + if left.lid > right.lid { + return left + } else { + return right + } +} + +func Min(left, right LID) LID { + if left.lid < right.lid { + return left + } else { + return right + } +} + func (c LID) Unpack() uint32 { return c.lid ^ c.mask } diff --git a/node/node.go b/node/node.go index 6f87e9c3..99649c90 100644 --- a/node/node.go +++ b/node/node.go @@ -7,10 +7,13 @@ import ( type Node interface { fmt.Stringer // for testing Next() LID + // NextGeq returns next greater or equal (GEQ) lid + NextGeq(nextID LID) LID } type Sourced interface { fmt.Stringer // for testing // aggregation need source NextSourced() (id LID, source uint32) + NextSourcedGeq(nextLID LID) (id LID, source uint32) } diff --git a/node/node_and.go b/node/node_and.go index 51918f46..a978bcf9 100644 --- a/node/node_and.go +++ b/node/node_and.go @@ -31,10 +31,18 @@ func (n *nodeAnd) readRight() { n.rightID = n.right.Next() } +func (n *nodeAnd) readLeftGeq(nextID LID) { + n.leftID = n.left.NextGeq(nextID) +} + +func (n *nodeAnd) readRightGeq(nextID LID) { + n.rightID = n.right.NextGeq(nextID) +} + func (n *nodeAnd) Next() LID { for !n.leftID.IsNull() && !n.rightID.IsNull() && n.leftID != n.rightID { for !n.rightID.IsNull() && n.leftID.Less(n.rightID) { - n.readLeft() + n.readLeftGeq(n.rightID) } for !n.leftID.IsNull() && n.rightID.Less(n.leftID) { n.readRight() @@ -48,3 +56,26 @@ func (n *nodeAnd) Next() LID { n.readRight() return cur } + +func (n *nodeAnd) NextGeq(nextID LID) LID { + for { + for !n.leftID.IsNull() && !n.rightID.IsNull() && !n.leftID.Eq(n.rightID) { + for !n.rightID.IsNull() && n.leftID.Less(n.rightID) { + n.readLeftGeq(Max(n.rightID, nextID)) + } + for !n.leftID.IsNull() && n.rightID.Less(n.leftID) { + n.readRightGeq(Max(n.leftID, nextID)) + } + } + + if n.leftID.IsNull() || n.rightID.IsNull() { + return NullLID() + } + cur := n.leftID + n.readLeft() + n.readRight() + if nextID.LessOrEq(cur) { + return cur + } + } +} diff --git a/node/node_and_test.go b/node/node_and_test.go new file mode 100644 index 00000000..fbb1acdb --- /dev/null +++ b/node/node_and_test.go @@ -0,0 +1,60 @@ +package node + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNodeAnd_NextGeqAscending(t *testing.T) { + left := NewStatic([]uint32{1, 2, 7, 10, 20, 25, 26, 30, 50, 80, 90, 100}, false) + right := NewStatic([]uint32{1, 3, 4, 7, 9, 30, 40, 45, 60, 80, 110}, false) + + node := NewAnd(left, right) + + id := node.NextGeq(NewDescLID(7)) + assert.Equal(t, uint32(7), id.Unpack()) + + id = node.NextGeq(NewDescLID(50)) + assert.Equal(t, uint32(80), id.Unpack()) + + id = node.NextGeq(NewDescLID(50)) + assert.True(t, id.IsNull()) +} + +// TestNodeAnd_NextGeqCompatibility tests that just calling NextGeq with 0 passed as argument is equivalent to +// calling Next +func TestNodeAnd_NextGeqCompatibility(t *testing.T) { + for _, asc := range []bool{true, false} { + left := []uint32{rand.Uint32N(10)} + right := []uint32{rand.Uint32N(10)} + + for i := 1; i < 1000; i++ { + left = append(left, left[i-1]+rand.Uint32N(10)) + right = append(right, right[i-1]+rand.Uint32N(10)) + } + + node := NewAnd(NewStatic(left, asc), NewStatic(right, asc)) + nodeGeq := NewAnd(NewStatic(left, asc), NewStatic(right, asc)) + + var zero uint32 + if asc { + zero = math.MaxUint32 + } else { + zero = 0 + } + + for { + lid := node.Next() + lidGeq := nodeGeq.NextGeq(NewLID(zero, asc)) + + assert.Equal(t, lid, lidGeq) + + if lid.IsNull() { + break + } + } + } +} diff --git a/node/node_nand.go b/node/node_nand.go index 42c679ac..52f5ff01 100644 --- a/node/node_nand.go +++ b/node/node_nand.go @@ -43,3 +43,11 @@ func (n *nodeNAnd) Next() LID { } return NullLID() } + +func (n *nodeNAnd) NextGeq(nextID LID) LID { + lid := n.Next() + for lid.Less(nextID) { + lid = n.Next() + } + return lid +} diff --git a/node/node_nand_test.go b/node/node_nand_test.go new file mode 100644 index 00000000..61bbc1f4 --- /dev/null +++ b/node/node_nand_test.go @@ -0,0 +1,52 @@ +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNodeNAnd_NextGeq(t *testing.T) { + neg := NewStatic([]uint32{1, 2, 7, 10, 20, 25, 26, 30, 50, 80, 90, 100}, false) + reg := NewStatic([]uint32{1, 3, 4, 7, 9, 30, 40, 45, 60, 80, 110}, false) + + node := NewNAnd(neg, reg) + + id := node.NextGeq(NewDescLID(7)) + assert.Equal(t, uint32(9), id.Unpack()) + + id = node.NextGeq(NewDescLID(50)) + assert.Equal(t, uint32(60), id.Unpack()) + + id = node.NextGeq(NewDescLID(100)) + assert.Equal(t, uint32(110), id.Unpack()) + + id = node.NextGeq(NewDescLID(100)) + assert.True(t, id.IsNull()) +} + +func TestNodeNAnd_NextGeq_Reverse(t *testing.T) { + neg := NewStatic([]uint32{1, 2, 7, 10, 20, 25, 26, 30, 50, 80, 90, 100}, true) + reg := NewStatic([]uint32{1, 3, 4, 7, 9, 30, 40, 45, 60, 80, 110}, true) + + node := NewNAnd(neg, reg) + + id := node.NextGeq(NewAscLID(80)) + assert.Equal(t, uint32(60), id.Unpack()) + + id = node.NextGeq(NewAscLID(49)) + assert.Equal(t, uint32(45), id.Unpack()) + + // call with same nextID, should just return next value + id = node.NextGeq(NewAscLID(49)) + assert.Equal(t, uint32(40), id.Unpack()) + + id = node.NextGeq(NewAscLID(49)) + assert.Equal(t, uint32(9), id.Unpack()) + + id = node.NextGeq(NewAscLID(4)) + assert.Equal(t, uint32(4), id.Unpack()) + + id = node.NextGeq(NewAscLID(1)) + assert.True(t, id.IsNull()) +} diff --git a/node/node_or.go b/node/node_or.go index a31773a3..ab0bf30f 100644 --- a/node/node_or.go +++ b/node/node_or.go @@ -1,6 +1,8 @@ package node -import "fmt" +import ( + "fmt" +) type nodeOr struct { left Node @@ -29,6 +31,14 @@ func (n *nodeOr) readRight() { n.rightID = n.right.Next() } +func (n *nodeOr) readLeftGeq(nextID LID) { + n.leftID = n.left.NextGeq(nextID) +} + +func (n *nodeOr) readRightGeq(nextID LID) { + n.rightID = n.right.NextGeq(nextID) +} + func (n *nodeOr) Next() LID { if n.leftID.IsNull() && n.rightID.IsNull() { return n.leftID @@ -50,6 +60,23 @@ func (n *nodeOr) Next() LID { return cur } +func (n *nodeOr) NextGeq(nextID LID) LID { + // Fast path: if we at least left or right and there is nothing to skip, then choose lowest and return. + minID := Min(n.leftID, n.rightID) + if nextID.LessOrEq(minID) { + return n.Next() + } + + if n.leftID.Less(nextID) { + n.readLeftGeq(nextID) + } + if n.rightID.Less(nextID) { + n.readRightGeq(nextID) + } + + return n.Next() +} + type nodeOrAgg struct { left Sourced right Sourced @@ -80,6 +107,14 @@ func (n *nodeOrAgg) readRight() { n.rightID, n.rightSource = n.right.NextSourced() } +func (n *nodeOrAgg) readLeftGeq(nextID LID) { + n.leftID, n.leftSource = n.left.NextSourcedGeq(nextID) +} + +func (n *nodeOrAgg) readRightGeq(nextID LID) { + n.rightID, n.rightSource = n.right.NextSourcedGeq(nextID) +} + func (n *nodeOrAgg) NextSourced() (LID, uint32) { if n.leftID.IsNull() && n.rightID.IsNull() { return n.leftID, 0 @@ -95,3 +130,31 @@ func (n *nodeOrAgg) NextSourced() (LID, uint32) { n.readRight() return cur, curSource } + +func (n *nodeOrAgg) NextSourcedGeq(nextID LID) (LID, uint32) { + // Fast path: if we at least left or right and there is nothing to skip, then choose lowest and return. + minID := Min(n.leftID, n.rightID) + if nextID.LessOrEq(minID) { + if n.leftID.Less(n.rightID) { + cur := n.leftID + curSource := n.leftSource + n.readLeft() + return cur, curSource + } else { + // we don't need deduplication + cur := n.rightID + curSource := n.rightSource + n.readRight() + return cur, curSource + } + } + + if n.leftID.Less(nextID) { + n.readLeftGeq(nextID) + } + if n.rightID.Less(nextID) { + n.readRightGeq(nextID) + } + + return n.NextSourced() +} diff --git a/node/node_or_test.go b/node/node_or_test.go new file mode 100644 index 00000000..f484b2cd --- /dev/null +++ b/node/node_or_test.go @@ -0,0 +1,257 @@ +package node + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNodeOr_NextGeqAscending(t *testing.T) { + left := NewStatic([]uint32{2, 7, 10, 20, 25, 26, 30, 50}, false) + right := NewStatic([]uint32{1, 3, 4, 7, 9, 30, 40}, false) + + node := NewOr(left, right) + + id := node.NextGeq(NewDescLID(7)) + assert.Equal(t, uint32(7), id.Unpack()) + + id = node.NextGeq(NewDescLID(7)) + assert.Equal(t, uint32(9), id.Unpack()) + + id = node.NextGeq(NewDescLID(24)) + assert.Equal(t, uint32(25), id.Unpack()) + + id = node.NextGeq(NewDescLID(30)) + assert.Equal(t, uint32(30), id.Unpack()) + + id = node.NextGeq(NewDescLID(51)) + assert.True(t, id.IsNull()) +} + +// TestNodeOr_NextGeqCompatibility tests that just calling NextGeq with LID zero value passed as argument is equivalent to +// calling Next +func TestNodeOr_NextGeqCompatibility(t *testing.T) { + for _, asc := range []bool{true, false} { + left := []uint32{rand.Uint32N(10)} + right := []uint32{rand.Uint32N(10)} + + for i := 1; i < 1000; i++ { + left = append(left, left[i-1]+rand.Uint32N(10)) + right = append(right, right[i-1]+rand.Uint32N(10)) + } + + node := NewOr(NewStatic(left, asc), NewStatic(right, asc)) + nodeGeq := NewOr(NewStatic(left, asc), NewStatic(right, asc)) + + var zero uint32 + if asc { + zero = math.MaxUint32 + } else { + zero = 0 + } + + for { + lid := node.Next() + lidGeq := nodeGeq.NextGeq(NewLID(zero, asc)) + + assert.Equal(t, lid, lidGeq) + + if lid.IsNull() { + break + } + } + } +} + +// TestNodeOrAgg_NoDedup tests that nodeOrAgg yields values both from left and right for same lid. +func TestNodeOrAgg_NoDedup(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 5, 7}, false), 1) + right := NewSourcedNodeWrapper(NewStatic([]uint32{5, 8}, false), 2) + + orAgg := NewNodeOrAgg(left, right) + pairs := readAllSourced(orAgg) + + // expected sources for lid=5 + var sources []uint32 + + for _, p := range pairs { + id, src := p[0], p[1] + if id == 5 { + sources = append(sources, src) + } + } + + require.Len(t, sources, 2, "expected id 5 to be returned twice from both children") + assert.ElementsMatch(t, []uint32{1, 2}, sources, "expected id 5 from both left and right sources") +} + +func TestNodeOrAgg_MergeAscending(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 3, 5}, false), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{2, 4, 6}, false), 1) + + orAgg := NewNodeOrAgg(left, right) + got := readAllSourced(orAgg) + + want := [][2]uint32{ + {1, 0}, + {2, 1}, + {3, 0}, + {4, 1}, + {5, 0}, + {6, 1}, + } + + assert.Equal(t, want, got) +} + +func TestNodeOrAgg_MergeAscendingWithDups(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 2, 3, 5, 8}, false), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{2, 3, 4, 6, 8}, false), 1) + + orAgg := NewNodeOrAgg(left, right) + got := readAllSourced(orAgg) + + want := [][2]uint32{ + {1, 0}, + {2, 1}, + {2, 0}, + {3, 1}, + {3, 0}, + {4, 1}, + {5, 0}, + {6, 1}, + {8, 1}, + {8, 0}, + } + + assert.Equal(t, want, got) +} + +// TestNodeOrAgg_NextSourcedGeq tests we can navigate to a lid with NextGeq and do not skip it from +// both left and right sides (no deduplication like in ordinary OR tree) +func TestNodeOrAgg_NextSourcedGeq(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 2, 3, 5, 8, 15, 19}, false), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{2, 3, 4, 6, 8, 14, 20}, false), 1) + + orAgg := NewNodeOrAgg(left, right) + + id, source := orAgg.NextSourcedGeq(NewDescLID(3)) + assert.Equal(t, uint32(3), id.Unpack()) + assert.Equal(t, uint32(1), source) + + // 3 returned again, but with different source - no deduplication + id, source = orAgg.NextSourcedGeq(NewDescLID(3)) + assert.Equal(t, uint32(3), id.Unpack()) + assert.Equal(t, uint32(0), source) + + id, source = orAgg.NextSourcedGeq(NewDescLID(6)) + assert.Equal(t, uint32(6), id.Unpack()) + assert.Equal(t, uint32(1), source) + + id, source = orAgg.NextSourcedGeq(NewDescLID(17)) + assert.Equal(t, uint32(19), id.Unpack()) + assert.Equal(t, uint32(0), source) +} + +// TestNodeOrAgg_NextSourcedGeq tests we can navigate to a lid with NextGeq in reverse way and do not skip it from +// both left and right sides (no deduplication like in ordinary OR tree) +func TestNodeOrAgg_NextSourcedGeq_Reverse(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 2, 3, 5, 8, 15, 19}, true), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{2, 3, 4, 6, 8, 14, 20}, true), 1) + + orAgg := NewNodeOrAgg(left, right) + + id, source := orAgg.NextSourcedGeq(NewAscLID(8)) + assert.Equal(t, uint32(8), id.Unpack()) + assert.Equal(t, uint32(1), source) + + // 8 returned again, but with different source - no deduplication + id, source = orAgg.NextSourcedGeq(NewAscLID(8)) + assert.Equal(t, uint32(8), id.Unpack()) + assert.Equal(t, uint32(0), source) + + id, source = orAgg.NextSourcedGeq(NewAscLID(4)) + assert.Equal(t, uint32(4), id.Unpack()) + assert.Equal(t, uint32(1), source) + + id, source = orAgg.NextSourcedGeq(NewAscLID(1)) + assert.Equal(t, uint32(1), id.Unpack()) + assert.Equal(t, uint32(0), source) + + id, _ = orAgg.NextSourcedGeq(NewAscLID(1)) + assert.True(t, id.IsNull()) +} + +func TestNodeOrAgg_MergeDescending(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{1, 3, 5}, true), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{2, 4, 6}, true), 1) + + orAgg := NewNodeOrAgg(left, right) + got := readAllSourced(orAgg) + + want := [][2]uint32{ + {6, 1}, + {5, 0}, + {4, 1}, + {3, 0}, + {2, 1}, + {1, 0}, + } + + assert.Equal(t, want, got) +} + +func TestNodeOrAgg_EmptySide(t *testing.T) { + t.Run("empty_left", func(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic(nil, false), 0) + right := NewSourcedNodeWrapper(NewStatic([]uint32{10, 20}, false), 1) + + orAgg := NewNodeOrAgg(left, right) + got := readAllSourced(orAgg) + + want := [][2]uint32{ + {10, 1}, + {20, 1}, + } + + assert.Equal(t, want, got) + }) + + t.Run("empty_right", func(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic([]uint32{10, 20}, false), 0) + right := NewSourcedNodeWrapper(NewStatic(nil, false), 1) + + orAgg := NewNodeOrAgg(left, right) + got := readAllSourced(orAgg) + + want := [][2]uint32{ + {10, 0}, + {20, 0}, + } + + assert.Equal(t, want, got) + }) + + t.Run("both_empty", func(t *testing.T) { + left := NewSourcedNodeWrapper(NewStatic(nil, false), 0) + right := NewSourcedNodeWrapper(NewStatic(nil, false), 1) + + orAgg := NewNodeOrAgg(left, right) + id, _ := orAgg.NextSourced() + + assert.True(t, id.IsNull()) + }) +} + +func readAllSourced(n Sourced) [][2]uint32 { + var res [][2]uint32 + id, src := n.NextSourced() + for !id.IsNull() { + res = append(res, [2]uint32{id.Unpack(), src}) + id, src = n.NextSourced() + } + return res +} diff --git a/node/node_range.go b/node/node_range.go index a6f75467..bf31e4a5 100644 --- a/node/node_range.go +++ b/node/node_range.go @@ -24,3 +24,8 @@ func (n *nodeRange) Next() LID { n.curID = n.curID.Inc() return result } + +func (n *nodeRange) NextGeq(nextID LID) LID { + n.curID = Max(n.curID, nextID) + return n.Next() +} diff --git a/node/node_range_test.go b/node/node_range_test.go new file mode 100644 index 00000000..08bcc8cd --- /dev/null +++ b/node/node_range_test.go @@ -0,0 +1,56 @@ +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNodeRange_NextGeq_JumpToLast(t *testing.T) { + node := NewRange(NewDescLID(3), NewDescLID(10)) + + id := node.NextGeq(NewDescLID(10)) + assert.Equal(t, uint32(10), id.Unpack()) +} + +func TestNodeRange_NextGeq_SkipsWholeRange(t *testing.T) { + node := NewRange(NewDescLID(3), NewDescLID(10)) + + id := node.NextGeq(NewDescLID(11)) + assert.True(t, id.IsNull()) +} + +func TestNodeRange_NextGeq(t *testing.T) { + node := NewRange(NewDescLID(3), NewDescLID(10)) + + id := node.NextGeq(NewDescLID(5)) + assert.Equal(t, uint32(5), id.Unpack()) + + id = node.NextGeq(NewDescLID(5)) + assert.Equal(t, uint32(6), id.Unpack()) + + id = node.NextGeq(NewDescLID(10)) + assert.Equal(t, uint32(10), id.Unpack()) + + id = node.NextGeq(NewDescLID(10)) + assert.True(t, id.IsNull()) +} + +func TestNodeRange_NextGeq_Reverse(t *testing.T) { + node := NewRange(NewAscLID(10), NewAscLID(3)) + + id := node.NextGeq(NewAscLID(9)) + assert.Equal(t, uint32(9), id.Unpack()) + + id = node.NextGeq(NewAscLID(9)) + assert.Equal(t, uint32(8), id.Unpack()) + + id = node.NextGeq(NewAscLID(4)) + assert.Equal(t, uint32(4), id.Unpack()) + + id = node.NextGeq(NewAscLID(3)) + assert.Equal(t, uint32(3), id.Unpack()) + + id = node.NextGeq(NewAscLID(3)) + assert.True(t, id.IsNull()) +} diff --git a/node/node_static.go b/node/node_static.go index d0b17eab..baabfa37 100644 --- a/node/node_static.go +++ b/node/node_static.go @@ -1,16 +1,21 @@ package node -import "math" +import ( + "math" + "sort" +) type staticCursor struct { ptr int data []uint32 } +// staticAsc stores lids in data slice in ascending order, and iterates in increasing order type staticAsc struct { staticCursor } +// staticAsc stores lids in data slice in ascending order, but iterates from the end (in descending order) type staticDesc struct { staticCursor } @@ -43,6 +48,24 @@ func (n *staticAsc) Next() LID { return NewDescLID(cur) } +// NextGeq finds next greater or equals since iteration is in ascending order +func (n *staticAsc) NextGeq(nextID LID) LID { + if n.ptr >= len(n.data) { + return NullLID() + } + + from := n.ptr + idx := sort.Search(len(n.data)-from, func(i int) bool { return n.data[from+i] >= nextID.Unpack() }) + if idx >= len(n.data)-from { + return NullLID() + } + + i := from + idx + cur := n.data[i] + n.ptr = i + 1 + return NewDescLID(cur) +} + func (n *staticDesc) Next() LID { // staticDesc is used in docs order asc, hence we return LID with asc order if n.ptr < 0 { @@ -53,7 +76,22 @@ func (n *staticDesc) Next() LID { return NewAscLID(cur) } -// MakeStaticNodes is currently used only for tests +// NextGeq finds next less or equals since iteration is in descending order +func (n *staticDesc) NextGeq(nextID LID) LID { + if n.ptr < 0 { + return NullLID() + } + idx := sort.Search(n.ptr+1, func(i int) bool { return n.data[i] > nextID.Unpack() }) - 1 + if idx < 0 { + return NullLID() + } + + cur := n.data[idx] + n.ptr = idx - 1 + return NewAscLID(cur) +} + +// MakeStaticNodes is currently used only for tests func MakeStaticNodes(data [][]uint32) []Node { nodes := make([]Node, len(data)) for i, values := range data { diff --git a/node/node_static_test.go b/node/node_static_test.go new file mode 100644 index 00000000..9c061af3 --- /dev/null +++ b/node/node_static_test.go @@ -0,0 +1,69 @@ +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStaticAscNextGeq(t *testing.T) { + lids := []uint32{1, 3, 5, 7, 9} + n := NewStatic(lids, false).(*staticAsc) + + id := n.NextGeq(NewDescLID(0)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(1), id.Unpack()) + + id = n.NextGeq(NewDescLID(4)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(5), id.Unpack()) + + // 5 has already been returned, so the next value >= 5 is 7. + id = n.NextGeq(NewDescLID(5)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(7), id.Unpack()) + + id = n.NextGeq(NewDescLID(10)) + assert.True(t, id.IsNull()) +} + +func TestStaticDescNextGeq(t *testing.T) { + lids := []uint32{1, 3, 5, 7, 9} + n := NewStatic(lids, true).(*staticDesc) + + id := n.NextGeq(NewDescLID(10)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(9), id.Unpack()) + + id = n.NextGeq(NewDescLID(10)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(7), id.Unpack()) + + id = n.NextGeq(NewDescLID(10)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(5), id.Unpack()) +} + +func TestStaticDescNextGeq_WithThreshold(t *testing.T) { + lids := []uint32{1, 3, 5, 7, 9} + n := NewStatic(lids, true).(*staticDesc) + + id := n.NextGeq(NewDescLID(8)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(7), id.Unpack()) + + id = n.NextGeq(NewDescLID(8)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(5), id.Unpack()) + + id = n.NextGeq(NewDescLID(8)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(3), id.Unpack()) + + id = n.NextGeq(NewDescLID(8)) + assert.False(t, id.IsNull()) + assert.Equal(t, uint32(1), id.Unpack()) + + id = n.NextGeq(NewDescLID(8)) + assert.True(t, id.IsNull()) +} diff --git a/node/node_test.go b/node/node_test.go index b6730e21..0d036a32 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -14,6 +14,15 @@ func readAllInto(node Node, ids []uint32) []uint32 { return ids } +func readAllIntoGeq(node Node, ids []uint32) []uint32 { + id := node.Next() + for !id.IsNull() { + ids = append(ids, id.Unpack()) + id = node.NextGeq(id) + } + return ids +} + func readAll(node Node) []uint32 { return readAllInto(node, nil) } diff --git a/node/sourced_node_wrapper.go b/node/sourced_node_wrapper.go index 82b52449..7cc60015 100644 --- a/node/sourced_node_wrapper.go +++ b/node/sourced_node_wrapper.go @@ -14,6 +14,11 @@ func (w *sourcedNodeWrapper) NextSourced() (LID, uint32) { return cmp, w.source } +func (w *sourcedNodeWrapper) NextSourcedGeq(nextID LID) (LID, uint32) { + id := w.node.NextGeq(nextID) + return id, w.source +} + func NewSourcedNodeWrapper(d Node, source int) Sourced { return &sourcedNodeWrapper{node: d, source: uint32(source)} }