+++ /dev/null
-package org.apache.lucene.search.grouping;
-
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import java.io.IOException;
-import java.util.*;
-
-import org.apache.lucene.search.FieldComparator;
-import org.apache.lucene.search.Sort;
-import org.apache.lucene.search.SortField;
-
-/**
- * Represents a group that is found during the first pass search.
- *
- * @lucene.experimental
- */
-public class SearchGroup<GROUP_VALUE_TYPE> {
-
- /** The value that defines this group */
- public GROUP_VALUE_TYPE groupValue;
-
- /** The sort values used during sorting. These are the
- * groupSort field values of the highest rank document
- * (by the groupSort) within the group. Can be
- * <code>null</code> if <code>fillFields=false</code> had
- * been passed to {@link AbstractFirstPassGroupingCollector#getTopGroups} */
- public Object[] sortValues;
-
- @Override
- public String toString() {
- return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
- }
-
- private static class ShardIter<T> {
- public final Iterator<SearchGroup<T>> iter;
- public final int shardIndex;
-
- public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
- this.shardIndex = shardIndex;
- iter = shard.iterator();
- assert iter.hasNext();
- }
-
- public SearchGroup<T> next() {
- assert iter.hasNext();
- final SearchGroup<T> group = iter.next();
- if (group.sortValues == null) {
- throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
- }
- return group;
- }
-
- @Override
- public String toString() {
- return "ShardIter(shard=" + shardIndex + ")";
- }
- }
-
- // Holds all shards currently on the same group
- private static class MergedGroup<T> {
-
- // groupValue may be null!
- public final T groupValue;
-
- public Object[] topValues;
- public final List<ShardIter<T>> shards = new ArrayList<ShardIter<T>>();
- public int minShardIndex;
- public boolean processed;
- public boolean inQueue;
-
- public MergedGroup(T groupValue) {
- this.groupValue = groupValue;
- }
-
- // Only for assert
- private boolean neverEquals(Object _other) {
- if (_other instanceof MergedGroup) {
- MergedGroup other = (MergedGroup) _other;
- if (groupValue == null) {
- assert other.groupValue != null;
- } else {
- assert !groupValue.equals(other.groupValue);
- }
- }
- return true;
- }
-
- @Override
- public boolean equals(Object _other) {
- // We never have another MergedGroup instance with
- // same groupValue
- assert neverEquals(_other);
-
- if (_other instanceof MergedGroup) {
- MergedGroup other = (MergedGroup) _other;
- if (groupValue == null) {
- return other == null;
- } else {
- return groupValue.equals(other);
- }
- } else {
- return false;
- }
- }
-
- @Override
- public int hashCode() {
- if (groupValue == null) {
- return 0;
- } else {
- return groupValue.hashCode();
- }
- }
- }
-
- private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
-
- public final FieldComparator[] comparators;
- public final int[] reversed;
-
- public GroupComparator(Sort groupSort) throws IOException {
- final SortField[] sortFields = groupSort.getSort();
- comparators = new FieldComparator[sortFields.length];
- reversed = new int[sortFields.length];
- for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
- final SortField sortField = sortFields[compIDX];
- comparators[compIDX] = sortField.getComparator(1, compIDX);
- reversed[compIDX] = sortField.getReverse() ? -1 : 1;
- }
- }
-
- @SuppressWarnings("unchecked")
- public int compare(MergedGroup<T> group, MergedGroup<T> other) {
- if (group == other) {
- return 0;
- }
- //System.out.println("compare group=" + group + " other=" + other);
- final Object[] groupValues = group.topValues;
- final Object[] otherValues = other.topValues;
- //System.out.println(" groupValues=" + groupValues + " otherValues=" + otherValues);
- for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
- final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
- otherValues[compIDX]);
- if (c != 0) {
- return c;
- }
- }
-
- // Tie break by min shard index:
- assert group.minShardIndex != other.minShardIndex;
- return group.minShardIndex - other.minShardIndex;
- }
- }
-
- private static class GroupMerger<T> {
-
- private final GroupComparator<T> groupComp;
- private final SortedSet<MergedGroup<T>> queue;
- private final Map<T,MergedGroup<T>> groupsSeen;
-
- public GroupMerger(Sort groupSort) throws IOException {
- groupComp = new GroupComparator<T>(groupSort);
- queue = new TreeSet<MergedGroup<T>>(groupComp);
- groupsSeen = new HashMap<T,MergedGroup<T>>();
- }
-
- @SuppressWarnings("unchecked")
- private void updateNextGroup(int topN, ShardIter<T> shard) {
- while(shard.iter.hasNext()) {
- final SearchGroup<T> group = shard.next();
- MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
- final boolean isNew = mergedGroup == null;
- //System.out.println(" next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
-
- if (isNew) {
- // Start a new group:
- //System.out.println(" new");
- mergedGroup = new MergedGroup<T>(group.groupValue);
- mergedGroup.minShardIndex = shard.shardIndex;
- assert group.sortValues != null;
- mergedGroup.topValues = group.sortValues;
- groupsSeen.put(group.groupValue, mergedGroup);
- mergedGroup.inQueue = true;
- queue.add(mergedGroup);
- } else if (mergedGroup.processed) {
- // This shard produced a group that we already
- // processed; move on to next group...
- continue;
- } else {
- //System.out.println(" old");
- boolean competes = false;
- for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
- final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
- mergedGroup.topValues[compIDX]);
- if (cmp < 0) {
- // Definitely competes
- competes = true;
- break;
- } else if (cmp > 0) {
- // Definitely does not compete
- break;
- } else if (compIDX == groupComp.comparators.length-1) {
- if (shard.shardIndex < mergedGroup.minShardIndex) {
- competes = true;
- }
- }
- }
-
- //System.out.println(" competes=" + competes);
-
- if (competes) {
- // Group's sort changed -- remove & re-insert
- if (mergedGroup.inQueue) {
- queue.remove(mergedGroup);
- }
- mergedGroup.topValues = group.sortValues;
- mergedGroup.minShardIndex = shard.shardIndex;
- queue.add(mergedGroup);
- mergedGroup.inQueue = true;
- }
- }
-
- mergedGroup.shards.add(shard);
- break;
- }
-
- // Prune un-competitive groups:
- while(queue.size() > topN) {
- // TODO java 1.6: .pollLast
- final MergedGroup<T> group = queue.last();
- //System.out.println("PRUNE: " + group);
- queue.remove(group);
- group.inQueue = false;
- }
- }
-
- public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
-
- final int maxQueueSize = offset + topN;
-
- //System.out.println("merge");
- // Init queue:
- for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
- final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
- if (!shard.isEmpty()) {
- //System.out.println(" insert shard=" + shardIDX);
- updateNextGroup(maxQueueSize, new ShardIter<T>(shard, shardIDX));
- }
- }
-
- // Pull merged topN groups:
- final List<SearchGroup<T>> newTopGroups = new ArrayList<SearchGroup<T>>();
-
- int count = 0;
-
- while(queue.size() != 0) {
- // TODO Java 1.6: pollFirst()
- final MergedGroup<T> group = queue.first();
- queue.remove(group);
- group.processed = true;
- //System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
- if (count++ >= offset) {
- final SearchGroup<T> newGroup = new SearchGroup<T>();
- newGroup.groupValue = group.groupValue;
- newGroup.sortValues = group.topValues;
- newTopGroups.add(newGroup);
- if (newTopGroups.size() == topN) {
- break;
- }
- //} else {
- // System.out.println(" skip < offset");
- }
-
- // Advance all iters in this group:
- for(ShardIter<T> shardIter : group.shards) {
- updateNextGroup(maxQueueSize, shardIter);
- }
- }
-
- if (newTopGroups.size() == 0) {
- return null;
- } else {
- return newTopGroups;
- }
- }
- }
-
- /** Merges multiple collections of top groups, for example
- * obtained from separate index shards. The provided
- * groupSort must match how the groups were sorted, and
- * the provided SearchGroups must have been computed
- * with fillFields=true passed to {@link
- * AbstractFirstPassGroupingCollector#getTopGroups}.
- *
- * <p>NOTE: this returns null if the topGroups is empty.
- */
- public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort)
- throws IOException {
- if (topGroups.size() == 0) {
- return null;
- } else {
- return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);
- }
- }
-}