add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / contrib / grouping / src / test / org / apache / lucene / search / grouping / TestGrouping.java
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.lucene.search.grouping;
19
20 import java.io.IOException;
21 import java.util.*;
22
23 import org.apache.lucene.analysis.MockAnalyzer;
24 import org.apache.lucene.document.Document;
25 import org.apache.lucene.document.Field;
26 import org.apache.lucene.document.NumericField;
27 import org.apache.lucene.index.FieldInfo.IndexOptions;
28 import org.apache.lucene.index.IndexReader;
29 import org.apache.lucene.index.RandomIndexWriter;
30 import org.apache.lucene.index.Term;
31 import org.apache.lucene.search.*;
32 import org.apache.lucene.store.Directory;
33 import org.apache.lucene.util.LuceneTestCase;
34 import org.apache.lucene.util.ReaderUtil;
35 import org.apache.lucene.util._TestUtil;
36
37 // TODO
38 //   - should test relevance sort too
39 //   - test null
40 //   - test ties
41 //   - test compound sort
42
43 public class TestGrouping extends LuceneTestCase {
44
45   public void testBasic() throws Exception {
46
47     final String groupField = "author";
48
49     Directory dir = newDirectory();
50     RandomIndexWriter w = new RandomIndexWriter(
51                                random,
52                                dir,
53                                newIndexWriterConfig(TEST_VERSION_CURRENT,
54                                                     new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy()));
55     // 0
56     Document doc = new Document();
57     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
58     doc.add(new Field("content", "random text", Field.Store.YES, Field.Index.ANALYZED));
59     doc.add(new Field("id", "1", Field.Store.YES, Field.Index.NO));
60     w.addDocument(doc);
61
62     // 1
63     doc = new Document();
64     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
65     doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED));
66     doc.add(new Field("id", "2", Field.Store.YES, Field.Index.NO));
67     w.addDocument(doc);
68
69     // 2
70     doc = new Document();
71     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
72     doc.add(new Field("content", "some more random textual data", Field.Store.YES, Field.Index.ANALYZED));
73     doc.add(new Field("id", "3", Field.Store.YES, Field.Index.NO));
74     w.addDocument(doc);
75
76     // 3
77     doc = new Document();
78     doc.add(new Field(groupField, "author2", Field.Store.YES, Field.Index.ANALYZED));
79     doc.add(new Field("content", "some random text", Field.Store.YES, Field.Index.ANALYZED));
80     doc.add(new Field("id", "4", Field.Store.YES, Field.Index.NO));
81     w.addDocument(doc);
82
83     // 4
84     doc = new Document();
85     doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED));
86     doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED));
87     doc.add(new Field("id", "5", Field.Store.YES, Field.Index.NO));
88     w.addDocument(doc);
89
90     // 5
91     doc = new Document();
92     doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED));
93     doc.add(new Field("content", "random", Field.Store.YES, Field.Index.ANALYZED));
94     doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NO));
95     w.addDocument(doc);
96
97     // 6 -- no author field
98     doc = new Document();
99     doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED));
100     doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NO));
101     w.addDocument(doc);
102
103     IndexSearcher indexSearcher = new IndexSearcher(w.getReader());
104     w.close();
105
106     final Sort groupSort = Sort.RELEVANCE;
107     final TermFirstPassGroupingCollector c1 = new TermFirstPassGroupingCollector(groupField, groupSort, 10);
108     indexSearcher.search(new TermQuery(new Term("content", "random")), c1);
109
110     final TermSecondPassGroupingCollector c2 = new TermSecondPassGroupingCollector(groupField, c1.getTopGroups(0, true), groupSort, null, 5, true, false, true);
111     indexSearcher.search(new TermQuery(new Term("content", "random")), c2);
112     
113     final TopGroups groups = c2.getTopGroups(0);
114
115     assertEquals(7, groups.totalHitCount);
116     assertEquals(7, groups.totalGroupedHitCount);
117     assertEquals(4, groups.groups.length);
118
119     // relevance order: 5, 0, 3, 4, 1, 2, 6
120
121     // the later a document is added the higher this docId
122     // value
123     GroupDocs group = groups.groups[0];
124     assertEquals("author3", group.groupValue);
125     assertEquals(2, group.scoreDocs.length);
126     assertEquals(5, group.scoreDocs[0].doc);
127     assertEquals(4, group.scoreDocs[1].doc);
128     assertTrue(group.scoreDocs[0].score > group.scoreDocs[1].score);
129
130     group = groups.groups[1];
131     assertEquals("author1", group.groupValue);
132     assertEquals(3, group.scoreDocs.length);
133     assertEquals(0, group.scoreDocs[0].doc);
134     assertEquals(1, group.scoreDocs[1].doc);
135     assertEquals(2, group.scoreDocs[2].doc);
136     assertTrue(group.scoreDocs[0].score > group.scoreDocs[1].score);
137     assertTrue(group.scoreDocs[1].score > group.scoreDocs[2].score);
138
139     group = groups.groups[2];
140     assertEquals("author2", group.groupValue);
141     assertEquals(1, group.scoreDocs.length);
142     assertEquals(3, group.scoreDocs[0].doc);
143
144     group = groups.groups[3];
145     assertNull(group.groupValue);
146     assertEquals(1, group.scoreDocs.length);
147     assertEquals(6, group.scoreDocs[0].doc);
148
149     indexSearcher.getIndexReader().close();
150     dir.close();
151   }
152
153   private static class GroupDoc {
154     final int id;
155     final String group;
156     final String sort1;
157     final String sort2;
158     // content must be "realN ..."
159     final String content;
160     float score;
161     float score2;
162
163     public GroupDoc(int id, String group, String sort1, String sort2, String content) {
164       this.id = id;
165       this.group = group;
166       this.sort1 = sort1;
167       this.sort2 = sort2;
168       this.content = content;
169     }
170   }
171
172   private Sort getRandomSort() {
173     final List<SortField> sortFields = new ArrayList<SortField>();
174     if (random.nextInt(7) == 2) {
175       sortFields.add(SortField.FIELD_SCORE);
176     } else {
177       if (random.nextBoolean()) {
178         if (random.nextBoolean()) {
179           sortFields.add(new SortField("sort1", SortField.STRING, random.nextBoolean()));
180         } else {
181           sortFields.add(new SortField("sort2", SortField.STRING, random.nextBoolean()));
182         }
183       } else if (random.nextBoolean()) {
184         sortFields.add(new SortField("sort1", SortField.STRING, random.nextBoolean()));
185         sortFields.add(new SortField("sort2", SortField.STRING, random.nextBoolean()));
186       }
187     }
188     // Break ties:
189     sortFields.add(new SortField("id", SortField.INT));
190     return new Sort(sortFields.toArray(new SortField[sortFields.size()]));
191   }
192
193   private Comparator<GroupDoc> getComparator(Sort sort) {
194     final SortField[] sortFields = sort.getSort();
195     return new Comparator<GroupDoc>() {
196       // @Override -- Not until Java 1.6
197       public int compare(GroupDoc d1, GroupDoc d2) {
198         for(SortField sf : sortFields) {
199           final int cmp;
200           if (sf.getType() == SortField.SCORE) {
201             if (d1.score > d2.score) {
202               cmp = -1;
203             } else if (d1.score < d2.score) {
204               cmp = 1;
205             } else {
206               cmp = 0;
207             }
208           } else if (sf.getField().equals("sort1")) {
209             cmp = d1.sort1.compareTo(d2.sort1);
210           } else if (sf.getField().equals("sort2")) {
211             cmp = d1.sort2.compareTo(d2.sort2);
212           } else {
213             assertEquals(sf.getField(), "id");
214             cmp = d1.id - d2.id;
215           }
216           if (cmp != 0) {
217             return sf.getReverse() ? -cmp : cmp;
218           }
219         }
220         // Our sort always fully tie breaks:
221         fail();
222         return 0;
223       }
224     };
225   }
226
227   private Comparable<?>[] fillFields(GroupDoc d, Sort sort) {
228     final SortField[] sortFields = sort.getSort();
229     final Comparable<?>[] fields = new Comparable[sortFields.length];
230     for(int fieldIDX=0;fieldIDX<sortFields.length;fieldIDX++) {
231       final Comparable<?> c;
232       final SortField sf = sortFields[fieldIDX];
233       if (sf.getType() == SortField.SCORE) {
234         c = new Float(d.score);
235       } else if (sf.getField().equals("sort1")) {
236         c = d.sort1;
237       } else if (sf.getField().equals("sort2")) {
238         c = d.sort2;
239       } else {
240         assertEquals("id", sf.getField());
241         c = new Integer(d.id);
242       }
243       fields[fieldIDX] = c;
244     }
245     return fields;
246   }
247
248   private String groupToString(String b) {
249     if (b == null) {
250       return "null";
251     } else {
252       return b;
253     }
254   }
255
256   private TopGroups<String> slowGrouping(GroupDoc[] groupDocs,
257                                          String searchTerm,
258                                          boolean fillFields,
259                                          boolean getScores,
260                                          boolean getMaxScores,
261                                          boolean doAllGroups,
262                                          Sort groupSort,
263                                          Sort docSort,
264                                          int topNGroups,
265                                          int docsPerGroup,
266                                          int groupOffset,
267                                          int docOffset) {
268
269     final Comparator<GroupDoc> groupSortComp = getComparator(groupSort);
270
271     Arrays.sort(groupDocs, groupSortComp);
272     final HashMap<String,List<GroupDoc>> groups = new HashMap<String,List<GroupDoc>>();
273     final List<String> sortedGroups = new ArrayList<String>();
274     final List<Comparable<?>[]> sortedGroupFields = new ArrayList<Comparable<?>[]>();
275
276     int totalHitCount = 0;
277     Set<String> knownGroups = new HashSet<String>();
278
279     //System.out.println("TEST: slowGrouping");
280     for(GroupDoc d : groupDocs) {
281       // TODO: would be better to filter by searchTerm before sorting!
282       if (!d.content.startsWith(searchTerm)) {
283         continue;
284       }
285       totalHitCount++;
286
287       //System.out.println("  match id=" + d.id + " score=" + d.score);
288
289       if (doAllGroups) {
290         if (!knownGroups.contains(d.group)) {
291           knownGroups.add(d.group);
292           //System.out.println("    add group=" + groupToString(d.group));
293         }
294       }
295
296       List<GroupDoc> l = groups.get(d.group);
297       if (l == null) {
298         //System.out.println("    add sortedGroup=" + groupToString(d.group));
299         sortedGroups.add(d.group);
300         if (fillFields) {
301           sortedGroupFields.add(fillFields(d, groupSort));
302         }
303         l = new ArrayList<GroupDoc>();
304         groups.put(d.group, l);
305       }
306       l.add(d);
307     }
308
309     if (groupOffset >= sortedGroups.size()) {
310       // slice is out of bounds
311       return null;
312     }
313
314     final int limit = Math.min(groupOffset + topNGroups, groups.size());
315
316     final Comparator<GroupDoc> docSortComp = getComparator(docSort);
317     @SuppressWarnings("unchecked")
318     final GroupDocs<String>[] result = new GroupDocs[limit-groupOffset];
319     int totalGroupedHitCount = 0;
320     for(int idx=groupOffset;idx < limit;idx++) {
321       final String group = sortedGroups.get(idx);
322       final List<GroupDoc> docs = groups.get(group);
323       totalGroupedHitCount += docs.size();
324       Collections.sort(docs, docSortComp);
325       final ScoreDoc[] hits;
326       if (docs.size() > docOffset) {
327         final int docIDXLimit = Math.min(docOffset + docsPerGroup, docs.size());
328         hits = new ScoreDoc[docIDXLimit - docOffset];
329         for(int docIDX=docOffset; docIDX < docIDXLimit; docIDX++) {
330           final GroupDoc d = docs.get(docIDX);
331           final FieldDoc fd;
332           if (fillFields) {
333             fd = new FieldDoc(d.id, getScores ? d.score : Float.NaN, fillFields(d, docSort));
334           } else {
335             fd = new FieldDoc(d.id, getScores ? d.score : Float.NaN);
336           }
337           hits[docIDX-docOffset] = fd;
338         }
339       } else  {
340         hits = new ScoreDoc[0];
341       }
342
343       result[idx-groupOffset] = new GroupDocs<String>(0.0f,
344                                                       docs.size(),
345                                                       hits,
346                                                       group,
347                                                       fillFields ? sortedGroupFields.get(idx) : null);
348     }
349
350     if (doAllGroups) {
351       return new TopGroups<String>(
352                   new TopGroups<String>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result),
353                    knownGroups.size()
354       );
355     } else {
356       return new TopGroups<String>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result);
357     }
358   }
359
360   private IndexReader getDocBlockReader(Directory dir, GroupDoc[] groupDocs) throws IOException {
361     // Coalesce by group, but in random order:
362     Collections.shuffle(Arrays.asList(groupDocs), random);
363     final Map<String,List<GroupDoc>> groupMap = new HashMap<String,List<GroupDoc>>();
364     final List<String> groupValues = new ArrayList<String>();
365     
366     for(GroupDoc groupDoc : groupDocs) {
367       if (!groupMap.containsKey(groupDoc.group)) {
368         groupValues.add(groupDoc.group);
369         groupMap.put(groupDoc.group, new ArrayList<GroupDoc>());
370       }
371       groupMap.get(groupDoc.group).add(groupDoc);
372     }
373
374     RandomIndexWriter w = new RandomIndexWriter(
375                                                 random,
376                                                 dir,
377                                                 newIndexWriterConfig(TEST_VERSION_CURRENT,
378                                                                      new MockAnalyzer(random)));
379
380     final List<List<Document>> updateDocs = new ArrayList<List<Document>>();
381     //System.out.println("TEST: index groups");
382     for(String group : groupValues) {
383       final List<Document> docs = new ArrayList<Document>();
384       //System.out.println("TEST:   group=" + (group == null ? "null" : group.utf8ToString()));
385       for(GroupDoc groupValue : groupMap.get(group)) {
386         Document doc = new Document();
387         docs.add(doc);
388         if (groupValue.group != null) {
389           doc.add(newField("group", groupValue.group, Field.Index.NOT_ANALYZED));
390         }
391         doc.add(newField("sort1", groupValue.sort1, Field.Index.NOT_ANALYZED));
392         doc.add(newField("sort2", groupValue.sort2, Field.Index.NOT_ANALYZED));
393         doc.add(new NumericField("id").setIntValue(groupValue.id));
394         doc.add(newField("content", groupValue.content, Field.Index.ANALYZED));
395         //System.out.println("TEST:     doc content=" + groupValue.content + " group=" + (groupValue.group == null ? "null" : groupValue.group.utf8ToString()) + " sort1=" + groupValue.sort1.utf8ToString() + " id=" + groupValue.id);
396       }
397       // So we can pull filter marking last doc in block:
398       final Field groupEnd = newField("groupend", "x", Field.Index.NOT_ANALYZED);
399       groupEnd.setIndexOptions(IndexOptions.DOCS_ONLY);
400       groupEnd.setOmitNorms(true);
401       docs.get(docs.size()-1).add(groupEnd);
402       // Add as a doc block:
403       w.addDocuments(docs);
404       if (group != null && random.nextInt(7) == 4) {
405         updateDocs.add(docs);
406       }
407     }
408
409     for(List<Document> docs : updateDocs) {
410       // Just replaces docs w/ same docs:
411       w.updateDocuments(new Term("group", docs.get(0).get("group")),
412                         docs);
413     }
414
415     final IndexReader r = w.getReader();
416     w.close();
417
418     return r;
419   }
420
421   private static class ShardState {
422
423     public final ShardSearcher[] subSearchers;
424     public final int[] docStarts;
425
426     public ShardState(IndexSearcher s) {
427       IndexReader[] subReaders = s.getIndexReader().getSequentialSubReaders();
428       if (subReaders == null) {
429         subReaders = new IndexReader[] {s.getIndexReader()};
430       }
431       subSearchers = new ShardSearcher[subReaders.length];
432       for(int searcherIDX=0;searcherIDX<subSearchers.length;searcherIDX++) { 
433         subSearchers[searcherIDX] = new ShardSearcher(subReaders[searcherIDX]);
434       }
435
436       docStarts = new int[subSearchers.length];
437       int docBase = 0;
438       for(int subIDX=0;subIDX<docStarts.length;subIDX++) {
439         docStarts[subIDX] = docBase;
440         docBase += subReaders[subIDX].maxDoc();
441         //System.out.println("docStarts[" + subIDX + "]=" + docStarts[subIDX]);
442       }
443     }
444   }
445
446   public void testRandom() throws Exception {
447     for(int iter=0;iter<3;iter++) {
448
449       if (VERBOSE) {
450         System.out.println("TEST: iter=" + iter);
451       }
452
453       final int numDocs = _TestUtil.nextInt(random, 100, 1000) * RANDOM_MULTIPLIER;
454       //final int numDocs = _TestUtil.nextInt(random, 5, 20);
455
456       final int numGroups = _TestUtil.nextInt(random, 1, numDocs);
457
458       if (VERBOSE) {
459         System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups);
460       }
461       
462       final List<String> groups = new ArrayList<String>();
463       for(int i=0;i<numGroups;i++) {
464         groups.add(_TestUtil.randomRealisticUnicodeString(random));
465         //groups.add(_TestUtil.randomUnicodeString(random));
466         assertEquals(-1, groups.get(groups.size()-1).indexOf(0xffff));
467         //groups.add(new BytesRef(_TestUtil.randomSimpleString(random)));
468       }
469       final String[] contentStrings = new String[_TestUtil.nextInt(random, 2, 20)];
470       if (VERBOSE) {
471         System.out.println("TEST: create fake content");
472       }
473       for(int contentIDX=0;contentIDX<contentStrings.length;contentIDX++) {
474         final StringBuilder sb = new StringBuilder();
475         sb.append("real" + random.nextInt(3)).append(' ');
476         final int fakeCount = random.nextInt(10);
477         for(int fakeIDX=0;fakeIDX<fakeCount;fakeIDX++) {
478           sb.append("fake ");
479         }
480         contentStrings[contentIDX] = sb.toString();
481         if (VERBOSE) {
482           System.out.println("  content=" + sb.toString());
483         }
484       }
485
486       Directory dir = newDirectory();
487       RandomIndexWriter w = new RandomIndexWriter(
488                                                   random,
489                                                   dir,
490                                                   newIndexWriterConfig(TEST_VERSION_CURRENT,
491                                                                        new MockAnalyzer(random)));
492
493       Document doc = new Document();
494       Document docNoGroup = new Document();
495       Field group = newField("group", "", Field.Index.NOT_ANALYZED);
496       doc.add(group);
497       Field sort1 = newField("sort1", "", Field.Index.NOT_ANALYZED);
498       doc.add(sort1);
499       docNoGroup.add(sort1);
500       Field sort2 = newField("sort2", "", Field.Index.NOT_ANALYZED);
501       doc.add(sort2);
502       docNoGroup.add(sort2);
503       Field content = newField("content", "", Field.Index.ANALYZED);
504       doc.add(content);
505       docNoGroup.add(content);
506       NumericField id = new NumericField("id");
507       doc.add(id);
508       docNoGroup.add(id);
509       final GroupDoc[] groupDocs = new GroupDoc[numDocs];
510       for(int i=0;i<numDocs;i++) {
511         final String groupValue;
512         if (random.nextInt(24) == 17) {
513           // So we test the "doc doesn't have the group'd
514           // field" case:
515           groupValue = null;
516         } else {
517           groupValue = groups.get(random.nextInt(groups.size()));
518         }
519         final GroupDoc groupDoc = new GroupDoc(i,
520                                                groupValue,
521                                                groups.get(random.nextInt(groups.size())),
522                                                groups.get(random.nextInt(groups.size())),
523                                                contentStrings[random.nextInt(contentStrings.length)]);
524         if (VERBOSE) {
525           System.out.println("  doc content=" + groupDoc.content + " id=" + i + " group=" + (groupDoc.group == null ? "null" : groupDoc.group) + " sort1=" + groupDoc.sort1 + " sort2=" + groupDoc.sort2);
526         }
527
528         groupDocs[i] = groupDoc;
529         if (groupDoc.group != null) {
530           group.setValue(groupDoc.group);
531         }
532         sort1.setValue(groupDoc.sort1);
533         sort2.setValue(groupDoc.sort2);
534         content.setValue(groupDoc.content);
535         id.setIntValue(groupDoc.id);
536         if (groupDoc.group == null) {
537           w.addDocument(docNoGroup);
538         } else {
539           w.addDocument(doc);
540         }
541       }
542
543       final GroupDoc[] groupDocsByID = new GroupDoc[groupDocs.length];
544       System.arraycopy(groupDocs, 0, groupDocsByID, 0, groupDocs.length);
545
546       final IndexReader r = w.getReader();
547       w.close();
548
549       // NOTE: intentional but temporary field cache insanity!
550       final int[] docIDToID = FieldCache.DEFAULT.getInts(r, "id");
551       IndexReader r2 = null;
552       Directory dir2 = null;
553
554       try {
555         final IndexSearcher s = newSearcher(r);
556         final ShardState shards = new ShardState(s);
557
558         for(int contentID=0;contentID<3;contentID++) {
559           final ScoreDoc[] hits = s.search(new TermQuery(new Term("content", "real"+contentID)), numDocs).scoreDocs;
560           for(ScoreDoc hit : hits) {
561             final GroupDoc gd = groupDocs[docIDToID[hit.doc]];
562             assertTrue(gd.score == 0.0);
563             gd.score = hit.score;
564             assertEquals(gd.id, docIDToID[hit.doc]);
565             //System.out.println("  score=" + hit.score + " id=" + docIDToID[hit.doc]);
566           }
567         }
568
569         for(GroupDoc gd : groupDocs) {
570           assertTrue(gd.score != 0.0);
571         }
572
573         // Build 2nd index, where docs are added in blocks by
574         // group, so we can use single pass collector
575         dir2 = newDirectory();
576         r2 = getDocBlockReader(dir2, groupDocs);
577         final Filter lastDocInBlock = new CachingWrapperFilter(new QueryWrapperFilter(new TermQuery(new Term("groupend", "x"))));
578         final int[] docIDToID2 = FieldCache.DEFAULT.getInts(r2, "id");
579
580         final IndexSearcher s2 = newSearcher(r2);
581         final ShardState shards2 = new ShardState(s2);
582
583         // Reader2 only increases maxDoc() vs reader, which
584         // means a monotonic shift in scores, so we can
585         // reliably remap them w/ Map:
586         final Map<String,Map<Float,Float>> scoreMap = new HashMap<String,Map<Float,Float>>();
587
588         // Tricky: must separately set .score2, because the doc
589         // block index was created with possible deletions!
590         //System.out.println("fixup score2");
591         for(int contentID=0;contentID<3;contentID++) {
592           //System.out.println("  term=real" + contentID);
593           final Map<Float,Float> termScoreMap = new HashMap<Float,Float>();
594           scoreMap.put("real"+contentID, termScoreMap);
595           //System.out.println("term=real" + contentID + " dfold=" + s.docFreq(new Term("content", "real"+contentID)) +
596           //" dfnew=" + s2.docFreq(new Term("content", "real"+contentID)));
597           final ScoreDoc[] hits = s2.search(new TermQuery(new Term("content", "real"+contentID)), numDocs).scoreDocs;
598           for(ScoreDoc hit : hits) {
599             final GroupDoc gd = groupDocsByID[docIDToID2[hit.doc]];
600             assertTrue(gd.score2 == 0.0);
601             gd.score2 = hit.score;
602             assertEquals(gd.id, docIDToID2[hit.doc]);
603             //System.out.println("    score=" + gd.score + " score2=" + hit.score + " id=" + docIDToID2[hit.doc]);
604             termScoreMap.put(gd.score, gd.score2);
605           }
606         }
607
608         for(int searchIter=0;searchIter<100;searchIter++) {
609
610           if (VERBOSE) {
611             System.out.println("TEST: searchIter=" + searchIter);
612           }
613
614           final String searchTerm = "real" + random.nextInt(3);
615           final boolean fillFields = random.nextBoolean();
616           boolean getScores = random.nextBoolean();
617           final boolean getMaxScores = random.nextBoolean();
618           final Sort groupSort = getRandomSort();
619           //final Sort groupSort = new Sort(new SortField[] {new SortField("sort1", SortField.STRING), new SortField("id", SortField.INT)});
620           // TODO: also test null (= sort by relevance)
621           final Sort docSort = getRandomSort();
622
623           for(SortField sf : docSort.getSort()) {
624             if (sf.getType() == SortField.SCORE) {
625               getScores = true;
626             }
627           }
628
629           for(SortField sf : groupSort.getSort()) {
630             if (sf.getType() == SortField.SCORE) {
631               getScores = true;
632             }
633           }
634
635           final int topNGroups = _TestUtil.nextInt(random, 1, 30);
636           //final int topNGroups = 10;
637           final int docsPerGroup = _TestUtil.nextInt(random, 1, 50);
638
639           final int groupOffset = _TestUtil.nextInt(random, 0, (topNGroups-1)/2);
640           //final int groupOffset = 0;
641
642           final int docOffset = _TestUtil.nextInt(random, 0, docsPerGroup-1);
643           //final int docOffset = 0;
644
645           final boolean doCache = random.nextBoolean();
646           final boolean doAllGroups = random.nextBoolean();
647           if (VERBOSE) {
648             System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getScores=" + getScores + " getMaxScores=" + getMaxScores);
649           }
650
651           final TermAllGroupsCollector allGroupsCollector;
652           if (doAllGroups) {
653             allGroupsCollector = new TermAllGroupsCollector("group");
654           } else {
655             allGroupsCollector = null;
656           }
657
658           final TermFirstPassGroupingCollector c1 = new TermFirstPassGroupingCollector("group", groupSort, groupOffset+topNGroups);
659           final CachingCollector cCache;
660           final Collector c;
661
662           final boolean useWrappingCollector = random.nextBoolean();
663         
664           if (doCache) {
665             final double maxCacheMB = random.nextDouble();
666             if (VERBOSE) {
667               System.out.println("TEST: maxCacheMB=" + maxCacheMB);
668             }
669
670             if (useWrappingCollector) {
671               if (doAllGroups) {
672                 cCache = CachingCollector.create(c1, true, maxCacheMB);              
673                 c = MultiCollector.wrap(cCache, allGroupsCollector);
674               } else {
675                 c = cCache = CachingCollector.create(c1, true, maxCacheMB);              
676               }
677             } else {
678               // Collect only into cache, then replay multiple times:
679               c = cCache = CachingCollector.create(false, true, maxCacheMB);
680             }
681           } else {
682             cCache = null;
683             if (doAllGroups) {
684               c = MultiCollector.wrap(c1, allGroupsCollector);
685             } else {
686               c = c1;
687             }
688           }
689         
690           // Search top reader:
691           final Query q = new TermQuery(new Term("content", searchTerm));
692           s.search(q, c);
693
694           if (doCache && !useWrappingCollector) {
695             if (cCache.isCached()) {
696               // Replay for first-pass grouping
697               cCache.replay(c1);
698               if (doAllGroups) {
699                 // Replay for all groups:
700                 cCache.replay(allGroupsCollector);
701               }
702             } else {
703               // Replay by re-running search:
704               s.search(new TermQuery(new Term("content", searchTerm)), c1);
705               if (doAllGroups) {
706                 s.search(new TermQuery(new Term("content", searchTerm)), allGroupsCollector);
707               }
708             }
709           }
710
711           final Collection<SearchGroup<String>> topGroups = c1.getTopGroups(groupOffset, fillFields);
712           final TopGroups groupsResult;
713           if (VERBOSE) {
714             System.out.println("TEST: topGroups:");
715             if (topGroups == null) {
716               System.out.println("  null");
717             } else {
718               for(SearchGroup<String> groupx : topGroups) {
719                 System.out.println("    " + groupToString(groupx.groupValue) + " sort=" + Arrays.toString(groupx.sortValues));
720               }
721             }
722           }
723           
724           final TopGroups<String> topGroupsShards = searchShards(s, shards, q, groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores);
725
726           if (topGroups != null) {
727
728             if (VERBOSE) {
729               System.out.println("TEST: topGroups");
730               for (SearchGroup<String> searchGroup : topGroups) {
731                 System.out.println("  " + (searchGroup.groupValue == null ? "null" : searchGroup.groupValue) + ": " + Arrays.deepToString(searchGroup.sortValues));
732               }
733             }
734
735             final TermSecondPassGroupingCollector c2 = new TermSecondPassGroupingCollector("group", topGroups, groupSort, docSort, docOffset+docsPerGroup, getScores, getMaxScores, fillFields);
736             if (doCache) {
737               if (cCache.isCached()) {
738                 if (VERBOSE) {
739                   System.out.println("TEST: cache is intact");
740                 }
741                 cCache.replay(c2);
742               } else {
743                 if (VERBOSE) {
744                   System.out.println("TEST: cache was too large");
745                 }
746                 s.search(new TermQuery(new Term("content", searchTerm)), c2);
747               }
748             } else {
749               s.search(new TermQuery(new Term("content", searchTerm)), c2);
750             }
751
752             if (doAllGroups) {
753               TopGroups<String> tempTopGroups = c2.getTopGroups(docOffset);
754               groupsResult = new TopGroups<String>(tempTopGroups, allGroupsCollector.getGroupCount());
755             } else {
756               groupsResult = c2.getTopGroups(docOffset);
757             }
758           } else {
759             groupsResult = null;
760             if (VERBOSE) {
761               System.out.println("TEST:   no results");
762             }
763           }
764         
765           final TopGroups<String> expectedGroups = slowGrouping(groupDocs, searchTerm, fillFields, getScores, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset);
766
767           if (VERBOSE) {
768             if (expectedGroups == null) {
769               System.out.println("TEST: no expected groups");
770             } else {
771               System.out.println("TEST: expected groups");
772               for(GroupDocs<String> gd : expectedGroups.groups) {
773                 System.out.println("  group=" + gd.groupValue);
774                 for(ScoreDoc sd : gd.scoreDocs) {
775                   System.out.println("    id=" + sd.doc + " score=" + sd.score);
776                 }
777               }
778             }
779           }
780           assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true, getScores);
781
782           // Confirm merged shards match: 
783           assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, fillFields, getScores);
784           if (topGroupsShards != null) {
785             verifyShards(shards.docStarts, topGroupsShards);
786           }
787
788           final boolean needsScores = getScores || getMaxScores || docSort == null;
789           final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups, needsScores, lastDocInBlock);
790           final TermAllGroupsCollector allGroupsCollector2;
791           final Collector c4;
792           if (doAllGroups) {
793             allGroupsCollector2 = new TermAllGroupsCollector("group");
794             c4 = MultiCollector.wrap(c3, allGroupsCollector2);
795           } else {
796             allGroupsCollector2 = null;
797             c4 = c3;
798           }
799           s2.search(new TermQuery(new Term("content", searchTerm)), c4);
800           @SuppressWarnings("unchecked")
801           final TopGroups<String> tempTopGroups2 = c3.getTopGroups(docSort, groupOffset, docOffset, docOffset+docsPerGroup, fillFields);
802           final TopGroups groupsResult2;
803           if (doAllGroups && tempTopGroups2 != null) {
804             assertEquals((int) tempTopGroups2.totalGroupCount, allGroupsCollector2.getGroupCount());
805             groupsResult2 = new TopGroups<String>(tempTopGroups2, allGroupsCollector2.getGroupCount());
806           } else {
807             groupsResult2 = tempTopGroups2;
808           }
809
810           final TopGroups<String> topGroupsBlockShards = searchShards(s2, shards2, q, groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores);
811
812           if (expectedGroups != null) {
813             // Fixup scores for reader2
814             for (GroupDocs groupDocsHits : expectedGroups.groups) {
815               for(ScoreDoc hit : groupDocsHits.scoreDocs) {
816                 final GroupDoc gd = groupDocsByID[hit.doc];
817                 assertEquals(gd.id, hit.doc);
818                 //System.out.println("fixup score " + hit.score + " to " + gd.score2 + " vs " + gd.score);
819                 hit.score = gd.score2;
820               }
821             }
822
823             final SortField[] sortFields = groupSort.getSort();
824             final Map<Float,Float> termScoreMap = scoreMap.get(searchTerm);
825             for(int groupSortIDX=0;groupSortIDX<sortFields.length;groupSortIDX++) {
826               if (sortFields[groupSortIDX].getType() == SortField.SCORE) {
827                 for (GroupDocs groupDocsHits : expectedGroups.groups) {
828                   if (groupDocsHits.groupSortValues != null) {
829                     //System.out.println("remap " + groupDocsHits.groupSortValues[groupSortIDX] + " to " + termScoreMap.get(groupDocsHits.groupSortValues[groupSortIDX]));
830                     groupDocsHits.groupSortValues[groupSortIDX] = termScoreMap.get(groupDocsHits.groupSortValues[groupSortIDX]);
831                     assertNotNull(groupDocsHits.groupSortValues[groupSortIDX]);
832                   }
833                 }
834               }
835             }
836
837             final SortField[] docSortFields = docSort.getSort();
838             for(int docSortIDX=0;docSortIDX<docSortFields.length;docSortIDX++) {
839               if (docSortFields[docSortIDX].getType() == SortField.SCORE) {
840                 for (GroupDocs groupDocsHits : expectedGroups.groups) {
841                   for(ScoreDoc _hit : groupDocsHits.scoreDocs) {
842                     FieldDoc hit = (FieldDoc) _hit;
843                     if (hit.fields != null) {
844                       hit.fields[docSortIDX] = termScoreMap.get(hit.fields[docSortIDX]);
845                       assertNotNull(hit.fields[docSortIDX]);
846                     }
847                   }
848                 }
849               }
850             }
851           }
852
853           assertEquals(docIDToID2, expectedGroups, groupsResult2, false, true, true, getScores);
854           assertEquals(docIDToID2, expectedGroups, topGroupsBlockShards, false, false, fillFields, getScores);
855         }
856         s.close();
857         s2.close();
858       } finally {
859         FieldCache.DEFAULT.purge(r);
860         if (r2 != null) {
861           FieldCache.DEFAULT.purge(r2);
862         }
863       }
864
865       r.close();
866       dir.close();
867
868       r2.close();
869       dir2.close();
870     }
871   }
872
873   private void verifyShards(int[] docStarts, TopGroups<String> topGroups) {
874     for(GroupDocs group : topGroups.groups) {
875       for(int hitIDX=0;hitIDX<group.scoreDocs.length;hitIDX++) {
876         final ScoreDoc sd = group.scoreDocs[hitIDX];
877         assertEquals("doc=" + sd.doc + " wrong shard",
878                      ReaderUtil.subIndex(sd.doc, docStarts),
879                      sd.shardIndex);
880       }
881     }
882   }
883
884   private void assertEquals(Collection<SearchGroup<String>> groups1, Collection<SearchGroup<String>> groups2, boolean doSortValues) {
885     assertEquals(groups1.size(), groups2.size());
886     final Iterator<SearchGroup<String>> iter1 = groups1.iterator();
887     final Iterator<SearchGroup<String>> iter2 = groups2.iterator();
888
889     while(iter1.hasNext()) {
890       assertTrue(iter2.hasNext());
891
892       SearchGroup<String> group1 = iter1.next();
893       SearchGroup<String> group2 = iter2.next();
894
895       assertEquals(group1.groupValue, group2.groupValue);
896       if (doSortValues) {
897         assertEquals(group1.sortValues, group2.sortValues);
898       }
899     }
900     assertFalse(iter2.hasNext());
901   }
902
903   private TopGroups<String> searchShards(IndexSearcher topSearcher, ShardState shardState, Query query, Sort groupSort, Sort docSort, int groupOffset, int topNGroups, int docOffset,
904                                          int topNDocs, boolean getScores, boolean getMaxScores) throws Exception {
905
906     // TODO: swap in caching, all groups collector here
907     // too...
908     if (VERBOSE) {
909       System.out.println("TEST: " + shardState.subSearchers.length + " shards: " + Arrays.toString(shardState.subSearchers));
910     }
911     // Run 1st pass collector to get top groups per shard
912     final Weight w = topSearcher.createNormalizedWeight(query);
913     final List<Collection<SearchGroup<String>>> shardGroups = new ArrayList<Collection<SearchGroup<String>>>();
914     for(int shardIDX=0;shardIDX<shardState.subSearchers.length;shardIDX++) {
915       final TermFirstPassGroupingCollector c = new TermFirstPassGroupingCollector("group", groupSort, groupOffset+topNGroups);
916       shardState.subSearchers[shardIDX].search(w, c);
917       final Collection<SearchGroup<String>> topGroups = c.getTopGroups(0, true);
918       if (topGroups != null) {
919         if (VERBOSE) {
920           System.out.println("  shard " + shardIDX + " s=" + shardState.subSearchers[shardIDX] + " " + topGroups.size() + " groups:");
921           for(SearchGroup<String> group : topGroups) {
922             System.out.println("    " + groupToString(group.groupValue) + " sort=" + Arrays.toString(group.sortValues));
923           }
924         }
925         shardGroups.add(topGroups);
926       }
927     }
928
929     final Collection<SearchGroup<String>> mergedTopGroups = SearchGroup.merge(shardGroups, groupOffset, topNGroups, groupSort);
930     if (VERBOSE) {
931       System.out.println("  merged:");
932       if (mergedTopGroups == null) {
933         System.out.println("    null");
934       } else {
935         for(SearchGroup<String> group : mergedTopGroups) {
936           System.out.println("    " + groupToString(group.groupValue) + " sort=" + Arrays.toString(group.sortValues));
937         }
938       }
939     }
940
941     if (mergedTopGroups != null) {
942
943       // Now 2nd pass:
944       @SuppressWarnings("unchecked")
945         final TopGroups<String>[] shardTopGroups = new TopGroups[shardState.subSearchers.length];
946       for(int shardIDX=0;shardIDX<shardState.subSearchers.length;shardIDX++) {
947         final TermSecondPassGroupingCollector c = new TermSecondPassGroupingCollector("group", mergedTopGroups, groupSort, docSort,
948                                                                                       docOffset + topNDocs, getScores, getMaxScores, true);
949         shardState.subSearchers[shardIDX].search(w, c);
950         shardTopGroups[shardIDX] = c.getTopGroups(0);
951         rebaseDocIDs(groupSort, docSort, shardState.docStarts[shardIDX], shardTopGroups[shardIDX]);
952       }
953
954       return TopGroups.merge(shardTopGroups, groupSort, docSort, docOffset, topNDocs);
955     } else {
956       return null;
957     }
958   }
959
960   private List<Integer> getDocIDSortLocs(Sort sort) {
961     List<Integer> docFieldLocs = new ArrayList<Integer>();
962     SortField[] docFields = sort.getSort();
963     for(int fieldIDX=0;fieldIDX<docFields.length;fieldIDX++) {
964       if (docFields[fieldIDX].getType() == SortField.DOC) {
965         docFieldLocs.add(fieldIDX);
966       }
967     }
968
969     return docFieldLocs;
970   }
971
972   private void rebaseDocIDs(Sort groupSort, Sort docSort, int docBase, TopGroups<String> groups) {
973
974     List<Integer> docFieldLocs = getDocIDSortLocs(docSort);
975     List<Integer> docGroupFieldLocs = getDocIDSortLocs(groupSort);
976
977     for(GroupDocs<String> group : groups.groups) {
978       if (group.groupSortValues != null) {
979         for(int idx : docGroupFieldLocs) {
980           group.groupSortValues[idx] = Integer.valueOf(((Integer) group.groupSortValues[idx]).intValue() + docBase);
981         }
982       }
983
984       for(int hitIDX=0;hitIDX<group.scoreDocs.length;hitIDX++) {
985         final ScoreDoc sd = group.scoreDocs[hitIDX];
986         sd.doc += docBase;
987         if (sd instanceof FieldDoc) {
988           final FieldDoc fd = (FieldDoc) sd;
989           if (fd.fields != null) {
990             for(int idx : docFieldLocs) {
991               fd.fields[idx] = Integer.valueOf(((Integer) fd.fields[idx]).intValue() + docBase);
992             }
993           }
994         }
995       }
996     }
997   }
998
999   private void assertEquals(int[] docIDtoID, TopGroups expected, TopGroups actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean verifySortValues, boolean testScores) {
1000     if (expected == null) {
1001       assertNull(actual);
1002       return;
1003     }
1004     assertNotNull(actual);
1005
1006     assertEquals(expected.groups.length, actual.groups.length);
1007     assertEquals(expected.totalHitCount, actual.totalHitCount);
1008     assertEquals(expected.totalGroupedHitCount, actual.totalGroupedHitCount);
1009     if (expected.totalGroupCount != null && verifyTotalGroupCount) {
1010       assertEquals(expected.totalGroupCount, actual.totalGroupCount);
1011     }
1012     
1013     for(int groupIDX=0;groupIDX<expected.groups.length;groupIDX++) {
1014       if (VERBOSE) {
1015         System.out.println("  check groupIDX=" + groupIDX);
1016       }
1017       final GroupDocs expectedGroup = expected.groups[groupIDX];
1018       final GroupDocs actualGroup = actual.groups[groupIDX];
1019       if (verifyGroupValues) {
1020         assertEquals(expectedGroup.groupValue, actualGroup.groupValue);
1021       }
1022       if (verifySortValues) {
1023         assertArrayEquals(expectedGroup.groupSortValues, actualGroup.groupSortValues);
1024       }
1025
1026       // TODO
1027       // assertEquals(expectedGroup.maxScore, actualGroup.maxScore);
1028       assertEquals(expectedGroup.totalHits, actualGroup.totalHits);
1029
1030       final ScoreDoc[] expectedFDs = expectedGroup.scoreDocs;
1031       final ScoreDoc[] actualFDs = actualGroup.scoreDocs;
1032
1033       assertEquals(expectedFDs.length, actualFDs.length);
1034       for(int docIDX=0;docIDX<expectedFDs.length;docIDX++) {
1035         final FieldDoc expectedFD = (FieldDoc) expectedFDs[docIDX];
1036         final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX];
1037         //System.out.println("  actual doc=" + docIDtoID[actualFD.doc] + " score=" + actualFD.score);
1038         assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]);
1039         if (testScores) {
1040           assertEquals(expectedFD.score, actualFD.score);
1041         } else {
1042           // TODO: too anal for now
1043           //assertEquals(Float.NaN, actualFD.score);
1044         }
1045         if (verifySortValues) {
1046           assertArrayEquals(expectedFD.fields, actualFD.fields);
1047         }
1048       }
1049     }
1050   }
1051
1052   private static class ShardSearcher {
1053     private final IndexSearcher subSearcher;
1054
1055     public ShardSearcher(IndexReader subReader) {
1056       this.subSearcher = new IndexSearcher(subReader);
1057     }
1058
1059     public void search(Weight weight, Collector collector) throws IOException {
1060       subSearcher.search(weight, null, collector);
1061     }
1062
1063     public TopDocs search(Weight weight, int topN) throws IOException {
1064       return subSearcher.search(weight, null, topN);
1065     }
1066
1067     @Override
1068     public String toString() {
1069       return "ShardSearcher(" + subSearcher + ")";
1070     }
1071   }
1072 }