Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ Optimizations
* GITHUB#15474: Use bulk scoring provided by RandomVectorScorers for new scalar quantized formats provided through
Lucene104ScalarQuantizedVectorsFormat and Lucene104HnswScalarQuantizedVectorsFormat (Ben Trent)

* GITHUB#15494: Set minimum competitive score in FirstPassGroupingCollector (Binlong Gao)

Bug Fixes
---------------------
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.search.grouping;

import org.apache.lucene.util.NumericUtils;

/**
* An encoder do encode (doc, score) pair as a long whose sort order is same as {@code (o1, o2) ->
* Float.compare(o1.score, o2.score)).thenComparing(Comparator.comparingInt((ScoreDoc o) ->
* o.doc).reversed())}
*/
class DocScoreEncoder {
static long encode(int docId, float score) {
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | (Integer.MAX_VALUE - docId);
}

static float toScore(long value) {
return NumericUtils.sortableIntToFloat((int) (value >>> 32));
}

static int docId(long value) {
return Integer.MAX_VALUE - ((int) value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Comparator;
import java.util.HashMap;
import java.util.TreeSet;
import java.util.concurrent.atomic.LongAccumulator;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
Expand All @@ -31,6 +32,7 @@
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.CollectionUtil;

/**
Expand All @@ -43,6 +45,8 @@
*/
public class FirstPassGroupingCollector<T> extends SimpleCollector {

private static final int DEFAULT_INTERVAL = 0x3ff;

private final GroupSelector<T> groupSelector;
private final boolean ignoreDocsWithoutGroupField;

Expand All @@ -53,6 +57,14 @@ public class FirstPassGroupingCollector<T> extends SimpleCollector {
private final boolean needsScores;
private final HashMap<T, CollectedSearchGroup<T>> groupMap;
private final int compIDXEnd;
private final boolean canSetMinScore;
private final LongAccumulator minScoreAcc;
private final int totalHitsThreshold;

private Scorable scorer;
private float minCompetitiveScore;
private int totalHitCount;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;

// Set once we reach topNGroups unique groups:
/**
Expand Down Expand Up @@ -94,6 +106,32 @@ public FirstPassGroupingCollector(
Sort groupSort,
int topNGroups,
boolean ignoreDocsWithoutGroupField) {
this(groupSelector, groupSort, topNGroups, ignoreDocsWithoutGroupField, Integer.MAX_VALUE);
}

/**
* Create the first pass collector with ignoreDocsWithoutGroupField
*
* @param groupSelector a GroupSelector used to defined groups
* @param groupSort The {@link Sort} used to sort the groups. The top sorted document within each
* group according to groupSort, determines how that group sorts against other groups. This
* must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE.
* @param topNGroups How many top groups to keep.
* @param ignoreDocsWithoutGroupField if true, ignore documents that don't have the group field
* instead of putting them in a null group
* @param totalHitsThreshold totalHitsThreshold the number of hits to collect accurately. If the
* number of hits is greater than threshold, hits beyond the threshold will be collected as
* well, but the accuracy of the total hits will be reduced. This is used to reduce the memory
* footprint of the collector when collecting hits beyond the threshold. If set to {@link
* Integer#MAX_VALUE}, total hits is calculated precisely.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public FirstPassGroupingCollector(
GroupSelector<T> groupSelector,
Sort groupSort,
int topNGroups,
boolean ignoreDocsWithoutGroupField,
int totalHitsThreshold) {
this.groupSelector = groupSelector;
this.ignoreDocsWithoutGroupField = ignoreDocsWithoutGroupField;
if (topNGroups < 1) {
Expand Down Expand Up @@ -121,11 +159,32 @@ public FirstPassGroupingCollector(

spareSlot = topNGroups;
groupMap = CollectionUtil.newHashMap(topNGroups);

this.totalHitsThreshold = totalHitsThreshold;
this.canSetMinScore =
sortFields.length > 0
&& sortFields[0].getType() == SortField.Type.SCORE
&& this.totalHitsThreshold != Integer.MAX_VALUE;
this.minScoreAcc = canSetMinScore ? new LongAccumulator(Math::max, Long.MIN_VALUE) : null;
}

public int getTotalHitCount() {
return this.totalHitCount;
}

public TotalHits.Relation getTotalHitsRelation() {
return this.totalHitsRelation;
}

@Override
public ScoreMode scoreMode() {
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
if (canSetMinScore) {
return ScoreMode.TOP_SCORES;
} else if (needsScores) {
return ScoreMode.COMPLETE;
} else {
return ScoreMode.COMPLETE_NO_SCORES;
}
}

/**
Expand Down Expand Up @@ -176,10 +235,16 @@ public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOExcepti

@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
groupSelector.setScorer(scorer);
for (LeafFieldComparator comparator : leafComparators) {
comparator.setScorer(scorer);
}
if (minScoreAcc == null) {
updateMinCompetitiveScore();
} else {
updateGlobalMinCompetitiveScore();
}
}

private boolean isCompetitive(int doc) throws IOException {
Expand Down Expand Up @@ -214,6 +279,10 @@ private boolean isCompetitive(int doc) throws IOException {

@Override
public void collect(int doc) throws IOException {
int hitCountSoFar = ++totalHitCount;
if (minScoreAcc != null && (hitCountSoFar & DEFAULT_INTERVAL) == 0) {
updateGlobalMinCompetitiveScore();
}

if (isCompetitive(doc) == false) {
return;
Expand Down Expand Up @@ -262,6 +331,8 @@ private void collectNewGroup(final int doc) throws IOException {
// number of groups; from here on we will drop
// bottom group when we insert new one:
buildSortedSet();
// if groups is full, we can propagate the min competitive score to the scorer
updateMinCompetitiveScore();
}

} else {
Expand Down Expand Up @@ -289,6 +360,7 @@ private void collectNewGroup(final int doc) throws IOException {
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(lastComparatorSlot);
}
updateMinCompetitiveScore();
}
}

Expand Down Expand Up @@ -354,6 +426,8 @@ private void collectExistingGroup(final int doc, final CollectedSearchGroup<T> g
for (LeafFieldComparator fc : leafComparators) {
fc.setBottom(newLast.comparatorSlot);
}
// now the groups is full, we can propagate the min competitive score to the scorer
updateMinCompetitiveScore();
}
}
}
Expand Down Expand Up @@ -384,6 +458,40 @@ public int compare(CollectedSearchGroup<?> o1, CollectedSearchGroup<?> o2) {
}
}

private void updateMinCompetitiveScore() throws IOException {
if (canSetMinScore
&& orderedGroups != null
&& scorer != null
&& totalHitCount > totalHitsThreshold) {
// Get the score of the bottom group
CollectedSearchGroup<T> bottomGroup = orderedGroups.last();
float bottomScore = (float) comparators[0].value(bottomGroup.comparatorSlot);
if (bottomScore > minCompetitiveScore) {
minCompetitiveScore = bottomScore;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
scorer.setMinCompetitiveScore(minCompetitiveScore);
if (minScoreAcc != null) {
minScoreAcc.accumulate(DocScoreEncoder.encode(docBase, bottomScore));
}
}
}
}

protected void updateGlobalMinCompetitiveScore() throws IOException {
if (canSetMinScore && scorer != null) {
long maxMinScore = minScoreAcc.get();
if (maxMinScore != Long.MIN_VALUE) {
float score = DocScoreEncoder.toScore(maxMinScore);
score = docBase >= DocScoreEncoder.docId(maxMinScore) ? Math.nextUp(score) : score;
if (score > minCompetitiveScore) {
minCompetitiveScore = score;
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
scorer.setMinCompetitiveScore(score);
}
}
}
}

@Override
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
docBase = readerContext.docBase;
Expand Down
Loading