add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / TopGroups.java
1 package org.apache.lucene.search.grouping;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import java.io.IOException;
21
22 import org.apache.lucene.search.ScoreDoc;
23 import org.apache.lucene.search.Sort;
24 import org.apache.lucene.search.SortField;
25 import org.apache.lucene.search.TopDocs;
26
27 /** Represents result returned by a grouping search.
28  *
29  * @lucene.experimental */
30 public class TopGroups<GROUP_VALUE_TYPE> {
31   /** Number of documents matching the search */
32   public final int totalHitCount;
33
34   /** Number of documents grouped into the topN groups */
35   public final int totalGroupedHitCount;
36
37   /** The total number of unique groups. If <code>null</code> this value is not computed. */
38   public final Integer totalGroupCount;
39
40   /** Group results in groupSort order */
41   public final GroupDocs<GROUP_VALUE_TYPE>[] groups;
42
43   /** How groups are sorted against each other */
44   public final SortField[] groupSort;
45
46   /** How docs are sorted within each group */
47   public final SortField[] withinGroupSort;
48
49   public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, int totalHitCount, int totalGroupedHitCount, GroupDocs<GROUP_VALUE_TYPE>[] groups) {
50     this.groupSort = groupSort;
51     this.withinGroupSort = withinGroupSort;
52     this.totalHitCount = totalHitCount;
53     this.totalGroupedHitCount = totalGroupedHitCount;
54     this.groups = groups;
55     this.totalGroupCount = null;
56   }
57
58   public TopGroups(TopGroups<GROUP_VALUE_TYPE> oldTopGroups, Integer totalGroupCount) {
59     this.groupSort = oldTopGroups.groupSort;
60     this.withinGroupSort = oldTopGroups.withinGroupSort;
61     this.totalHitCount = oldTopGroups.totalHitCount;
62     this.totalGroupedHitCount = oldTopGroups.totalGroupedHitCount;
63     this.groups = oldTopGroups.groups;
64     this.totalGroupCount = totalGroupCount;
65   }
66
67   /** Merges an array of TopGroups, for example obtained
68    *  from the second-pass collector across multiple
69    *  shards.  Each TopGroups must have been sorted by the
70    *  same groupSort and docSort, and the top groups passed
71    *  to all second-pass collectors must be the same.
72    *
73    * <b>NOTE</b>: this cannot merge totalGroupCount; ie the
74    * returned TopGroups will have null totalGroupCount.
75    *
76    * <b>NOTE</b>: the topDocs in each GroupDocs is actually
77    * an instance of TopDocsAndShards
78    */
79   public static <T> TopGroups<T> merge(TopGroups<T>[] shardGroups, Sort groupSort, Sort docSort, int docOffset, int docTopN)
80     throws IOException {
81
82     //System.out.println("TopGroups.merge");
83
84     if (shardGroups.length == 0) {
85       return null;
86     }
87
88     int totalHitCount = 0;
89     int totalGroupedHitCount = 0;
90
91     final int numGroups = shardGroups[0].groups.length;
92     for(TopGroups<T> shard : shardGroups) {
93       if (numGroups != shard.groups.length) {
94         throw new IllegalArgumentException("number of groups differs across shards; you must pass same top groups to all shards' second-pass collector");
95       }
96       totalHitCount += shard.totalHitCount;
97       totalGroupedHitCount += shard.totalGroupedHitCount;
98     }
99
100     @SuppressWarnings("unchecked")
101     final GroupDocs<T>[] mergedGroupDocs = new GroupDocs[numGroups];
102
103     final TopDocs[] shardTopDocs = new TopDocs[shardGroups.length];
104
105     for(int groupIDX=0;groupIDX<numGroups;groupIDX++) {
106       final T groupValue = shardGroups[0].groups[groupIDX].groupValue;
107       //System.out.println("  merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues));
108       float maxScore = Float.MIN_VALUE;
109       int totalHits = 0;
110       for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) {
111         //System.out.println("    shard=" + shardIDX);
112         final TopGroups<T> shard = shardGroups[shardIDX];
113         final GroupDocs shardGroupDocs = shard.groups[groupIDX];
114         if (groupValue == null) {
115           if (shardGroupDocs.groupValue != null) {
116             throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector");
117           }
118         } else if (!groupValue.equals(shardGroupDocs.groupValue)) {
119           throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector");
120         }
121
122         /*
123         for(ScoreDoc sd : shardGroupDocs.scoreDocs) {
124           System.out.println("      doc=" + sd.doc);
125         }
126         */
127
128         shardTopDocs[shardIDX] = new TopDocs(shardGroupDocs.totalHits,
129                                              shardGroupDocs.scoreDocs,
130                                              shardGroupDocs.maxScore);
131         maxScore = Math.max(maxScore, shardGroupDocs.maxScore);
132         totalHits += shardGroupDocs.totalHits;
133       }
134
135       final TopDocs mergedTopDocs = TopDocs.merge(docSort, docOffset + docTopN, shardTopDocs);
136
137       // Slice;
138       final ScoreDoc[] mergedScoreDocs;
139       if (docOffset == 0) {
140         mergedScoreDocs = mergedTopDocs.scoreDocs;
141       } else if (docOffset >= mergedTopDocs.scoreDocs.length) {
142         mergedScoreDocs = new ScoreDoc[0];
143       } else {
144         mergedScoreDocs = new ScoreDoc[mergedTopDocs.scoreDocs.length - docOffset];
145         System.arraycopy(mergedTopDocs.scoreDocs,
146                          docOffset,
147                          mergedScoreDocs,
148                          0,
149                          mergedTopDocs.scoreDocs.length - docOffset);
150       }
151       //System.out.println("SHARDS=" + Arrays.toString(mergedTopDocs.shardIndex));
152       mergedGroupDocs[groupIDX] = new GroupDocs<T>(maxScore,
153                                                    totalHits,
154                                                    mergedScoreDocs,
155                                                    groupValue,
156                                                    shardGroups[0].groups[groupIDX].groupSortValues);
157     }
158
159     return new TopGroups<T>(groupSort.getSort(),
160                             docSort == null ? null : docSort.getSort(),
161                             totalHitCount,
162                             totalGroupedHitCount,
163                             mergedGroupDocs);
164   }
165 }