pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / SearchGroup.java
1 package org.apache.lucene.search.grouping;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.search.FieldComparator;
21 import org.apache.lucene.search.Sort;
22 import org.apache.lucene.search.SortField;
23
24 import java.io.IOException;
25 import java.util.*;
26
27 /**
28  * Represents a group that is found during the first pass search.
29  *
30  * @lucene.experimental
31  */
32 public class SearchGroup<GROUP_VALUE_TYPE> {
33
34   /** The value that defines this group  */
35   public GROUP_VALUE_TYPE groupValue;
36
37   /** The sort values used during sorting. These are the
38    *  groupSort field values of the highest rank document
39    *  (by the groupSort) within the group.  Can be
40    * <code>null</code> if <code>fillFields=false</code> had
41    * been passed to {@link AbstractFirstPassGroupingCollector#getTopGroups} */
42   public Object[] sortValues;
43
44   @Override
45   public String toString() {
46     return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
47   }
48
49   @Override
50   public boolean equals(Object o) {
51     if (this == o) return true;
52     if (o == null || getClass() != o.getClass()) return false;
53
54     SearchGroup that = (SearchGroup) o;
55
56     if (groupValue == null) {
57       if (that.groupValue != null) {
58         return false;
59       }
60     } else if (!groupValue.equals(that.groupValue)) {
61       return false;
62     }
63
64     return true;
65   }
66
67   @Override
68   public int hashCode() {
69     return groupValue != null ? groupValue.hashCode() : 0;
70   }
71
72   private static class ShardIter<T> {
73     public final Iterator<SearchGroup<T>> iter;
74     public final int shardIndex;
75
76     public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
77       this.shardIndex = shardIndex;
78       iter = shard.iterator();
79       assert iter.hasNext();
80     }
81
82     public SearchGroup<T> next() {
83       assert iter.hasNext();
84       final SearchGroup<T> group = iter.next();
85       if (group.sortValues == null) {
86         throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
87       }
88       return group;
89     }
90     
91     @Override
92     public String toString() {
93       return "ShardIter(shard=" + shardIndex + ")";
94     }
95   }
96
97   // Holds all shards currently on the same group
98   private static class MergedGroup<T> {
99
100     // groupValue may be null!
101     public final T groupValue;
102
103     public Object[] topValues;
104     public final List<ShardIter<T>> shards = new ArrayList<ShardIter<T>>();
105     public int minShardIndex;
106     public boolean processed;
107     public boolean inQueue;
108
109     public MergedGroup(T groupValue) {
110       this.groupValue = groupValue;
111     }
112
113     // Only for assert
114     private boolean neverEquals(Object _other) {
115       if (_other instanceof MergedGroup) {
116         MergedGroup other = (MergedGroup) _other;
117         if (groupValue == null) {
118           assert other.groupValue != null;
119         } else {
120           assert !groupValue.equals(other.groupValue);
121         }
122       }
123       return true;
124     }
125
126     @Override
127     public boolean equals(Object _other) {
128       // We never have another MergedGroup instance with
129       // same groupValue
130       assert neverEquals(_other);
131
132       if (_other instanceof MergedGroup) {
133         MergedGroup other = (MergedGroup) _other;
134         if (groupValue == null) {
135           return other == null;
136         } else {
137           return groupValue.equals(other);
138         }
139       } else {
140         return false;
141       }
142     }
143
144     @Override
145     public int hashCode() {
146       if (groupValue == null) {
147         return 0;
148       } else {
149         return groupValue.hashCode();
150       }
151     }
152   }
153
154   private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
155
156     public final FieldComparator[] comparators;
157     public final int[] reversed;
158
159     public GroupComparator(Sort groupSort) throws IOException {
160       final SortField[] sortFields = groupSort.getSort();
161       comparators = new FieldComparator[sortFields.length];
162       reversed = new int[sortFields.length];
163       for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
164         final SortField sortField = sortFields[compIDX];
165         comparators[compIDX] = sortField.getComparator(1, compIDX);
166         reversed[compIDX] = sortField.getReverse() ? -1 : 1;
167       }
168     }
169
170     @SuppressWarnings("unchecked")
171     public int compare(MergedGroup<T> group, MergedGroup<T> other) {
172       if (group == other) {
173         return 0;
174       }
175       //System.out.println("compare group=" + group + " other=" + other);
176       final Object[] groupValues = group.topValues;
177       final Object[] otherValues = other.topValues;
178       //System.out.println("  groupValues=" + groupValues + " otherValues=" + otherValues);
179       for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
180         final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
181                                                                              otherValues[compIDX]);
182         if (c != 0) {
183           return c;
184         }
185       }
186
187       // Tie break by min shard index:
188       assert group.minShardIndex != other.minShardIndex;
189       return group.minShardIndex - other.minShardIndex;
190     }
191   }
192
193   private static class GroupMerger<T> {
194
195     private final GroupComparator<T> groupComp;
196     private final SortedSet<MergedGroup<T>> queue;
197     private final Map<T,MergedGroup<T>> groupsSeen;
198
199     public GroupMerger(Sort groupSort) throws IOException {
200       groupComp = new GroupComparator<T>(groupSort);
201       queue = new TreeSet<MergedGroup<T>>(groupComp);
202       groupsSeen = new HashMap<T,MergedGroup<T>>();
203     }
204
205     @SuppressWarnings("unchecked")
206     private void updateNextGroup(int topN, ShardIter<T> shard) {
207       while(shard.iter.hasNext()) {
208         final SearchGroup<T> group = shard.next();
209         MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
210         final boolean isNew = mergedGroup == null;
211         //System.out.println("    next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
212
213         if (isNew) {
214           // Start a new group:
215           //System.out.println("      new");
216           mergedGroup = new MergedGroup<T>(group.groupValue);
217           mergedGroup.minShardIndex = shard.shardIndex;
218           assert group.sortValues != null;
219           mergedGroup.topValues = group.sortValues;
220           groupsSeen.put(group.groupValue, mergedGroup);
221           mergedGroup.inQueue = true;
222           queue.add(mergedGroup);
223         } else if (mergedGroup.processed) {
224           // This shard produced a group that we already
225           // processed; move on to next group...
226           continue;
227         } else {
228           //System.out.println("      old");
229           boolean competes = false;
230           for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
231             final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
232                                                                                                        mergedGroup.topValues[compIDX]);
233             if (cmp < 0) {
234               // Definitely competes
235               competes = true;
236               break;
237             } else if (cmp > 0) {
238               // Definitely does not compete
239               break;
240             } else if (compIDX == groupComp.comparators.length-1) {
241               if (shard.shardIndex < mergedGroup.minShardIndex) {
242                 competes = true;
243               }
244             }
245           }
246
247           //System.out.println("      competes=" + competes);
248
249           if (competes) {
250             // Group's sort changed -- remove & re-insert
251             if (mergedGroup.inQueue) {
252               queue.remove(mergedGroup);
253             }
254             mergedGroup.topValues = group.sortValues;
255             mergedGroup.minShardIndex = shard.shardIndex;
256             queue.add(mergedGroup);
257             mergedGroup.inQueue = true;
258           }
259         }
260
261         mergedGroup.shards.add(shard);
262         break;
263       }
264
265       // Prune un-competitive groups:
266       while(queue.size() > topN) {
267         // TODO java 1.6: .pollLast
268         final MergedGroup<T> group = queue.last();
269         //System.out.println("PRUNE: " + group);
270         queue.remove(group);
271         group.inQueue = false;
272       }
273     }
274
275     public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
276
277       final int maxQueueSize = offset + topN;
278
279       //System.out.println("merge");
280       // Init queue:
281       for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
282         final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
283         if (!shard.isEmpty()) {
284           //System.out.println("  insert shard=" + shardIDX);
285           updateNextGroup(maxQueueSize, new ShardIter<T>(shard, shardIDX));
286         }
287       }
288
289       // Pull merged topN groups:
290       final List<SearchGroup<T>> newTopGroups = new ArrayList<SearchGroup<T>>();
291
292       int count = 0;
293
294       while(queue.size() != 0) {
295         // TODO Java 1.6: pollFirst()
296         final MergedGroup<T> group = queue.first();
297         queue.remove(group);
298         group.processed = true;
299         //System.out.println("  pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
300         if (count++ >= offset) {
301           final SearchGroup<T> newGroup = new SearchGroup<T>();
302           newGroup.groupValue = group.groupValue;
303           newGroup.sortValues = group.topValues;
304           newTopGroups.add(newGroup);
305           if (newTopGroups.size() == topN) {
306             break;
307           }
308         //} else {
309         // System.out.println("    skip < offset");
310         }
311
312         // Advance all iters in this group:
313         for(ShardIter<T> shardIter : group.shards) {
314           updateNextGroup(maxQueueSize, shardIter);
315         }
316       }
317
318       if (newTopGroups.size() == 0) {
319         return null;
320       } else {
321         return newTopGroups;
322       }
323     }
324   }
325
326   /** Merges multiple collections of top groups, for example
327    *  obtained from separate index shards.  The provided
328    *  groupSort must match how the groups were sorted, and
329    *  the provided SearchGroups must have been computed
330    *  with fillFields=true passed to {@link
331    *  AbstractFirstPassGroupingCollector#getTopGroups}.
332    *
333    * <p>NOTE: this returns null if the topGroups is empty.
334    */
335   public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort)
336     throws IOException {
337     if (topGroups.size() == 0) {
338       return null;
339     } else {
340       return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);
341     }
342   }
343 }