add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / SearchGroup.java
1 package org.apache.lucene.search.grouping;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import java.io.IOException;
21 import java.util.*;
22
23 import org.apache.lucene.search.FieldComparator;
24 import org.apache.lucene.search.Sort;
25 import org.apache.lucene.search.SortField;
26
27 /**
28  * Represents a group that is found during the first pass search.
29  *
30  * @lucene.experimental
31  */
32 public class SearchGroup<GROUP_VALUE_TYPE> {
33
34   /** The value that defines this group  */
35   public GROUP_VALUE_TYPE groupValue;
36
37   /** The sort values used during sorting. These are the
38    *  groupSort field values of the highest rank document
39    *  (by the groupSort) within the group.  Can be
40    * <code>null</code> if <code>fillFields=false</code> had
41    * been passed to {@link AbstractFirstPassGroupingCollector#getTopGroups} */
42   public Object[] sortValues;
43
44   @Override
45   public String toString() {
46     return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
47   }
48
49   private static class ShardIter<T> {
50     public final Iterator<SearchGroup<T>> iter;
51     public final int shardIndex;
52
53     public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
54       this.shardIndex = shardIndex;
55       iter = shard.iterator();
56       assert iter.hasNext();
57     }
58
59     public SearchGroup<T> next() {
60       assert iter.hasNext();
61       final SearchGroup<T> group = iter.next();
62       if (group.sortValues == null) {
63         throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
64       }
65       return group;
66     }
67     
68     @Override
69     public String toString() {
70       return "ShardIter(shard=" + shardIndex + ")";
71     }
72   }
73
74   // Holds all shards currently on the same group
75   private static class MergedGroup<T> {
76
77     // groupValue may be null!
78     public final T groupValue;
79
80     public Object[] topValues;
81     public final List<ShardIter<T>> shards = new ArrayList<ShardIter<T>>();
82     public int minShardIndex;
83     public boolean processed;
84     public boolean inQueue;
85
86     public MergedGroup(T groupValue) {
87       this.groupValue = groupValue;
88     }
89
90     // Only for assert
91     private boolean neverEquals(Object _other) {
92       if (_other instanceof MergedGroup) {
93         MergedGroup other = (MergedGroup) _other;
94         if (groupValue == null) {
95           assert other.groupValue != null;
96         } else {
97           assert !groupValue.equals(other.groupValue);
98         }
99       }
100       return true;
101     }
102
103     @Override
104     public boolean equals(Object _other) {
105       // We never have another MergedGroup instance with
106       // same groupValue
107       assert neverEquals(_other);
108
109       if (_other instanceof MergedGroup) {
110         MergedGroup other = (MergedGroup) _other;
111         if (groupValue == null) {
112           return other == null;
113         } else {
114           return groupValue.equals(other);
115         }
116       } else {
117         return false;
118       }
119     }
120
121     @Override
122     public int hashCode() {
123       if (groupValue == null) {
124         return 0;
125       } else {
126         return groupValue.hashCode();
127       }
128     }
129   }
130
131   private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
132
133     public final FieldComparator[] comparators;
134     public final int[] reversed;
135
136     public GroupComparator(Sort groupSort) throws IOException {
137       final SortField[] sortFields = groupSort.getSort();
138       comparators = new FieldComparator[sortFields.length];
139       reversed = new int[sortFields.length];
140       for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
141         final SortField sortField = sortFields[compIDX];
142         comparators[compIDX] = sortField.getComparator(1, compIDX);
143         reversed[compIDX] = sortField.getReverse() ? -1 : 1;
144       }
145     }
146
147     @SuppressWarnings("unchecked")
148     public int compare(MergedGroup<T> group, MergedGroup<T> other) {
149       if (group == other) {
150         return 0;
151       }
152       //System.out.println("compare group=" + group + " other=" + other);
153       final Object[] groupValues = group.topValues;
154       final Object[] otherValues = other.topValues;
155       //System.out.println("  groupValues=" + groupValues + " otherValues=" + otherValues);
156       for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
157         final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
158                                                                              otherValues[compIDX]);
159         if (c != 0) {
160           return c;
161         }
162       }
163
164       // Tie break by min shard index:
165       assert group.minShardIndex != other.minShardIndex;
166       return group.minShardIndex - other.minShardIndex;
167     }
168   }
169
170   private static class GroupMerger<T> {
171
172     private final GroupComparator<T> groupComp;
173     private final SortedSet<MergedGroup<T>> queue;
174     private final Map<T,MergedGroup<T>> groupsSeen;
175
176     public GroupMerger(Sort groupSort) throws IOException {
177       groupComp = new GroupComparator<T>(groupSort);
178       queue = new TreeSet<MergedGroup<T>>(groupComp);
179       groupsSeen = new HashMap<T,MergedGroup<T>>();
180     }
181
182     @SuppressWarnings("unchecked")
183     private void updateNextGroup(int topN, ShardIter<T> shard) {
184       while(shard.iter.hasNext()) {
185         final SearchGroup<T> group = shard.next();
186         MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
187         final boolean isNew = mergedGroup == null;
188         //System.out.println("    next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
189
190         if (isNew) {
191           // Start a new group:
192           //System.out.println("      new");
193           mergedGroup = new MergedGroup<T>(group.groupValue);
194           mergedGroup.minShardIndex = shard.shardIndex;
195           assert group.sortValues != null;
196           mergedGroup.topValues = group.sortValues;
197           groupsSeen.put(group.groupValue, mergedGroup);
198           mergedGroup.inQueue = true;
199           queue.add(mergedGroup);
200         } else if (mergedGroup.processed) {
201           // This shard produced a group that we already
202           // processed; move on to next group...
203           continue;
204         } else {
205           //System.out.println("      old");
206           boolean competes = false;
207           for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
208             final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
209                                                                                                        mergedGroup.topValues[compIDX]);
210             if (cmp < 0) {
211               // Definitely competes
212               competes = true;
213               break;
214             } else if (cmp > 0) {
215               // Definitely does not compete
216               break;
217             } else if (compIDX == groupComp.comparators.length-1) {
218               if (shard.shardIndex < mergedGroup.minShardIndex) {
219                 competes = true;
220               }
221             }
222           }
223
224           //System.out.println("      competes=" + competes);
225
226           if (competes) {
227             // Group's sort changed -- remove & re-insert
228             if (mergedGroup.inQueue) {
229               queue.remove(mergedGroup);
230             }
231             mergedGroup.topValues = group.sortValues;
232             mergedGroup.minShardIndex = shard.shardIndex;
233             queue.add(mergedGroup);
234             mergedGroup.inQueue = true;
235           }
236         }
237
238         mergedGroup.shards.add(shard);
239         break;
240       }
241
242       // Prune un-competitive groups:
243       while(queue.size() > topN) {
244         // TODO java 1.6: .pollLast
245         final MergedGroup<T> group = queue.last();
246         //System.out.println("PRUNE: " + group);
247         queue.remove(group);
248         group.inQueue = false;
249       }
250     }
251
252     public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
253
254       final int maxQueueSize = offset + topN;
255
256       //System.out.println("merge");
257       // Init queue:
258       for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
259         final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
260         if (!shard.isEmpty()) {
261           //System.out.println("  insert shard=" + shardIDX);
262           updateNextGroup(maxQueueSize, new ShardIter<T>(shard, shardIDX));
263         }
264       }
265
266       // Pull merged topN groups:
267       final List<SearchGroup<T>> newTopGroups = new ArrayList<SearchGroup<T>>();
268
269       int count = 0;
270
271       while(queue.size() != 0) {
272         // TODO Java 1.6: pollFirst()
273         final MergedGroup<T> group = queue.first();
274         queue.remove(group);
275         group.processed = true;
276         //System.out.println("  pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
277         if (count++ >= offset) {
278           final SearchGroup<T> newGroup = new SearchGroup<T>();
279           newGroup.groupValue = group.groupValue;
280           newGroup.sortValues = group.topValues;
281           newTopGroups.add(newGroup);
282           if (newTopGroups.size() == topN) {
283             break;
284           }
285         //} else {
286         // System.out.println("    skip < offset");
287         }
288
289         // Advance all iters in this group:
290         for(ShardIter<T> shardIter : group.shards) {
291           updateNextGroup(maxQueueSize, shardIter);
292         }
293       }
294
295       if (newTopGroups.size() == 0) {
296         return null;
297       } else {
298         return newTopGroups;
299       }
300     }
301   }
302
303   /** Merges multiple collections of top groups, for example
304    *  obtained from separate index shards.  The provided
305    *  groupSort must match how the groups were sorted, and
306    *  the provided SearchGroups must have been computed
307    *  with fillFields=true passed to {@link
308    *  AbstractFirstPassGroupingCollector#getTopGroups}.
309    *
310    * <p>NOTE: this returns null if the topGroups is empty.
311    */
312   public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort)
313     throws IOException {
314     if (topGroups.size() == 0) {
315       return null;
316     } else {
317       return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);
318     }
319   }
320 }