add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / contrib / grouping / src / test / org / apache / lucene / search / grouping / TermAllGroupHeadsCollectorTest.java
1 package org.apache.lucene.search.grouping;
2
3 /*
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.analysis.MockAnalyzer;
21 import org.apache.lucene.document.Document;
22 import org.apache.lucene.document.Field;
23 import org.apache.lucene.document.NumericField;
24 import org.apache.lucene.index.IndexReader;
25 import org.apache.lucene.index.RandomIndexWriter;
26 import org.apache.lucene.index.Term;
27 import org.apache.lucene.search.*;
28 import org.apache.lucene.store.Directory;
29 import org.apache.lucene.util.FixedBitSet;
30 import org.apache.lucene.util.LuceneTestCase;
31 import org.apache.lucene.util._TestUtil;
32
33 import java.io.IOException;
34 import java.util.*;
35
36 public class TermAllGroupHeadsCollectorTest extends LuceneTestCase {
37
38   public void testBasic() throws Exception {
39     final String groupField = "author";
40     Directory dir = newDirectory();
41     RandomIndexWriter w = new RandomIndexWriter(
42         random,
43         dir,
44         newIndexWriterConfig(TEST_VERSION_CURRENT,
45             new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy()));
46
47     // 0
48     Document doc = new Document();
49     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
50     doc.add(new Field("content", "random text", Field.Store.YES, Field.Index.ANALYZED));
51     doc.add(new Field("id", "1", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
52     w.addDocument(doc);
53
54     // 1
55     doc = new Document();
56     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
57     doc.add(new Field("content", "some more random text blob", Field.Store.YES, Field.Index.ANALYZED));
58     doc.add(new Field("id", "2", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
59     w.addDocument(doc);
60
61     // 2
62     doc = new Document();
63     doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED));
64     doc.add(new Field("content", "some more random textual data", Field.Store.YES, Field.Index.ANALYZED));
65     doc.add(new Field("id", "3", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
66     w.addDocument(doc);
67     w.commit(); // To ensure a second segment
68
69     // 3
70     doc = new Document();
71     doc.add(new Field(groupField, "author2", Field.Store.YES, Field.Index.ANALYZED));
72     doc.add(new Field("content", "some random text", Field.Store.YES, Field.Index.ANALYZED));
73     doc.add(new Field("id", "4", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
74     w.addDocument(doc);
75
76     // 4
77     doc = new Document();
78     doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED));
79     doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED));
80     doc.add(new Field("id", "5", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
81     w.addDocument(doc);
82
83     // 5
84     doc = new Document();
85     doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED));
86     doc.add(new Field("content", "random blob", Field.Store.YES, Field.Index.ANALYZED));
87     doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
88     w.addDocument(doc);
89
90     // 6 -- no author field
91     doc = new Document();
92     doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED));
93     doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
94     w.addDocument(doc);
95
96     // 7 -- no author field
97     doc = new Document();
98     doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED));
99     doc.add(new Field("id", "7", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS));
100     w.addDocument(doc);
101
102     IndexSearcher indexSearcher = new IndexSearcher(w.getReader());
103     w.close();
104     int maxDoc = indexSearcher.maxDoc();
105
106     Sort sortWithinGroup = new Sort(new SortField("id", SortField.INT, true));
107     AbstractAllGroupHeadsCollector c1 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup);
108     indexSearcher.search(new TermQuery(new Term("content", "random")), c1);
109     assertTrue(arrayContains(new int[]{2, 3, 5, 7}, c1.retrieveGroupHeads()));
110     assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, c1.retrieveGroupHeads(maxDoc), maxDoc));
111
112     AbstractAllGroupHeadsCollector c2 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup);
113     indexSearcher.search(new TermQuery(new Term("content", "some")), c2);
114     assertTrue(arrayContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads()));
115     assertTrue(openBitSetContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads(maxDoc), maxDoc));
116
117     AbstractAllGroupHeadsCollector c3 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup);
118     indexSearcher.search(new TermQuery(new Term("content", "blob")), c3);
119     assertTrue(arrayContains(new int[]{1, 5}, c3.retrieveGroupHeads()));
120     assertTrue(openBitSetContains(new int[]{1, 5}, c3.retrieveGroupHeads(maxDoc), maxDoc));
121
122     // STRING sort type triggers different implementation
123     Sort sortWithinGroup2 = new Sort(new SortField("id", SortField.STRING, true));
124     AbstractAllGroupHeadsCollector c4 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup2);
125     indexSearcher.search(new TermQuery(new Term("content", "random")), c4);
126     assertTrue(arrayContains(new int[]{2, 3, 5, 7}, c4.retrieveGroupHeads()));
127     assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, c4.retrieveGroupHeads(maxDoc), maxDoc));
128
129     Sort sortWithinGroup3 = new Sort(new SortField("id", SortField.STRING, false));
130     AbstractAllGroupHeadsCollector c5 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup3);
131     indexSearcher.search(new TermQuery(new Term("content", "random")), c5);
132     // 7 b/c higher doc id wins, even if order of field is in not in reverse.
133     assertTrue(arrayContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads()));
134     assertTrue(openBitSetContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads(maxDoc), maxDoc));
135
136     indexSearcher.getIndexReader().close();
137     dir.close();
138   }
139
140   public void testRandom() throws Exception {
141     int numberOfRuns = _TestUtil.nextInt(random, 3, 6);
142     for (int iter = 0; iter < numberOfRuns; iter++) {
143       if (VERBOSE) {
144         System.out.println(String.format("TEST: iter=%d total=%d", iter, numberOfRuns));
145       }
146
147       final int numDocs = _TestUtil.nextInt(random, 100, 1000) * RANDOM_MULTIPLIER;
148       final int numGroups = _TestUtil.nextInt(random, 1, numDocs);
149
150       if (VERBOSE) {
151         System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups);
152       }
153
154       final List<String> groups = new ArrayList<String>();
155       for (int i = 0; i < numGroups; i++) {
156         groups.add(_TestUtil.randomRealisticUnicodeString(random));
157       }
158       final String[] contentStrings = new String[_TestUtil.nextInt(random, 2, 20)];
159       if (VERBOSE) {
160         System.out.println("TEST: create fake content");
161       }
162       for (int contentIDX = 0; contentIDX < contentStrings.length; contentIDX++) {
163         final StringBuilder sb = new StringBuilder();
164         sb.append("real").append(random.nextInt(3)).append(' ');
165         final int fakeCount = random.nextInt(10);
166         for (int fakeIDX = 0; fakeIDX < fakeCount; fakeIDX++) {
167           sb.append("fake ");
168         }
169         contentStrings[contentIDX] = sb.toString();
170         if (VERBOSE) {
171           System.out.println("  content=" + sb.toString());
172         }
173       }
174
175       Directory dir = newDirectory();
176       RandomIndexWriter w = new RandomIndexWriter(
177           random,
178           dir,
179           newIndexWriterConfig(TEST_VERSION_CURRENT,
180               new MockAnalyzer(random)));
181
182       Document doc = new Document();
183       Document docNoGroup = new Document();
184       Field group = newField("group", "", Field.Index.NOT_ANALYZED);
185       doc.add(group);
186       Field sort1 = newField("sort1", "", Field.Index.NOT_ANALYZED);
187       doc.add(sort1);
188       docNoGroup.add(sort1);
189       Field sort2 = newField("sort2", "", Field.Index.NOT_ANALYZED);
190       doc.add(sort2);
191       docNoGroup.add(sort2);
192       Field sort3 = newField("sort3", "", Field.Index.NOT_ANALYZED);
193       doc.add(sort3);
194       docNoGroup.add(sort3);
195       Field content = newField("content", "", Field.Index.ANALYZED);
196       doc.add(content);
197       docNoGroup.add(content);
198       NumericField id = new NumericField("id");
199       doc.add(id);
200       docNoGroup.add(id);
201       final GroupDoc[] groupDocs = new GroupDoc[numDocs];
202       for (int i = 0; i < numDocs; i++) {
203         final String groupValue;
204         if (random.nextInt(24) == 17) {
205           // So we test the "doc doesn't have the group'd
206           // field" case:
207           groupValue = null;
208         } else {
209           groupValue = groups.get(random.nextInt(groups.size()));
210         }
211
212         final GroupDoc groupDoc = new GroupDoc(
213             i,
214             groupValue,
215             groups.get(random.nextInt(groups.size())),
216             groups.get(random.nextInt(groups.size())),
217             String.format("%05d", i),
218             contentStrings[random.nextInt(contentStrings.length)]
219         );
220
221         if (VERBOSE) {
222           System.out.println("  doc content=" + groupDoc.content + " id=" + i + " group=" + (groupDoc.group == null ? "null" : groupDoc.group) + " sort1=" + groupDoc.sort1 + " sort2=" + groupDoc.sort2 + " sort3=" + groupDoc.sort3);
223         }
224
225         groupDocs[i] = groupDoc;
226         if (groupDoc.group != null) {
227           group.setValue(groupDoc.group);
228         }
229         sort1.setValue(groupDoc.sort1);
230         sort2.setValue(groupDoc.sort2);
231         sort3.setValue(groupDoc.sort3);
232         content.setValue(groupDoc.content);
233         id.setIntValue(groupDoc.id);
234         if (groupDoc.group == null) {
235           w.addDocument(docNoGroup);
236         } else {
237           w.addDocument(doc);
238         }
239       }
240
241       final IndexReader r = w.getReader();
242       w.close();
243
244       // NOTE: intentional but temporary field cache insanity!
245       final int[] docIdToFieldId = FieldCache.DEFAULT.getInts(r, "id");
246       final int[] fieldIdToDocID = new int[numDocs];
247       for (int i = 0; i < docIdToFieldId.length; i++) {
248         int fieldId = docIdToFieldId[i];
249         fieldIdToDocID[fieldId] = i;
250       }
251
252       try {
253         final IndexSearcher s = newSearcher(r);
254
255         for (int contentID = 0; contentID < 3; contentID++) {
256           final ScoreDoc[] hits = s.search(new TermQuery(new Term("content", "real" + contentID)), numDocs).scoreDocs;
257           for (ScoreDoc hit : hits) {
258             final GroupDoc gd = groupDocs[docIdToFieldId[hit.doc]];
259             assertTrue(gd.score == 0.0);
260             gd.score = hit.score;
261             int docId = gd.id;
262             assertEquals(docId, docIdToFieldId[hit.doc]);
263           }
264         }
265
266         for (GroupDoc gd : groupDocs) {
267           assertTrue(gd.score != 0.0);
268         }
269
270         for (int searchIter = 0; searchIter < 100; searchIter++) {
271
272           if (VERBOSE) {
273             System.out.println("TEST: searchIter=" + searchIter);
274           }
275
276           final String searchTerm = "real" + random.nextInt(3);
277           boolean sortByScoreOnly = random.nextBoolean();
278           Sort sortWithinGroup = getRandomSort(sortByScoreOnly);
279           AbstractAllGroupHeadsCollector allGroupHeadsCollector = TermAllGroupHeadsCollector.create("group", sortWithinGroup);
280           s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollector);
281           int[] expectedGroupHeads = createExpectedGroupHeads(searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID);
282           int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads();
283           // The actual group heads contains Lucene ids. Need to change them into our id value.
284           for (int i = 0; i < actualGroupHeads.length; i++) {
285             actualGroupHeads[i] = docIdToFieldId[actualGroupHeads[i]];
286           }
287           // Allows us the easily iterate and assert the actual and expected results.
288           Arrays.sort(expectedGroupHeads);
289           Arrays.sort(actualGroupHeads);
290
291           if (VERBOSE) {
292             System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName());
293             System.out.println("Sort within group: " + sortWithinGroup);
294             System.out.println("Num group: " + numGroups);
295             System.out.println("Num doc: " + numDocs);
296             System.out.println("\n=== Expected: \n");
297             for (int expectedDocId : expectedGroupHeads) {
298               GroupDoc expectedGroupDoc = groupDocs[expectedDocId];
299               String expectedGroup = expectedGroupDoc.group == null ? null : expectedGroupDoc.group;
300               System.out.println(
301                   String.format(
302                       "Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d",
303                       expectedGroup, expectedGroupDoc.score, expectedGroupDoc.sort1,
304                       expectedGroupDoc.sort2, expectedGroupDoc.sort3, expectedDocId
305                   )
306               );
307             }
308             System.out.println("\n=== Actual: \n");
309             for (int actualDocId : actualGroupHeads) {
310               GroupDoc actualGroupDoc = groupDocs[actualDocId];
311               String actualGroup = actualGroupDoc.group == null ? null : actualGroupDoc.group;
312               System.out.println(
313                   String.format(
314                       "Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d",
315                       actualGroup, actualGroupDoc.score, actualGroupDoc.sort1,
316                       actualGroupDoc.sort2, actualGroupDoc.sort3, actualDocId
317                   )
318               );
319             }
320             System.out.println("\n===================================================================================");
321           }
322
323           assertEquals(expectedGroupHeads.length, actualGroupHeads.length);
324           for (int i = 0; i < expectedGroupHeads.length; i++) {
325             assertEquals(expectedGroupHeads[i], actualGroupHeads[i]);
326           }
327         }
328         s.close();
329       } finally {
330         FieldCache.DEFAULT.purge(r);
331       }
332
333       r.close();
334       dir.close();
335     }
336   }
337
338
339   private boolean arrayContains(int[] expected, int[] actual) {
340     if (expected.length != actual.length) {
341       return false;
342     }
343
344     for (int e : expected) {
345       boolean found = false;
346       for (int a : actual) {
347         if (e == a) {
348           found = true;
349         }
350       }
351
352       if (!found) {
353         return false;
354       }
355     }
356
357     return true;
358   }
359
360   private boolean openBitSetContains(int[] expectedDocs, FixedBitSet actual, int maxDoc) throws IOException {
361     if (expectedDocs.length != actual.cardinality()) {
362       return false;
363     }
364
365     FixedBitSet expected = new FixedBitSet(maxDoc);
366     for (int expectedDoc : expectedDocs) {
367       expected.set(expectedDoc);
368     }
369
370     int docId;
371     DocIdSetIterator iterator = expected.iterator();
372     while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
373       if (!actual.get(docId)) {
374         return false;
375       }
376     }
377
378     return true;
379   }
380
381   private int[] createExpectedGroupHeads(String searchTerm, GroupDoc[] groupDocs, Sort docSort, boolean sortByScoreOnly, int[] fieldIdToDocID) throws IOException {
382     Map<String, List<GroupDoc>> groupHeads = new HashMap<String, List<GroupDoc>>();
383     for (GroupDoc groupDoc : groupDocs) {
384       if (!groupDoc.content.startsWith(searchTerm)) {
385         continue;
386       }
387
388       if (!groupHeads.containsKey(groupDoc.group)) {
389         List<GroupDoc> list = new ArrayList<GroupDoc>();
390         list.add(groupDoc);
391         groupHeads.put(groupDoc.group, list);
392         continue;
393       }
394       groupHeads.get(groupDoc.group).add(groupDoc);
395     }
396
397     int[] allGroupHeads = new int[groupHeads.size()];
398     int i = 0;
399     for (String groupValue : groupHeads.keySet()) {
400       List<GroupDoc> docs = groupHeads.get(groupValue);
401       Collections.sort(docs, getComparator(docSort, sortByScoreOnly, fieldIdToDocID));
402       allGroupHeads[i++] = docs.get(0).id;
403     }
404
405     return allGroupHeads;
406   }
407
408   private Sort getRandomSort(boolean scoreOnly) {
409     final List<SortField> sortFields = new ArrayList<SortField>();
410     if (random.nextInt(7) == 2 || scoreOnly) {
411       sortFields.add(SortField.FIELD_SCORE);
412     } else {
413       if (random.nextBoolean()) {
414         if (random.nextBoolean()) {
415           sortFields.add(new SortField("sort1", SortField.STRING, random.nextBoolean()));
416         } else {
417           sortFields.add(new SortField("sort2", SortField.STRING, random.nextBoolean()));
418         }
419       } else if (random.nextBoolean()) {
420         sortFields.add(new SortField("sort1", SortField.STRING, random.nextBoolean()));
421         sortFields.add(new SortField("sort2", SortField.STRING, random.nextBoolean()));
422       }
423     }
424     // Break ties:
425     if (random.nextBoolean() && !scoreOnly) {
426       sortFields.add(new SortField("sort3", SortField.STRING));
427     } else if (!scoreOnly) {
428       sortFields.add(new SortField("id", SortField.INT));
429     }
430     return new Sort(sortFields.toArray(new SortField[sortFields.size()]));
431   }
432
433   private Comparator<GroupDoc> getComparator(Sort sort, final boolean sortByScoreOnly, final int[] fieldIdToDocID) {
434     final SortField[] sortFields = sort.getSort();
435     return new Comparator<GroupDoc>() {
436       public int compare(GroupDoc d1, GroupDoc d2) {
437         for (SortField sf : sortFields) {
438           final int cmp;
439           if (sf.getType() == SortField.SCORE) {
440             if (d1.score > d2.score) {
441               cmp = -1;
442             } else if (d1.score < d2.score) {
443               cmp = 1;
444             } else {
445               cmp = sortByScoreOnly ? fieldIdToDocID[d1.id] - fieldIdToDocID[d2.id] : 0;
446             }
447           } else if (sf.getField().equals("sort1")) {
448             cmp = d1.sort1.compareTo(d2.sort1);
449           } else if (sf.getField().equals("sort2")) {
450             cmp = d1.sort2.compareTo(d2.sort2);
451           } else if (sf.getField().equals("sort3")) {
452             cmp = d1.sort3.compareTo(d2.sort3);
453           } else {
454             assertEquals(sf.getField(), "id");
455             cmp = d1.id - d2.id;
456           }
457           if (cmp != 0) {
458             return sf.getReverse() ? -cmp : cmp;
459           }
460         }
461         // Our sort always fully tie breaks:
462         fail();
463         return 0;
464       }
465     };
466   }
467
468
469   private static class GroupDoc {
470     final int id;
471     final String group;
472     final String sort1;
473     final String sort2;
474     final String sort3;
475     // content must be "realN ..."
476     final String content;
477     float score;
478
479     public GroupDoc(int id, String group, String sort1, String sort2, String sort3, String content) {
480       this.id = id;
481       this.group = group;
482       this.sort1 = sort1;
483       this.sort2 = sort2;
484       this.sort3 = sort3;
485       this.content = content;
486     }
487
488   }
489
490 }