diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 67767f23d839..ac34f710d8e2 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -268,6 +268,8 @@ Optimizations * GITHUB#15474: Use bulk scoring provided by RandomVectorScorers for new scalar quantized formats provided through Lucene104ScalarQuantizedVectorsFormat and Lucene104HnswScalarQuantizedVectorsFormat (Ben Trent) +* GITHUB#15494: Set minimum competitive score in FirstPassGroupingCollector (Binlong Gao) + * GITHUB#15500: Use bulk scoring for filtered HNSW search and for entry-point scoring in the graph. This should provide speed improvements when using vector scorers that satisfy the bulk scoring interface. (Ben Trent) diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java index b36200e48963..bdb1de6abbe9 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsCollector.java @@ -321,7 +321,7 @@ public void testSetMinCompetitiveScore() throws Exception { IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE)); Document doc = new Document(); - w.addDocuments(Arrays.asList(doc, doc, doc, doc)); + w.addDocuments(Arrays.asList(doc, doc, doc, doc, doc)); w.flush(); w.addDocuments(Arrays.asList(doc, doc)); w.flush(); @@ -432,7 +432,7 @@ public void testTotalHits() throws Exception { leafCollector.setScorer(scorer); scorer.score = 3; - leafCollector.collect(1); + leafCollector.collect(0); scorer.score = 4; leafCollector.collect(1); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopFieldCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTopFieldCollector.java index bcbc8cac50d2..d0fd17f6cddd 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTopFieldCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTopFieldCollector.java @@ -222,13 +222,13 @@ public void testTotalHits() throws Exception { scorer.score = 3; if (totalHitsThreshold < 3) { - expectThrows(CollectionTerminatedException.class, () -> leafCollector2.collect(1)); + expectThrows(CollectionTerminatedException.class, () -> leafCollector2.collect(0)); TopDocs topDocs = collector.topDocs(); assertEquals( new TotalHits(3, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), topDocs.totalHits); continue; } else { - leafCollector2.collect(1); + leafCollector2.collect(0); } scorer.score = 4; @@ -271,7 +271,7 @@ public void testSetMinCompetitiveScore() throws Exception { IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE)); Document doc = new Document(); - w.addDocuments(Arrays.asList(doc, doc, doc, doc)); + w.addDocuments(Arrays.asList(doc, doc, doc, doc, doc)); w.flush(); w.addDocuments(Arrays.asList(doc, doc)); w.flush(); diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/DocScoreEncoder.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/DocScoreEncoder.java new file mode 100644 index 000000000000..d6b04088c638 --- /dev/null +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/DocScoreEncoder.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.grouping; + +import org.apache.lucene.util.NumericUtils; + +/** + * An encoder do encode (doc, score) pair as a long whose sort order is same as {@code (o1, o2) -> + * Float.compare(o1.score, o2.score)).thenComparing(Comparator.comparingInt((ScoreDoc o) -> + * o.doc).reversed())} + */ +class DocScoreEncoder { + static long encode(int docId, float score) { + return (((long) NumericUtils.floatToSortableInt(score)) << 32) | (Integer.MAX_VALUE - docId); + } + + static float toScore(long value) { + return NumericUtils.sortableIntToFloat((int) (value >>> 32)); + } + + static int docId(long value) { + return Integer.MAX_VALUE - ((int) value); + } +} diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/FirstPassGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/FirstPassGroupingCollector.java index 802ceb4feaa1..ec8f2ff97dbe 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/FirstPassGroupingCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/FirstPassGroupingCollector.java @@ -22,6 +22,7 @@ import java.util.Comparator; import java.util.HashMap; import java.util.TreeSet; +import java.util.concurrent.atomic.LongAccumulator; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.FieldComparator; import org.apache.lucene.search.LeafFieldComparator; @@ -31,6 +32,7 @@ import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; /** * FirstPassGroupingCollector is the first of two passes necessary to collect grouped hits. This @@ -42,6 +44,8 @@ */ public class FirstPassGroupingCollector extends SimpleCollector { + private static final int DEFAULT_INTERVAL = 0x3ff; + private final GroupSelector groupSelector; private final boolean ignoreDocsWithoutGroupField; @@ -52,6 +56,14 @@ public class FirstPassGroupingCollector extends SimpleCollector { private final boolean needsScores; private final HashMap> groupMap; private final int compIDXEnd; + private final boolean canSetMinScore; + private final LongAccumulator minScoreAcc; + private final int totalHitsThreshold; + + private Scorable scorer; + private float minCompetitiveScore; + private int totalHitCount; + private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; // Set once we reach topNGroups unique groups: /** @@ -93,6 +105,32 @@ public FirstPassGroupingCollector( Sort groupSort, int topNGroups, boolean ignoreDocsWithoutGroupField) { + this(groupSelector, groupSort, topNGroups, ignoreDocsWithoutGroupField, Integer.MAX_VALUE); + } + + /** + * Create the first pass collector with ignoreDocsWithoutGroupField + * + * @param groupSelector a GroupSelector used to defined groups + * @param groupSort The {@link Sort} used to sort the groups. The top sorted document within each + * group according to groupSort, determines how that group sorts against other groups. This + * must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE. + * @param topNGroups How many top groups to keep. + * @param ignoreDocsWithoutGroupField if true, ignore documents that don't have the group field + * instead of putting them in a null group + * @param totalHitsThreshold totalHitsThreshold the number of hits to collect accurately. If the + * number of hits is greater than threshold, hits beyond the threshold will be collected as + * well, but the accuracy of the total hits will be reduced. This is used to reduce the memory + * footprint of the collector when collecting hits beyond the threshold. If set to {@link + * Integer#MAX_VALUE}, total hits is calculated precisely. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public FirstPassGroupingCollector( + GroupSelector groupSelector, + Sort groupSort, + int topNGroups, + boolean ignoreDocsWithoutGroupField, + int totalHitsThreshold) { this.groupSelector = groupSelector; this.ignoreDocsWithoutGroupField = ignoreDocsWithoutGroupField; if (topNGroups < 1) { @@ -120,11 +158,32 @@ public FirstPassGroupingCollector( spareSlot = topNGroups; groupMap = HashMap.newHashMap(topNGroups); + + this.totalHitsThreshold = totalHitsThreshold; + this.canSetMinScore = + sortFields.length > 0 + && sortFields[0].getType() == SortField.Type.SCORE + && this.totalHitsThreshold != Integer.MAX_VALUE; + this.minScoreAcc = canSetMinScore ? new LongAccumulator(Math::max, Long.MIN_VALUE) : null; + } + + public int getTotalHitCount() { + return this.totalHitCount; + } + + public TotalHits.Relation getTotalHitsRelation() { + return this.totalHitsRelation; } @Override public ScoreMode scoreMode() { - return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; + if (canSetMinScore) { + return ScoreMode.TOP_SCORES; + } else if (needsScores) { + return ScoreMode.COMPLETE; + } else { + return ScoreMode.COMPLETE_NO_SCORES; + } } /** @@ -175,10 +234,16 @@ public Collection> getTopGroups(int groupOffset) throws IOExcepti @Override public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; groupSelector.setScorer(scorer); for (LeafFieldComparator comparator : leafComparators) { comparator.setScorer(scorer); } + if (minScoreAcc == null) { + updateMinCompetitiveScore(); + } else { + updateGlobalMinCompetitiveScore(); + } } private boolean isCompetitive(int doc) throws IOException { @@ -213,6 +278,10 @@ private boolean isCompetitive(int doc) throws IOException { @Override public void collect(int doc) throws IOException { + int hitCountSoFar = ++totalHitCount; + if (minScoreAcc != null && (hitCountSoFar & DEFAULT_INTERVAL) == 0) { + updateGlobalMinCompetitiveScore(); + } if (isCompetitive(doc) == false) { return; @@ -261,6 +330,8 @@ private void collectNewGroup(final int doc) throws IOException { // number of groups; from here on we will drop // bottom group when we insert new one: buildSortedSet(); + // if groups is full, we can propagate the min competitive score to the scorer + updateMinCompetitiveScore(); } } else { @@ -288,6 +359,7 @@ private void collectNewGroup(final int doc) throws IOException { for (LeafFieldComparator fc : leafComparators) { fc.setBottom(lastComparatorSlot); } + updateMinCompetitiveScore(); } } @@ -353,6 +425,8 @@ private void collectExistingGroup(final int doc, final CollectedSearchGroup g for (LeafFieldComparator fc : leafComparators) { fc.setBottom(newLast.comparatorSlot); } + // now the groups is full, we can propagate the min competitive score to the scorer + updateMinCompetitiveScore(); } } } @@ -383,6 +457,40 @@ public int compare(CollectedSearchGroup o1, CollectedSearchGroup o2) { } } + private void updateMinCompetitiveScore() throws IOException { + if (canSetMinScore + && orderedGroups != null + && scorer != null + && totalHitCount > totalHitsThreshold) { + // Get the score of the bottom group + CollectedSearchGroup bottomGroup = orderedGroups.last(); + float bottomScore = (float) comparators[0].value(bottomGroup.comparatorSlot); + if (bottomScore > minCompetitiveScore) { + minCompetitiveScore = bottomScore; + totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + scorer.setMinCompetitiveScore(minCompetitiveScore); + if (minScoreAcc != null) { + minScoreAcc.accumulate(DocScoreEncoder.encode(docBase, bottomScore)); + } + } + } + } + + protected void updateGlobalMinCompetitiveScore() throws IOException { + if (canSetMinScore && scorer != null) { + long maxMinScore = minScoreAcc.get(); + if (maxMinScore != Long.MIN_VALUE) { + float score = DocScoreEncoder.toScore(maxMinScore); + score = docBase >= DocScoreEncoder.docId(maxMinScore) ? Math.nextUp(score) : score; + if (score > minCompetitiveScore) { + minCompetitiveScore = score; + totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + scorer.setMinCompetitiveScore(score); + } + } + } + } + @Override protected void doSetNextReader(LeafReaderContext readerContext) throws IOException { docBase = readerContext.docBase; diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java index 289a7155aadf..867259f929fc 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java @@ -36,7 +36,9 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.MultiDocValues; import org.apache.lucene.index.NumericDocValues; @@ -49,9 +51,11 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; @@ -272,6 +276,170 @@ public void testAllDocsWithoutGroupField() throws IOException { dir.close(); } + public void testTotalHitsThreshold() throws Exception { + Directory dir = newDirectory(); + RandomIndexWriter w = + new RandomIndexWriter(random(), dir, newIndexWriterConfig(new MockAnalyzer(random()))); + + // Add documents with and without group field + Document doc = new Document(); + addGroupField(doc, "group", "group1"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + doc = new Document(); + addGroupField(doc, "group", "group2"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + // Document without group field + doc = new Document(); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + IndexSearcher searcher = newSearcher(w.getReader()); + w.close(); + + // Test ignoreDocsWithoutGroupField = true + FirstPassGroupingCollector collector1 = + new FirstPassGroupingCollector<>( + new TermGroupSelector("group"), Sort.RELEVANCE, 10, true, Integer.MAX_VALUE); + searcher.search(new TermQuery(new Term("content", "test")), collector1); + + Collection> groups1 = collector1.getTopGroups(0); + assertEquals(2, groups1.size()); // Should ignore null group + assertEquals(3, collector1.getTotalHitCount()); + assertEquals(TotalHits.Relation.EQUAL_TO, collector1.getTotalHitsRelation()); + + // Test ignoreDocsWithoutGroupField = false + FirstPassGroupingCollector collector2 = + new FirstPassGroupingCollector<>( + new TermGroupSelector("group"), Sort.RELEVANCE, 10, false, Integer.MAX_VALUE); + searcher.search(new TermQuery(new Term("content", "test")), collector2); + + Collection> groups2 = collector2.getTopGroups(0); + assertEquals(3, groups2.size()); // Should include null group + assertEquals(3, collector2.getTotalHitCount()); + + // Test totalHitsThreshold with score-based sorting + FirstPassGroupingCollector collector3 = + new FirstPassGroupingCollector<>( + new TermGroupSelector("group"), Sort.RELEVANCE, 10, false, 1); + searcher.search(new TermQuery(new Term("content", "test")), collector3); + + assertEquals(ScoreMode.TOP_SCORES, collector3.scoreMode()); + assertTrue( + collector3.getTotalHitCount() <= 3); // May skip some docs due to min competitive score + + searcher.getIndexReader().close(); + dir.close(); + } + + private static class TestScorer extends Scorable { + float score; + Float minCompetitiveScore = null; + + @Override + public void setMinCompetitiveScore(float minCompetitiveScore) { + this.minCompetitiveScore = minCompetitiveScore; + } + + @Override + public float score() throws IOException { + return score; + } + } + + public void testSetMinCompetitiveScore() throws Exception { + Directory dir = newDirectory(); + IndexWriter w = + new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy())); + + // Add documents with different scores + Document doc = new Document(); + addGroupField(doc, "group", "group1"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + doc = new Document(); + addGroupField(doc, "group", "group2"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + doc = new Document(); + addGroupField(doc, "group", "group3"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + doc = new Document(); + addGroupField(doc, "group", "group4"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + doc = new Document(); + addGroupField(doc, "group", "group5"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + w.flush(); + + doc = new Document(); + addGroupField(doc, "group", "group4"); + doc.add(new TextField("content", "test", Field.Store.YES)); + w.addDocument(doc); + + IndexReader reader = DirectoryReader.open(w); + w.close(); + + // Test with score-based sorting and low totalHitsThreshold to enable min competitive score + FirstPassGroupingCollector collector = + new FirstPassGroupingCollector<>( + new TermGroupSelector("group"), Sort.RELEVANCE, 2, false, 2); + + TestScorer scorer = new TestScorer(); + LeafCollector leafCollector = collector.getLeafCollector(reader.leaves().get(0)); + leafCollector.setScorer(scorer); + assertNull(scorer.minCompetitiveScore); + + scorer.score = 1.0f; + leafCollector.collect(0); + assertNull(scorer.minCompetitiveScore); // Not set yet, need more groups + + scorer.score = 2.0f; + leafCollector.collect(1); + assertNull(scorer.minCompetitiveScore); // Still building up groups + + scorer.score = 3.0f; + leafCollector.collect(2); + assertNotNull(scorer.minCompetitiveScore); + assertTrue(scorer.minCompetitiveScore > 0); + + // Test that low-scoring document doesn't update min competitive score + scorer.score = 0.5f; + scorer.minCompetitiveScore = Float.NaN; + leafCollector.collect(3); // This should be skipped due to isCompetitive check + assertTrue(Float.isNaN(scorer.minCompetitiveScore)); + + // Test higher score updates min competitive score + scorer.score = 4.0f; + leafCollector.collect(4); + assertNotNull(scorer.minCompetitiveScore); + assertTrue(scorer.minCompetitiveScore >= 2.0f); + + // Test that min competitive score is set on new leaf collectors + if (reader.leaves().size() > 1) { + TestScorer newScorer = new TestScorer(); + LeafCollector newLeafCollector = collector.getLeafCollector(reader.leaves().get(1)); + newLeafCollector.setScorer(newScorer); + // Min competitive score should be propagated to new scorer + assertNotNull(newScorer.minCompetitiveScore); + assertTrue(newScorer.minCompetitiveScore > 0); + } + + reader.close(); + dir.close(); + } + private void addGroupField(Document doc, String groupField, String value) { doc.add(new SortedDocValuesField(groupField, new BytesRef(value))); }