1 package org.apache.lucene.search.grouping;
4 * Licensed to the Apache Software Foundation (ASF) under one or more
5 * contributor license agreements. See the NOTICE file distributed with
6 * this work for additional information regarding copyright ownership.
7 * The ASF licenses this file to You under the Apache License, Version 2.0
8 * (the "License"); you may not use this file except in compliance with
9 * the License. You may obtain a copy of the License at
11 * http://www.apache.org/licenses/LICENSE-2.0
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
20 import org.apache.lucene.search.FieldComparator;
21 import org.apache.lucene.search.Sort;
22 import org.apache.lucene.search.SortField;
24 import java.io.IOException;
28 * Represents a group that is found during the first pass search.
30 * @lucene.experimental
32 public class SearchGroup<GROUP_VALUE_TYPE> {
34 /** The value that defines this group */
35 public GROUP_VALUE_TYPE groupValue;
37 /** The sort values used during sorting. These are the
38 * groupSort field values of the highest rank document
39 * (by the groupSort) within the group. Can be
40 * <code>null</code> if <code>fillFields=false</code> had
41 * been passed to {@link AbstractFirstPassGroupingCollector#getTopGroups} */
42 public Object[] sortValues;
45 public String toString() {
46 return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
50 public boolean equals(Object o) {
51 if (this == o) return true;
52 if (o == null || getClass() != o.getClass()) return false;
54 SearchGroup that = (SearchGroup) o;
56 if (groupValue == null) {
57 if (that.groupValue != null) {
60 } else if (!groupValue.equals(that.groupValue)) {
68 public int hashCode() {
69 return groupValue != null ? groupValue.hashCode() : 0;
72 private static class ShardIter<T> {
73 public final Iterator<SearchGroup<T>> iter;
74 public final int shardIndex;
76 public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
77 this.shardIndex = shardIndex;
78 iter = shard.iterator();
79 assert iter.hasNext();
82 public SearchGroup<T> next() {
83 assert iter.hasNext();
84 final SearchGroup<T> group = iter.next();
85 if (group.sortValues == null) {
86 throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
92 public String toString() {
93 return "ShardIter(shard=" + shardIndex + ")";
97 // Holds all shards currently on the same group
98 private static class MergedGroup<T> {
100 // groupValue may be null!
101 public final T groupValue;
103 public Object[] topValues;
104 public final List<ShardIter<T>> shards = new ArrayList<ShardIter<T>>();
105 public int minShardIndex;
106 public boolean processed;
107 public boolean inQueue;
109 public MergedGroup(T groupValue) {
110 this.groupValue = groupValue;
114 private boolean neverEquals(Object _other) {
115 if (_other instanceof MergedGroup) {
116 MergedGroup other = (MergedGroup) _other;
117 if (groupValue == null) {
118 assert other.groupValue != null;
120 assert !groupValue.equals(other.groupValue);
127 public boolean equals(Object _other) {
128 // We never have another MergedGroup instance with
130 assert neverEquals(_other);
132 if (_other instanceof MergedGroup) {
133 MergedGroup other = (MergedGroup) _other;
134 if (groupValue == null) {
135 return other == null;
137 return groupValue.equals(other);
145 public int hashCode() {
146 if (groupValue == null) {
149 return groupValue.hashCode();
154 private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
156 public final FieldComparator[] comparators;
157 public final int[] reversed;
159 public GroupComparator(Sort groupSort) throws IOException {
160 final SortField[] sortFields = groupSort.getSort();
161 comparators = new FieldComparator[sortFields.length];
162 reversed = new int[sortFields.length];
163 for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
164 final SortField sortField = sortFields[compIDX];
165 comparators[compIDX] = sortField.getComparator(1, compIDX);
166 reversed[compIDX] = sortField.getReverse() ? -1 : 1;
170 @SuppressWarnings("unchecked")
171 public int compare(MergedGroup<T> group, MergedGroup<T> other) {
172 if (group == other) {
175 //System.out.println("compare group=" + group + " other=" + other);
176 final Object[] groupValues = group.topValues;
177 final Object[] otherValues = other.topValues;
178 //System.out.println(" groupValues=" + groupValues + " otherValues=" + otherValues);
179 for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
180 final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
181 otherValues[compIDX]);
187 // Tie break by min shard index:
188 assert group.minShardIndex != other.minShardIndex;
189 return group.minShardIndex - other.minShardIndex;
193 private static class GroupMerger<T> {
195 private final GroupComparator<T> groupComp;
196 private final SortedSet<MergedGroup<T>> queue;
197 private final Map<T,MergedGroup<T>> groupsSeen;
199 public GroupMerger(Sort groupSort) throws IOException {
200 groupComp = new GroupComparator<T>(groupSort);
201 queue = new TreeSet<MergedGroup<T>>(groupComp);
202 groupsSeen = new HashMap<T,MergedGroup<T>>();
205 @SuppressWarnings("unchecked")
206 private void updateNextGroup(int topN, ShardIter<T> shard) {
207 while(shard.iter.hasNext()) {
208 final SearchGroup<T> group = shard.next();
209 MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
210 final boolean isNew = mergedGroup == null;
211 //System.out.println(" next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
214 // Start a new group:
215 //System.out.println(" new");
216 mergedGroup = new MergedGroup<T>(group.groupValue);
217 mergedGroup.minShardIndex = shard.shardIndex;
218 assert group.sortValues != null;
219 mergedGroup.topValues = group.sortValues;
220 groupsSeen.put(group.groupValue, mergedGroup);
221 mergedGroup.inQueue = true;
222 queue.add(mergedGroup);
223 } else if (mergedGroup.processed) {
224 // This shard produced a group that we already
225 // processed; move on to next group...
228 //System.out.println(" old");
229 boolean competes = false;
230 for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
231 final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
232 mergedGroup.topValues[compIDX]);
234 // Definitely competes
237 } else if (cmp > 0) {
238 // Definitely does not compete
240 } else if (compIDX == groupComp.comparators.length-1) {
241 if (shard.shardIndex < mergedGroup.minShardIndex) {
247 //System.out.println(" competes=" + competes);
250 // Group's sort changed -- remove & re-insert
251 if (mergedGroup.inQueue) {
252 queue.remove(mergedGroup);
254 mergedGroup.topValues = group.sortValues;
255 mergedGroup.minShardIndex = shard.shardIndex;
256 queue.add(mergedGroup);
257 mergedGroup.inQueue = true;
261 mergedGroup.shards.add(shard);
265 // Prune un-competitive groups:
266 while(queue.size() > topN) {
267 // TODO java 1.6: .pollLast
268 final MergedGroup<T> group = queue.last();
269 //System.out.println("PRUNE: " + group);
271 group.inQueue = false;
275 public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
277 final int maxQueueSize = offset + topN;
279 //System.out.println("merge");
281 for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
282 final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
283 if (!shard.isEmpty()) {
284 //System.out.println(" insert shard=" + shardIDX);
285 updateNextGroup(maxQueueSize, new ShardIter<T>(shard, shardIDX));
289 // Pull merged topN groups:
290 final List<SearchGroup<T>> newTopGroups = new ArrayList<SearchGroup<T>>();
294 while(queue.size() != 0) {
295 // TODO Java 1.6: pollFirst()
296 final MergedGroup<T> group = queue.first();
298 group.processed = true;
299 //System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
300 if (count++ >= offset) {
301 final SearchGroup<T> newGroup = new SearchGroup<T>();
302 newGroup.groupValue = group.groupValue;
303 newGroup.sortValues = group.topValues;
304 newTopGroups.add(newGroup);
305 if (newTopGroups.size() == topN) {
309 // System.out.println(" skip < offset");
312 // Advance all iters in this group:
313 for(ShardIter<T> shardIter : group.shards) {
314 updateNextGroup(maxQueueSize, shardIter);
318 if (newTopGroups.size() == 0) {
326 /** Merges multiple collections of top groups, for example
327 * obtained from separate index shards. The provided
328 * groupSort must match how the groups were sorted, and
329 * the provided SearchGroups must have been computed
330 * with fillFields=true passed to {@link
331 * AbstractFirstPassGroupingCollector#getTopGroups}.
333 * <p>NOTE: this returns null if the topGroups is empty.
335 public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort)
337 if (topGroups.size() == 0) {
340 return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);