add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / TermAllGroupHeadsCollector.java
1 package org.apache.lucene.search.grouping;
2
3 /*
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.index.IndexReader;
21 import org.apache.lucene.search.*;
22
23 import java.io.IOException;
24 import java.util.*;
25
26 /**
27  * A base implementation of {@link AbstractAllGroupHeadsCollector} for retrieving the most relevant groups when grouping
28  * on a string based group field. More specifically this all concrete implementations of this base implementation
29  * use {@link org.apache.lucene.search.FieldCache.StringIndex}.
30  *
31  * @lucene.experimental
32  */
33 public abstract class TermAllGroupHeadsCollector<GH extends AbstractAllGroupHeadsCollector.GroupHead> extends AbstractAllGroupHeadsCollector<GH> {
34
35   private static final int DEFAULT_INITIAL_SIZE = 128;
36
37   final String groupField;
38   FieldCache.StringIndex groupIndex;
39   IndexReader indexReader;
40   int docBase;
41
42   protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) {
43     super(numberOfSorts);
44     this.groupField = groupField;
45   }
46
47   /**
48    * Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments.
49    * This factory method decides with implementation is best suited.
50    *
51    * Delegates to {@link #create(String, org.apache.lucene.search.Sort, int)} with an initialSize of 128.
52    *
53    * @param groupField      The field to group by
54    * @param sortWithinGroup The sort within each group
55    * @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments
56    * @throws IOException If I/O related errors occur
57    */
58   public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup) throws IOException {
59     return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE);
60   }
61
62   /**
63    * Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments.
64    * This factory method decides with implementation is best suited.
65    *
66    * @param groupField      The field to group by
67    * @param sortWithinGroup The sort within each group
68    * @param initialSize The initial allocation size of the internal int set and group list which should roughly match
69    *                    the total number of expected unique groups. Be aware that the heap usage is
70    *                    4 bytes * initialSize.
71    * @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments
72    * @throws IOException If I/O related errors occur
73    */
74   public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup, int initialSize) throws IOException {
75     boolean sortAllScore = true;
76     boolean sortAllFieldValue = true;
77
78     for (SortField sortField : sortWithinGroup.getSort()) {
79       if (sortField.getType() == SortField.SCORE) {
80         sortAllFieldValue = false;
81       } else if (needGeneralImpl(sortField)) {
82         return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup);
83       } else {
84         sortAllScore = false;
85       }
86     }
87
88     if (sortAllScore) {
89       return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
90     } else if (sortAllFieldValue) {
91       return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
92     } else {
93       return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
94     }
95   }
96
97   // Returns when a sort field needs the general impl.
98   private static boolean needGeneralImpl(SortField sortField) {
99     int sortType = sortField.getType();
100     // Note (MvG): We can also make an optimized impl when sorting is SortField.DOC
101     return sortType != SortField.STRING_VAL && sortType != SortField.STRING && sortType != SortField.SCORE;
102   }
103
104   // A general impl that works for any group sort.
105   static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector<GeneralAllGroupHeadsCollector.GroupHead> {
106
107     private final Sort sortWithinGroup;
108     private final Map<String, GroupHead> groups;
109
110     private Scorer scorer;
111
112     GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) throws IOException {
113       super(groupField, sortWithinGroup.getSort().length);
114       this.sortWithinGroup = sortWithinGroup;
115       groups = new HashMap<String, GroupHead>();
116
117       final SortField[] sortFields = sortWithinGroup.getSort();
118       for (int i = 0; i < sortFields.length; i++) {
119         reversed[i] = sortFields[i].getReverse() ? -1 : 1;
120       }
121     }
122
123     protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
124       final int ord = groupIndex.order[doc];
125       final String groupValue = ord == 0 ? null : groupIndex.lookup[ord];
126       GroupHead groupHead = groups.get(groupValue);
127       if (groupHead == null) {
128         groupHead = new GroupHead(groupValue, sortWithinGroup, doc);
129         groups.put(groupValue == null ? null : groupValue, groupHead);
130         temporalResult.stop = true;
131       } else {
132         temporalResult.stop = false;
133       }
134       temporalResult.groupHead = groupHead;
135     }
136
137     protected Collection<GroupHead> getCollectedGroupHeads() {
138       return groups.values();
139     }
140
141     public void setNextReader(IndexReader reader, int docBase) throws IOException {
142       this.indexReader = reader;
143       this.docBase = docBase;
144       groupIndex = FieldCache.DEFAULT.getStringIndex(reader, groupField);
145
146       for (GroupHead groupHead : groups.values()) {
147         for (int i = 0; i < groupHead.comparators.length; i++) {
148           groupHead.comparators[i].setNextReader(reader, docBase);
149         }
150       }
151     }
152
153     public void setScorer(Scorer scorer) throws IOException {
154       this.scorer = scorer;
155       for (GroupHead groupHead : groups.values()) {
156         for (FieldComparator comparator : groupHead.comparators) {
157           comparator.setScorer(scorer);
158         }
159       }
160     }
161
162     class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<String> {
163
164       final FieldComparator[] comparators;
165
166       private GroupHead(String groupValue, Sort sort, int doc) throws IOException {
167         super(groupValue, doc + docBase);
168         final SortField[] sortFields = sort.getSort();
169         comparators = new FieldComparator[sortFields.length];
170         for (int i = 0; i < sortFields.length; i++) {
171           comparators[i] = sortFields[i].getComparator(1, i);
172           comparators[i].setNextReader(indexReader, docBase);
173           comparators[i].setScorer(scorer);
174           comparators[i].copy(0, doc);
175           comparators[i].setBottom(0);
176         }
177       }
178
179       public int compare(int compIDX, int doc) throws IOException {
180         return comparators[compIDX].compareBottom(doc);
181       }
182
183       public void updateDocHead(int doc) throws IOException {
184         for (FieldComparator comparator : comparators) {
185           comparator.copy(0, doc);
186           comparator.setBottom(0);
187         }
188         this.doc = doc + docBase;
189       }
190     }
191   }
192
193
194   // AbstractAllGroupHeadsCollector optimized for ord fields and scores.
195   static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdScoreAllGroupHeadsCollector.GroupHead> {
196
197     private final SentinelIntSet ordSet;
198     private final List<GroupHead> collectedGroups;
199     private final SortField[] fields;
200
201     private FieldCache.StringIndex[] sortsIndex;
202     private Scorer scorer;
203     private GroupHead[] segmentGroupHeads;
204
205     OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
206       super(groupField, sortWithinGroup.getSort().length);
207       ordSet = new SentinelIntSet(initialSize, -1);
208       collectedGroups = new ArrayList<GroupHead>(initialSize);
209
210       final SortField[] sortFields = sortWithinGroup.getSort();
211       fields = new SortField[sortFields.length];
212       sortsIndex = new FieldCache.StringIndex[sortFields.length];
213       for (int i = 0; i < sortFields.length; i++) {
214         reversed[i] = sortFields[i].getReverse() ? -1 : 1;
215         fields[i] = sortFields[i];
216       }
217     }
218
219     protected Collection<GroupHead> getCollectedGroupHeads() {
220       return collectedGroups;
221     }
222
223     public void setScorer(Scorer scorer) throws IOException {
224       this.scorer = scorer;
225     }
226
227     protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
228       int key = groupIndex.order[doc];
229       GroupHead groupHead;
230       if (!ordSet.exists(key)) {
231         ordSet.put(key);
232         String term = key == 0 ? null : groupIndex.lookup[key];
233         groupHead = new GroupHead(doc, term);
234         collectedGroups.add(groupHead);
235         segmentGroupHeads[key] = groupHead;
236         temporalResult.stop = true;
237       } else {
238         temporalResult.stop = false;
239         groupHead = segmentGroupHeads[key];
240       }
241       temporalResult.groupHead = groupHead;
242     }
243
244     public void setNextReader(IndexReader reader, int docBase) throws IOException {
245       this.indexReader = reader;
246       this.docBase = docBase;
247       groupIndex = FieldCache.DEFAULT.getStringIndex(reader, groupField);
248       for (int i = 0; i < fields.length; i++) {
249         if (fields[i].getType() == SortField.SCORE) {
250           continue;
251         }
252
253         sortsIndex[i] = FieldCache.DEFAULT.getStringIndex(reader, fields[i].getField());
254       }
255
256       // Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
257       ordSet.clear();
258       segmentGroupHeads = new GroupHead[groupIndex.lookup.length];
259       for (GroupHead collectedGroup : collectedGroups) {
260         int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue);
261         if (ord >= 0) {
262           ordSet.put(ord);
263           segmentGroupHeads[ord] = collectedGroup;
264
265           for (int i = 0; i < sortsIndex.length; i++) {
266             if (fields[i].getType() == SortField.SCORE) {
267               continue;
268             }
269
270             collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i]);
271           }
272         }
273       }
274     }
275
276     class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<String> {
277
278       String[] sortValues;
279       int[] sortOrds;
280       float[] scores;
281
282       private GroupHead(int doc, String groupValue) throws IOException {
283         super(groupValue, doc + docBase);
284         sortValues = new String[sortsIndex.length];
285         sortOrds = new int[sortsIndex.length];
286         scores = new float[sortsIndex.length];
287         for (int i = 0; i < sortsIndex.length; i++) {
288           if (fields[i].getType() == SortField.SCORE) {
289             scores[i] = scorer.score();
290           } else {
291             sortValues[i] = sortsIndex[i].lookup[sortsIndex[i].order[doc]];
292             sortOrds[i] = sortsIndex[i].order[doc];
293           }
294         }
295
296       }
297
298       public int compare(int compIDX, int doc) throws IOException {
299         if (fields[compIDX].getType() == SortField.SCORE) {
300           float score = scorer.score();
301           if (scores[compIDX] < score) {
302             return 1;
303           } else if (scores[compIDX] > score) {
304             return -1;
305           }
306           return 0;
307         } else {
308           if (sortOrds[compIDX] < 0) {
309             // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative.
310             final String val1 = sortValues[compIDX];
311             final String val2 = sortsIndex[compIDX].lookup[sortsIndex[compIDX].order[doc]];
312             if (val1 == null) {
313               if (val2 == null) {
314                 return 0;
315               }
316               return -1;
317             } else if (val2 == null) {
318               return 1;
319             }
320             return val1.compareTo(val2);
321           } else {
322             return sortOrds[compIDX] - sortsIndex[compIDX].order[doc];
323           }
324         }
325       }
326
327       public void updateDocHead(int doc) throws IOException {
328         for (int i = 0; i < sortsIndex.length; i++) {
329           if (fields[i].getType() == SortField.SCORE) {
330             scores[i] = scorer.score();
331           } else {
332             sortValues[i] = sortsIndex[i].lookup[sortsIndex[i].order[doc]];
333             sortOrds[i] = sortsIndex[i].order[doc];
334           }
335         }
336         this.doc = doc + docBase;
337       }
338
339     }
340
341   }
342
343
344   // AbstractAllGroupHeadsCollector optimized for ord fields.
345   static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdAllGroupHeadsCollector.GroupHead> {
346
347     private final SentinelIntSet ordSet;
348     private final List<GroupHead> collectedGroups;
349     private final SortField[] fields;
350
351     private FieldCache.StringIndex[] sortsIndex;
352     private GroupHead[] segmentGroupHeads;
353
354     OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
355       super(groupField, sortWithinGroup.getSort().length);
356       ordSet = new SentinelIntSet(initialSize, -1);
357       collectedGroups = new ArrayList<GroupHead>(initialSize);
358
359       final SortField[] sortFields = sortWithinGroup.getSort();
360       fields = new SortField[sortFields.length];
361       sortsIndex = new FieldCache.StringIndex[sortFields.length];
362       for (int i = 0; i < sortFields.length; i++) {
363         reversed[i] = sortFields[i].getReverse() ? -1 : 1;
364         fields[i] = sortFields[i];
365       }
366     }
367
368     protected Collection<GroupHead> getCollectedGroupHeads() {
369       return collectedGroups;
370     }
371
372     public void setScorer(Scorer scorer) throws IOException {
373     }
374
375     protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
376       int key = groupIndex.order[doc];
377       GroupHead groupHead;
378       if (!ordSet.exists(key)) {
379         ordSet.put(key);
380         String term = key == 0 ? null : groupIndex.lookup[key];
381         groupHead = new GroupHead(doc, term);
382         collectedGroups.add(groupHead);
383         segmentGroupHeads[key] = groupHead;
384         temporalResult.stop = true;
385       } else {
386         temporalResult.stop = false;
387         groupHead = segmentGroupHeads[key];
388       }
389       temporalResult.groupHead = groupHead;
390     }
391
392     public void setNextReader(IndexReader reader, int docBase) throws IOException {
393       this.indexReader = reader;
394       this.docBase = docBase;
395       groupIndex = FieldCache.DEFAULT.getStringIndex(reader, groupField);
396       for (int i = 0; i < fields.length; i++) {
397         sortsIndex[i] = FieldCache.DEFAULT.getStringIndex(reader, fields[i].getField());
398       }
399
400       // Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
401       ordSet.clear();
402       segmentGroupHeads = new GroupHead[groupIndex.lookup.length];
403       for (GroupHead collectedGroup : collectedGroups) {
404         int groupOrd = groupIndex.binarySearchLookup(collectedGroup.groupValue);
405         if (groupOrd >= 0) {
406           ordSet.put(groupOrd);
407           segmentGroupHeads[groupOrd] = collectedGroup;
408
409           for (int i = 0; i < sortsIndex.length; i++) {
410             collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i]);
411           }
412         }
413       }
414     }
415
416     class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<String> {
417
418       String[] sortValues;
419       int[] sortOrds;
420
421       private GroupHead(int doc, String groupValue) throws IOException {
422         super(groupValue, doc + docBase);
423         sortValues = new String[sortsIndex.length];
424         sortOrds = new int[sortsIndex.length];
425         for (int i = 0; i < sortsIndex.length; i++) {
426           sortValues[i] = sortsIndex[i].lookup[sortsIndex[i].order[doc]];
427           sortOrds[i] = sortsIndex[i].order[doc];
428         }
429       }
430
431       public int compare(int compIDX, int doc) throws IOException {
432         if (sortOrds[compIDX] < 0) {
433           // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative.
434           final String val1 = sortValues[compIDX];
435           final String val2 = sortsIndex[compIDX].lookup[sortsIndex[compIDX].order[doc]];
436           if (val1 == null) {
437             if (val2 == null) {
438               return 0;
439             }
440             return -1;
441           } else if (val2 == null) {
442             return 1;
443           }
444           return val1.compareTo(val2);
445         } else {
446           return sortOrds[compIDX] - sortsIndex[compIDX].order[doc];
447         }
448       }
449
450       public void updateDocHead(int doc) throws IOException {
451         for (int i = 0; i < sortsIndex.length; i++) {
452           sortValues[i] = sortsIndex[i].lookup[sortsIndex[i].order[doc]];
453           sortOrds[i] = sortsIndex[i].order[doc];
454         }
455         this.doc = doc + docBase;
456       }
457
458     }
459
460   }
461
462
463   // AbstractAllGroupHeadsCollector optimized for scores.
464   static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<ScoreAllGroupHeadsCollector.GroupHead> {
465
466     private final SentinelIntSet ordSet;
467     private final List<GroupHead> collectedGroups;
468     private final SortField[] fields;
469
470     private Scorer scorer;
471     private GroupHead[] segmentGroupHeads;
472
473     ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
474       super(groupField, sortWithinGroup.getSort().length);
475       ordSet = new SentinelIntSet(initialSize, -1);
476       collectedGroups = new ArrayList<GroupHead>(initialSize);
477
478       final SortField[] sortFields = sortWithinGroup.getSort();
479       fields = new SortField[sortFields.length];
480       for (int i = 0; i < sortFields.length; i++) {
481         reversed[i] = sortFields[i].getReverse() ? -1 : 1;
482         fields[i] = sortFields[i];
483       }
484     }
485
486     protected Collection<GroupHead> getCollectedGroupHeads() {
487       return collectedGroups;
488     }
489
490     public void setScorer(Scorer scorer) throws IOException {
491       this.scorer = scorer;
492     }
493
494     protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
495       int key = groupIndex.order[doc];
496       GroupHead groupHead;
497       if (!ordSet.exists(key)) {
498         ordSet.put(key);
499         String term = key == 0 ? null : groupIndex.lookup[key];
500         groupHead = new GroupHead(doc, term);
501         collectedGroups.add(groupHead);
502         segmentGroupHeads[key] = groupHead;
503         temporalResult.stop = true;
504       } else {
505         temporalResult.stop = false;
506         groupHead = segmentGroupHeads[key];
507       }
508       temporalResult.groupHead = groupHead;
509     }
510
511     public void setNextReader(IndexReader reader, int docBase) throws IOException {
512       this.indexReader = reader;
513       this.docBase = docBase;
514       groupIndex = FieldCache.DEFAULT.getStringIndex(reader, groupField);
515
516       // Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
517       ordSet.clear();
518       segmentGroupHeads = new GroupHead[groupIndex.lookup.length];
519       for (GroupHead collectedGroup : collectedGroups) {
520         int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue);
521         if (ord >= 0) {
522           ordSet.put(ord);
523           segmentGroupHeads[ord] = collectedGroup;
524         }
525       }
526     }
527
528     class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<String> {
529
530       float[] scores;
531
532       private GroupHead(int doc, String groupValue) throws IOException {
533         super(groupValue, doc + docBase);
534         scores = new float[fields.length];
535         float score = scorer.score();
536         for (int i = 0; i < scores.length; i++) {
537           scores[i] = score;
538         }
539       }
540
541       public int compare(int compIDX, int doc) throws IOException {
542         float score = scorer.score();
543         if (scores[compIDX] < score) {
544           return 1;
545         } else if (scores[compIDX] > score) {
546           return -1;
547         }
548         return 0;
549       }
550
551       public void updateDocHead(int doc) throws IOException {
552         float score = scorer.score();
553         for (int i = 0; i < scores.length; i++) {
554           scores[i] = score;
555         }
556         this.doc = doc + docBase;
557       }
558
559     }
560
561   }
562
563 }