pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / join / src / java / org / apache / lucene / search / join / BlockJoinCollector.java
1 package org.apache.lucene.search.join;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import java.io.IOException;
21 import java.util.Arrays;
22 import java.util.HashMap;
23 import java.util.Map;
24
25 import org.apache.lucene.index.IndexReader;
26 import org.apache.lucene.index.IndexWriter;       // javadocs
27 import org.apache.lucene.search.Collector;
28 import org.apache.lucene.search.FieldComparator;
29 import org.apache.lucene.search.FieldValueHitQueue;
30 import org.apache.lucene.search.Query;
31 import org.apache.lucene.search.ScoreCachingWrappingScorer;
32 import org.apache.lucene.search.Scorer;
33 import org.apache.lucene.search.Sort;
34 import org.apache.lucene.search.TopDocs;
35 import org.apache.lucene.search.TopDocsCollector;
36 import org.apache.lucene.search.TopFieldCollector;
37 import org.apache.lucene.search.TopScoreDocCollector;
38 import org.apache.lucene.search.Weight;
39 import org.apache.lucene.search.grouping.GroupDocs;
40 import org.apache.lucene.search.grouping.TopGroups;
41 import org.apache.lucene.util.ArrayUtil;
42
43
44 /** Collects parent document hits for a Query containing one more more
45  *  BlockJoinQuery clauses, sorted by the
46  *  specified parent Sort.  Note that this cannot perform
47  *  arbitrary joins; rather, it requires that all joined
48  *  documents are indexed as a doc block (using {@link
49  *  IndexWriter#addDocuments} or {@link
50  *  IndexWriter#updateDocuments}).  Ie, the join is computed
51  *  at index time.
52  *
53  *  <p>The parent Sort must only use
54  *  fields from the parent documents; sorting by field in
55  *  the child documents is not supported.</p>
56  *
57  *  <p>You should only use this
58  *  collector if one or more of the clauses in the query is
59  *  a {@link BlockJoinQuery}.  This collector will find those query
60  *  clauses and record the matching child documents for the
61  *  top scoring parent documents.</p>
62  *
63  *  <p>Multiple joins (star join) and nested joins and a mix
64  *  of the two are allowed, as long as in all cases the
65  *  documents corresponding to a single row of each joined
66  *  parent table were indexed as a doc block.</p>
67  *
68  *  <p>For the simple star join you can retrieve the
69  *  {@link TopGroups} instance containing each {@link BlockJoinQuery}'s
70  *  matching child documents for the top parent groups,
71  *  using {@link #getTopGroups}.  Ie,
72  *  a single query, which will contain two or more
73  *  {@link BlockJoinQuery}'s as clauses representing the star join,
74  *  can then retrieve two or more {@link TopGroups} instances.</p>
75  *
76  *  <p>For nested joins, the query will run correctly (ie,
77  *  match the right parent and child documents), however,
78  *  because TopGroups is currently unable to support nesting
79  *  (each group is not able to hold another TopGroups), you
80  *  are only able to retrieve the TopGroups of the first
81  *  join.  The TopGroups of the nested joins will not be
82  *  correct.
83  *
84  *  See {@link org.apache.lucene.search.join} for a code
85  *  sample.
86  *
87  * @lucene.experimental
88  */
89 public class BlockJoinCollector extends Collector {
90
91   private final Sort sort;
92
93   // Maps each BlockJoinQuery instance to its "slot" in
94   // joinScorers and in OneGroup's cached doc/scores/count:
95   private final Map<Query,Integer> joinQueryID = new HashMap<Query,Integer>();
96   private final int numParentHits;
97   private final FieldValueHitQueue<OneGroup> queue;
98   private final FieldComparator[] comparators;
99   private final int[] reverseMul;
100   private final int compEnd;
101   private final boolean trackMaxScore;
102   private final boolean trackScores;
103
104   private int docBase;
105   private BlockJoinQuery.BlockJoinScorer[] joinScorers = new BlockJoinQuery.BlockJoinScorer[0];
106   private IndexReader currentReader;
107   private Scorer scorer;
108   private boolean queueFull;
109
110   private OneGroup bottom;
111   private int totalHitCount;
112   private float maxScore = Float.NaN;
113
114   /*  Creates a BlockJoinCollector.  The provided sort must
115    *  not be null. */
116   public BlockJoinCollector(Sort sort, int numParentHits, boolean trackScores, boolean trackMaxScore) throws IOException {
117     // TODO: allow null sort to be specialized to relevance
118     // only collector
119     this.sort = sort;
120     this.trackMaxScore = trackMaxScore;
121     this.trackScores = trackScores;
122     this.numParentHits = numParentHits;
123     queue = FieldValueHitQueue.create(sort.getSort(), numParentHits);
124     comparators = queue.getComparators();
125     reverseMul = queue.getReverseMul();
126     compEnd = comparators.length - 1;
127   }
128   
129   private static final class OneGroup extends FieldValueHitQueue.Entry {
130     public OneGroup(int comparatorSlot, int parentDoc, float parentScore, int numJoins, boolean doScores) {
131       super(comparatorSlot, parentDoc, parentScore);
132       docs = new int[numJoins][];
133       for(int joinID=0;joinID<numJoins;joinID++) {
134         docs[joinID] = new int[5];
135       }
136       if (doScores) {
137         scores = new float[numJoins][];
138         for(int joinID=0;joinID<numJoins;joinID++) {
139           scores[joinID] = new float[5];
140         }
141       }
142       counts = new int[numJoins];
143     }
144     IndexReader reader;
145     int docBase;
146     int[][] docs;
147     float[][] scores;
148     int[] counts;
149   };
150
151   @Override
152   public void collect(int parentDoc) throws IOException {
153     //System.out.println("C parentDoc=" + parentDoc);
154     totalHitCount++;
155
156     float score = Float.NaN;
157
158     if (trackMaxScore) {
159       score = scorer.score();
160       if (score > maxScore) {
161         maxScore = score;
162       }
163     }
164
165     // TODO: we could sweep all joinScorers here and
166     // aggregate total child hit count, so we can fill this
167     // in getTopGroups (we wire it to 0 now)
168
169     if (queueFull) {
170       //System.out.println("  queueFull");
171       // Fastmatch: return if this hit is not competitive
172       for (int i = 0;; i++) {
173         final int c = reverseMul[i] * comparators[i].compareBottom(parentDoc);
174         if (c < 0) {
175           // Definitely not competitive.
176           //System.out.println("    skip");
177           return;
178         } else if (c > 0) {
179           // Definitely competitive.
180           break;
181         } else if (i == compEnd) {
182           // Here c=0. If we're at the last comparator, this doc is not
183           // competitive, since docs are visited in doc Id order, which means
184           // this doc cannot compete with any other document in the queue.
185           //System.out.println("    skip");
186           return;
187         }
188       }
189
190       //System.out.println("    competes!  doc=" + (docBase + parentDoc));
191
192       // This hit is competitive - replace bottom element in queue & adjustTop
193       for (int i = 0; i < comparators.length; i++) {
194         comparators[i].copy(bottom.slot, parentDoc);
195       }
196       if (!trackMaxScore && trackScores) {
197         score = scorer.score();
198       }
199       bottom.doc = docBase + parentDoc;
200       bottom.reader = currentReader;
201       bottom.docBase = docBase;
202       bottom.score = score;
203       copyGroups(bottom);
204       bottom = queue.updateTop();
205
206       for (int i = 0; i < comparators.length; i++) {
207         comparators[i].setBottom(bottom.slot);
208       }
209     } else {
210       // Startup transient: queue is not yet full:
211       final int comparatorSlot = totalHitCount - 1;
212
213       // Copy hit into queue
214       for (int i = 0; i < comparators.length; i++) {
215         comparators[i].copy(comparatorSlot, parentDoc);
216       }
217       //System.out.println("  startup: new OG doc=" + (docBase+parentDoc));
218       final OneGroup og = new OneGroup(comparatorSlot, docBase+parentDoc, score, joinScorers.length, trackScores);
219       og.reader = currentReader;
220       og.docBase = docBase;
221       copyGroups(og);
222       bottom = queue.add(og);
223       queueFull = totalHitCount == numParentHits;
224       if (queueFull) {
225         // End of startup transient: queue just filled up:
226         for (int i = 0; i < comparators.length; i++) {
227           comparators[i].setBottom(bottom.slot);
228         }
229       }
230     }
231   }
232
233   // Pulls out child doc and scores for all join queries:
234   private void copyGroups(OneGroup og) {
235     // While rare, it's possible top arrays could be too
236     // short if join query had null scorer on first
237     // segment(s) but then became non-null on later segments
238     final int numSubScorers = joinScorers.length;
239     if (og.docs.length < numSubScorers) {
240       // While rare, this could happen if join query had
241       // null scorer on first segment(s) but then became
242       // non-null on later segments
243       og.docs = ArrayUtil.grow(og.docs);
244     }
245     if (og.counts.length < numSubScorers) {
246       og.counts = ArrayUtil.grow(og.counts);
247     }
248     if (trackScores && og.scores.length < numSubScorers) {
249       og.scores = ArrayUtil.grow(og.scores);
250     }
251
252     //System.out.println("copyGroups parentDoc=" + og.doc);
253     for(int scorerIDX = 0;scorerIDX < numSubScorers;scorerIDX++) {
254       final BlockJoinQuery.BlockJoinScorer joinScorer = joinScorers[scorerIDX];
255       //System.out.println("  scorer=" + joinScorer);
256       if (joinScorer != null) {
257         og.counts[scorerIDX] = joinScorer.getChildCount();
258         //System.out.println("    count=" + og.counts[scorerIDX]);
259         og.docs[scorerIDX] = joinScorer.swapChildDocs(og.docs[scorerIDX]);
260         /*
261         for(int idx=0;idx<og.counts[scorerIDX];idx++) {
262           System.out.println("    docs[" + idx + "]=" + og.docs[scorerIDX][idx]);
263         }
264         */
265         if (trackScores) {
266           og.scores[scorerIDX] = joinScorer.swapChildScores(og.scores[scorerIDX]);
267         }
268       }
269     }
270   }
271
272   @Override
273   public void setNextReader(IndexReader reader, int docBase) throws IOException {
274     currentReader = reader;
275     this.docBase = docBase;
276     for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
277       comparators[compIDX].setNextReader(reader, docBase);
278     }
279   }
280
281   @Override
282   public boolean acceptsDocsOutOfOrder() {
283     return false;
284   }
285
286   @Override
287   public void setScorer(Scorer scorer) {
288     //System.out.println("C.setScorer scorer=" + scorer);
289     // Since we invoke .score(), and the comparators likely
290     // do as well, cache it so it's only "really" computed
291     // once:
292     this.scorer = new ScoreCachingWrappingScorer(scorer);
293     for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
294       comparators[compIDX].setScorer(this.scorer);
295     }
296     Arrays.fill(joinScorers, null);
297
298     // Find any BlockJoinScorers out there:
299     scorer.visitScorers(new Scorer.ScorerVisitor<Query,Query,Scorer>() {
300         private void enroll(BlockJoinQuery query, BlockJoinQuery.BlockJoinScorer scorer) {
301           final Integer slot = joinQueryID.get(query);
302           if (slot == null) {
303             joinQueryID.put(query, joinScorers.length);
304             //System.out.println("found JQ: " + query + " slot=" + joinScorers.length);
305             final BlockJoinQuery.BlockJoinScorer[] newArray = new BlockJoinQuery.BlockJoinScorer[1+joinScorers.length];
306             System.arraycopy(joinScorers, 0, newArray, 0, joinScorers.length);
307             joinScorers = newArray;
308             joinScorers[joinScorers.length-1] = scorer;
309           } else {
310             joinScorers[slot] = scorer;
311           }
312         }
313
314         @Override
315         public void visitOptional(Query parent, Query child, Scorer scorer) {
316           //System.out.println("visitOpt");
317           if (child instanceof BlockJoinQuery) {
318             enroll((BlockJoinQuery) child,
319                    (BlockJoinQuery.BlockJoinScorer) scorer);
320           }
321         }
322
323         @Override
324         public void visitRequired(Query parent, Query child, Scorer scorer) {
325           //System.out.println("visitReq parent=" + parent + " child=" + child + " scorer=" + scorer);
326           if (child instanceof BlockJoinQuery) {
327             enroll((BlockJoinQuery) child,
328                    (BlockJoinQuery.BlockJoinScorer) scorer);
329           }
330         }
331
332         @Override
333         public void visitProhibited(Query parent, Query child, Scorer scorer) {
334           //System.out.println("visitProh");
335           if (child instanceof BlockJoinQuery) {
336             enroll((BlockJoinQuery) child,
337                    (BlockJoinQuery.BlockJoinScorer) scorer);
338           }
339         }
340       });
341   }
342
343   private final static class FakeScorer extends Scorer {
344
345     float score;
346     int doc;
347
348     public FakeScorer() {
349       super((Weight) null);
350     }
351
352     @Override
353     public float score() {
354       return score;
355     }
356
357     @Override
358     public int docID() {
359       return doc;
360     }
361
362     @Override
363     public int advance(int target) {
364       throw new UnsupportedOperationException();
365     }
366
367     @Override
368     public int nextDoc() {
369       throw new UnsupportedOperationException();
370     }
371   }
372
373   private OneGroup[] sortedGroups;
374
375   private void sortQueue() {
376     sortedGroups = new OneGroup[queue.size()];
377     for(int downTo=queue.size()-1;downTo>=0;downTo--) {
378       sortedGroups[downTo] = queue.pop();
379     }
380   }
381
382   /** Return the TopGroups for the specified
383    *  BlockJoinQuery.  The groupValue of each GroupDocs will
384    *  be the parent docID for that group.  Note that the
385    *  {@link GroupDocs#totalHits}, which would be the
386    *  total number of child documents matching that parent,
387    *  is not computed (will always be 0).  Returns null if
388    *  no groups matched. */
389   @SuppressWarnings("unchecked")
390   public TopGroups<Integer> getTopGroups(BlockJoinQuery query, Sort withinGroupSort, int offset, int maxDocsPerGroup, int withinGroupOffset, boolean fillSortFields) 
391
392     throws IOException {
393
394     final Integer _slot = joinQueryID.get(query);
395     if (_slot == null) {
396       if (totalHitCount == 0) {
397         return null;
398       } else {
399         throw new IllegalArgumentException("the Query did not contain the provided BlockJoinQuery");
400       }
401     }
402
403     // unbox once
404     final int slot = _slot;
405
406     if (sortedGroups == null) {
407       if (offset >= queue.size()) {
408         return null;
409       }
410       sortQueue();
411     } else if (offset > sortedGroups.length) {
412       return null;
413     }
414
415     int totalGroupedHitCount = 0;
416
417     final FakeScorer fakeScorer = new FakeScorer();
418
419     final GroupDocs<Integer>[] groups = new GroupDocs[sortedGroups.length - offset];
420
421     for(int groupIDX=offset;groupIDX<sortedGroups.length;groupIDX++) {
422       final OneGroup og = sortedGroups[groupIDX];
423
424       // At this point we hold all docs w/ in each group,
425       // unsorted; we now sort them:
426       final TopDocsCollector collector;
427       if (withinGroupSort == null) {
428         // Sort by score
429         if (!trackScores) {
430           throw new IllegalArgumentException("cannot sort by relevance within group: trackScores=false");
431         }
432         collector = TopScoreDocCollector.create(maxDocsPerGroup, true);
433       } else {
434         // Sort by fields
435         collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, fillSortFields, trackScores, trackMaxScore, true);
436       }
437
438       collector.setScorer(fakeScorer);
439       collector.setNextReader(og.reader, og.docBase);
440       final int numChildDocs = og.counts[slot];
441       for(int docIDX=0;docIDX<numChildDocs;docIDX++) {
442         final int doc = og.docs[slot][docIDX];
443         fakeScorer.doc = doc;
444         if (trackScores) {
445           fakeScorer.score = og.scores[slot][docIDX];
446         }
447         collector.collect(doc);
448       }
449       totalGroupedHitCount += numChildDocs;
450
451       final Object[] groupSortValues;
452
453       if (fillSortFields) {
454         groupSortValues = new Object[comparators.length];
455         for(int sortFieldIDX=0;sortFieldIDX<comparators.length;sortFieldIDX++) {
456           groupSortValues[sortFieldIDX] = comparators[sortFieldIDX].value(og.slot);
457         }
458       } else {
459         groupSortValues = null;
460       }
461
462       final TopDocs topDocs = collector.topDocs(withinGroupOffset, maxDocsPerGroup);
463
464       groups[groupIDX-offset] = new GroupDocs<Integer>(topDocs.getMaxScore(),
465                                                        og.counts[slot],
466                                                        topDocs.scoreDocs,
467                                                        og.doc,
468                                                        groupSortValues);
469     }
470
471     return new TopGroups<Integer>(new TopGroups<Integer>(sort.getSort(),
472                                                          withinGroupSort == null ? null : withinGroupSort.getSort(),
473                                                          0, totalGroupedHitCount, groups),
474                                   totalHitCount);
475   }
476 }