pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / TopGroups.java
1 package org.apache.lucene.search.grouping;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.search.ScoreDoc;
21 import org.apache.lucene.search.Sort;
22 import org.apache.lucene.search.SortField;
23 import org.apache.lucene.search.TopDocs;
24
25 import java.io.IOException;
26
27 /** Represents result returned by a grouping search.
28  *
29  * @lucene.experimental */
30 public class TopGroups<GROUP_VALUE_TYPE> {
31   /** Number of documents matching the search */
32   public final int totalHitCount;
33
34   /** Number of documents grouped into the topN groups */
35   public final int totalGroupedHitCount;
36
37   /** The total number of unique groups. If <code>null</code> this value is not computed. */
38   public final Integer totalGroupCount;
39
40   /** Group results in groupSort order */
41   public final GroupDocs<GROUP_VALUE_TYPE>[] groups;
42
43   /** How groups are sorted against each other */
44   public final SortField[] groupSort;
45
46   /** How docs are sorted within each group */
47   public final SortField[] withinGroupSort;
48
49   public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, int totalHitCount, int totalGroupedHitCount, GroupDocs<GROUP_VALUE_TYPE>[] groups) {
50     this.groupSort = groupSort;
51     this.withinGroupSort = withinGroupSort;
52     this.totalHitCount = totalHitCount;
53     this.totalGroupedHitCount = totalGroupedHitCount;
54     this.groups = groups;
55     this.totalGroupCount = null;
56   }
57
58   public TopGroups(TopGroups<GROUP_VALUE_TYPE> oldTopGroups, Integer totalGroupCount) {
59     this.groupSort = oldTopGroups.groupSort;
60     this.withinGroupSort = oldTopGroups.withinGroupSort;
61     this.totalHitCount = oldTopGroups.totalHitCount;
62     this.totalGroupedHitCount = oldTopGroups.totalGroupedHitCount;
63     this.groups = oldTopGroups.groups;
64     this.totalGroupCount = totalGroupCount;
65   }
66
67   /** Merges an array of TopGroups, for example obtained
68    *  from the second-pass collector across multiple
69    *  shards.  Each TopGroups must have been sorted by the
70    *  same groupSort and docSort, and the top groups passed
71    *  to all second-pass collectors must be the same.
72    *
73    * <b>NOTE</b>: We can't always compute an exact totalGroupCount.
74    * Documents belonging to a group may occur on more than
75    * one shard and thus the merged totalGroupCount can be
76    * higher than the actual totalGroupCount. In this case the
77    * totalGroupCount represents a upper bound. If the documents
78    * of one group do only reside in one shard then the
79    * totalGroupCount is exact.
80    *
81    * <b>NOTE</b>: the topDocs in each GroupDocs is actually
82    * an instance of TopDocsAndShards
83    */
84   public static <T> TopGroups<T> merge(TopGroups<T>[] shardGroups, Sort groupSort, Sort docSort, int docOffset, int docTopN)
85     throws IOException {
86
87     //System.out.println("TopGroups.merge");
88
89     if (shardGroups.length == 0) {
90       return null;
91     }
92
93     int totalHitCount = 0;
94     int totalGroupedHitCount = 0;
95     // Optionally merge the totalGroupCount.
96     Integer totalGroupCount = null;
97
98     final int numGroups = shardGroups[0].groups.length;
99     for(TopGroups<T> shard : shardGroups) {
100       if (numGroups != shard.groups.length) {
101         throw new IllegalArgumentException("number of groups differs across shards; you must pass same top groups to all shards' second-pass collector");
102       }
103       totalHitCount += shard.totalHitCount;
104       totalGroupedHitCount += shard.totalGroupedHitCount;
105       if (shard.totalGroupCount != null) {
106         if (totalGroupCount == null) {
107           totalGroupCount = 0;
108         }
109
110         totalGroupCount += shard.totalGroupCount;
111       }
112     }
113
114     @SuppressWarnings("unchecked")
115     final GroupDocs<T>[] mergedGroupDocs = new GroupDocs[numGroups];
116
117     final TopDocs[] shardTopDocs = new TopDocs[shardGroups.length];
118
119     for(int groupIDX=0;groupIDX<numGroups;groupIDX++) {
120       final T groupValue = shardGroups[0].groups[groupIDX].groupValue;
121       //System.out.println("  merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues));
122       float maxScore = Float.MIN_VALUE;
123       int totalHits = 0;
124       for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) {
125         //System.out.println("    shard=" + shardIDX);
126         final TopGroups<T> shard = shardGroups[shardIDX];
127         final GroupDocs shardGroupDocs = shard.groups[groupIDX];
128         if (groupValue == null) {
129           if (shardGroupDocs.groupValue != null) {
130             throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector");
131           }
132         } else if (!groupValue.equals(shardGroupDocs.groupValue)) {
133           throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector");
134         }
135
136         /*
137         for(ScoreDoc sd : shardGroupDocs.scoreDocs) {
138           System.out.println("      doc=" + sd.doc);
139         }
140         */
141
142         shardTopDocs[shardIDX] = new TopDocs(shardGroupDocs.totalHits,
143                                              shardGroupDocs.scoreDocs,
144                                              shardGroupDocs.maxScore);
145         maxScore = Math.max(maxScore, shardGroupDocs.maxScore);
146         totalHits += shardGroupDocs.totalHits;
147       }
148
149       final TopDocs mergedTopDocs = TopDocs.merge(docSort, docOffset + docTopN, shardTopDocs);
150
151       // Slice;
152       final ScoreDoc[] mergedScoreDocs;
153       if (docOffset == 0) {
154         mergedScoreDocs = mergedTopDocs.scoreDocs;
155       } else if (docOffset >= mergedTopDocs.scoreDocs.length) {
156         mergedScoreDocs = new ScoreDoc[0];
157       } else {
158         mergedScoreDocs = new ScoreDoc[mergedTopDocs.scoreDocs.length - docOffset];
159         System.arraycopy(mergedTopDocs.scoreDocs,
160                          docOffset,
161                          mergedScoreDocs,
162                          0,
163                          mergedTopDocs.scoreDocs.length - docOffset);
164       }
165       //System.out.println("SHARDS=" + Arrays.toString(mergedTopDocs.shardIndex));
166       mergedGroupDocs[groupIDX] = new GroupDocs<T>(maxScore,
167                                                    totalHits,
168                                                    mergedScoreDocs,
169                                                    groupValue,
170                                                    shardGroups[0].groups[groupIDX].groupSortValues);
171     }
172
173     if (totalGroupCount != null) {
174       TopGroups<T> result = new TopGroups<T>(groupSort.getSort(),
175                               docSort == null ? null : docSort.getSort(),
176                               totalHitCount,
177                               totalGroupedHitCount,
178                               mergedGroupDocs);
179       return new TopGroups<T>(result, totalGroupCount);
180     } else {
181       return new TopGroups<T>(groupSort.getSort(),
182                               docSort == null ? null : docSort.getSort(),
183                               totalHitCount,
184                               totalGroupedHitCount,
185                               mergedGroupDocs);
186     }
187   }
188 }