1 package org.apache.lucene.search.join;
4 * Licensed to the Apache Software Foundation (ASF) under one or more
5 * contributor license agreements. See the NOTICE file distributed with
6 * this work for additional information regarding copyright ownership.
7 * The ASF licenses this file to You under the Apache License, Version 2.0
8 * (the "License"); you may not use this file except in compliance with
9 * the License. You may obtain a copy of the License at
11 * http://www.apache.org/licenses/LICENSE-2.0
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
20 import java.io.IOException;
21 import java.util.Arrays;
22 import java.util.HashMap;
25 import org.apache.lucene.index.IndexReader;
26 import org.apache.lucene.index.IndexWriter; // javadocs
27 import org.apache.lucene.search.Collector;
28 import org.apache.lucene.search.FieldComparator;
29 import org.apache.lucene.search.FieldValueHitQueue;
30 import org.apache.lucene.search.Query;
31 import org.apache.lucene.search.ScoreCachingWrappingScorer;
32 import org.apache.lucene.search.Scorer;
33 import org.apache.lucene.search.Sort;
34 import org.apache.lucene.search.TopDocs;
35 import org.apache.lucene.search.TopDocsCollector;
36 import org.apache.lucene.search.TopFieldCollector;
37 import org.apache.lucene.search.TopScoreDocCollector;
38 import org.apache.lucene.search.Weight;
39 import org.apache.lucene.search.grouping.GroupDocs;
40 import org.apache.lucene.search.grouping.TopGroups;
41 import org.apache.lucene.util.ArrayUtil;
44 /** Collects parent document hits for a Query containing one more more
45 * BlockJoinQuery clauses, sorted by the
46 * specified parent Sort. Note that this cannot perform
47 * arbitrary joins; rather, it requires that all joined
48 * documents are indexed as a doc block (using {@link
49 * IndexWriter#addDocuments} or {@link
50 * IndexWriter#updateDocuments}). Ie, the join is computed
53 * <p>The parent Sort must only use
54 * fields from the parent documents; sorting by field in
55 * the child documents is not supported.</p>
57 * <p>You should only use this
58 * collector if one or more of the clauses in the query is
59 * a {@link BlockJoinQuery}. This collector will find those query
60 * clauses and record the matching child documents for the
61 * top scoring parent documents.</p>
63 * <p>Multiple joins (star join) and nested joins and a mix
64 * of the two are allowed, as long as in all cases the
65 * documents corresponding to a single row of each joined
66 * parent table were indexed as a doc block.</p>
68 * <p>For the simple star join you can retrieve the
69 * {@link TopGroups} instance containing each {@link BlockJoinQuery}'s
70 * matching child documents for the top parent groups,
71 * using {@link #getTopGroups}. Ie,
72 * a single query, which will contain two or more
73 * {@link BlockJoinQuery}'s as clauses representing the star join,
74 * can then retrieve two or more {@link TopGroups} instances.</p>
76 * <p>For nested joins, the query will run correctly (ie,
77 * match the right parent and child documents), however,
78 * because TopGroups is currently unable to support nesting
79 * (each group is not able to hold another TopGroups), you
80 * are only able to retrieve the TopGroups of the first
81 * join. The TopGroups of the nested joins will not be
84 * See {@link org.apache.lucene.search.join} for a code
87 * @lucene.experimental
89 public class BlockJoinCollector extends Collector {
91 private final Sort sort;
93 // Maps each BlockJoinQuery instance to its "slot" in
94 // joinScorers and in OneGroup's cached doc/scores/count:
95 private final Map<Query,Integer> joinQueryID = new HashMap<Query,Integer>();
96 private final int numParentHits;
97 private final FieldValueHitQueue<OneGroup> queue;
98 private final FieldComparator[] comparators;
99 private final int[] reverseMul;
100 private final int compEnd;
101 private final boolean trackMaxScore;
102 private final boolean trackScores;
105 private BlockJoinQuery.BlockJoinScorer[] joinScorers = new BlockJoinQuery.BlockJoinScorer[0];
106 private IndexReader currentReader;
107 private Scorer scorer;
108 private boolean queueFull;
110 private OneGroup bottom;
111 private int totalHitCount;
112 private float maxScore = Float.NaN;
114 /* Creates a BlockJoinCollector. The provided sort must
116 public BlockJoinCollector(Sort sort, int numParentHits, boolean trackScores, boolean trackMaxScore) throws IOException {
117 // TODO: allow null sort to be specialized to relevance
120 this.trackMaxScore = trackMaxScore;
121 this.trackScores = trackScores;
122 this.numParentHits = numParentHits;
123 queue = FieldValueHitQueue.create(sort.getSort(), numParentHits);
124 comparators = queue.getComparators();
125 reverseMul = queue.getReverseMul();
126 compEnd = comparators.length - 1;
129 private static final class OneGroup extends FieldValueHitQueue.Entry {
130 public OneGroup(int comparatorSlot, int parentDoc, float parentScore, int numJoins, boolean doScores) {
131 super(comparatorSlot, parentDoc, parentScore);
132 docs = new int[numJoins][];
133 for(int joinID=0;joinID<numJoins;joinID++) {
134 docs[joinID] = new int[5];
137 scores = new float[numJoins][];
138 for(int joinID=0;joinID<numJoins;joinID++) {
139 scores[joinID] = new float[5];
142 counts = new int[numJoins];
152 public void collect(int parentDoc) throws IOException {
153 //System.out.println("C parentDoc=" + parentDoc);
156 float score = Float.NaN;
159 score = scorer.score();
160 if (score > maxScore) {
165 // TODO: we could sweep all joinScorers here and
166 // aggregate total child hit count, so we can fill this
167 // in getTopGroups (we wire it to 0 now)
170 //System.out.println(" queueFull");
171 // Fastmatch: return if this hit is not competitive
172 for (int i = 0;; i++) {
173 final int c = reverseMul[i] * comparators[i].compareBottom(parentDoc);
175 // Definitely not competitive.
176 //System.out.println(" skip");
179 // Definitely competitive.
181 } else if (i == compEnd) {
182 // Here c=0. If we're at the last comparator, this doc is not
183 // competitive, since docs are visited in doc Id order, which means
184 // this doc cannot compete with any other document in the queue.
185 //System.out.println(" skip");
190 //System.out.println(" competes! doc=" + (docBase + parentDoc));
192 // This hit is competitive - replace bottom element in queue & adjustTop
193 for (int i = 0; i < comparators.length; i++) {
194 comparators[i].copy(bottom.slot, parentDoc);
196 if (!trackMaxScore && trackScores) {
197 score = scorer.score();
199 bottom.doc = docBase + parentDoc;
200 bottom.reader = currentReader;
201 bottom.docBase = docBase;
202 bottom.score = score;
204 bottom = queue.updateTop();
206 for (int i = 0; i < comparators.length; i++) {
207 comparators[i].setBottom(bottom.slot);
210 // Startup transient: queue is not yet full:
211 final int comparatorSlot = totalHitCount - 1;
213 // Copy hit into queue
214 for (int i = 0; i < comparators.length; i++) {
215 comparators[i].copy(comparatorSlot, parentDoc);
217 //System.out.println(" startup: new OG doc=" + (docBase+parentDoc));
218 final OneGroup og = new OneGroup(comparatorSlot, docBase+parentDoc, score, joinScorers.length, trackScores);
219 og.reader = currentReader;
220 og.docBase = docBase;
222 bottom = queue.add(og);
223 queueFull = totalHitCount == numParentHits;
225 // End of startup transient: queue just filled up:
226 for (int i = 0; i < comparators.length; i++) {
227 comparators[i].setBottom(bottom.slot);
233 // Pulls out child doc and scores for all join queries:
234 private void copyGroups(OneGroup og) {
235 // While rare, it's possible top arrays could be too
236 // short if join query had null scorer on first
237 // segment(s) but then became non-null on later segments
238 final int numSubScorers = joinScorers.length;
239 if (og.docs.length < numSubScorers) {
240 // While rare, this could happen if join query had
241 // null scorer on first segment(s) but then became
242 // non-null on later segments
243 og.docs = ArrayUtil.grow(og.docs);
245 if (og.counts.length < numSubScorers) {
246 og.counts = ArrayUtil.grow(og.counts);
248 if (trackScores && og.scores.length < numSubScorers) {
249 og.scores = ArrayUtil.grow(og.scores);
252 //System.out.println("copyGroups parentDoc=" + og.doc);
253 for(int scorerIDX = 0;scorerIDX < numSubScorers;scorerIDX++) {
254 final BlockJoinQuery.BlockJoinScorer joinScorer = joinScorers[scorerIDX];
255 //System.out.println(" scorer=" + joinScorer);
256 if (joinScorer != null) {
257 og.counts[scorerIDX] = joinScorer.getChildCount();
258 //System.out.println(" count=" + og.counts[scorerIDX]);
259 og.docs[scorerIDX] = joinScorer.swapChildDocs(og.docs[scorerIDX]);
261 for(int idx=0;idx<og.counts[scorerIDX];idx++) {
262 System.out.println(" docs[" + idx + "]=" + og.docs[scorerIDX][idx]);
266 og.scores[scorerIDX] = joinScorer.swapChildScores(og.scores[scorerIDX]);
273 public void setNextReader(IndexReader reader, int docBase) throws IOException {
274 currentReader = reader;
275 this.docBase = docBase;
276 for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
277 comparators[compIDX].setNextReader(reader, docBase);
282 public boolean acceptsDocsOutOfOrder() {
287 public void setScorer(Scorer scorer) {
288 //System.out.println("C.setScorer scorer=" + scorer);
289 // Since we invoke .score(), and the comparators likely
290 // do as well, cache it so it's only "really" computed
292 this.scorer = new ScoreCachingWrappingScorer(scorer);
293 for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
294 comparators[compIDX].setScorer(this.scorer);
296 Arrays.fill(joinScorers, null);
298 // Find any BlockJoinScorers out there:
299 scorer.visitScorers(new Scorer.ScorerVisitor<Query,Query,Scorer>() {
300 private void enroll(BlockJoinQuery query, BlockJoinQuery.BlockJoinScorer scorer) {
301 final Integer slot = joinQueryID.get(query);
303 joinQueryID.put(query, joinScorers.length);
304 //System.out.println("found JQ: " + query + " slot=" + joinScorers.length);
305 final BlockJoinQuery.BlockJoinScorer[] newArray = new BlockJoinQuery.BlockJoinScorer[1+joinScorers.length];
306 System.arraycopy(joinScorers, 0, newArray, 0, joinScorers.length);
307 joinScorers = newArray;
308 joinScorers[joinScorers.length-1] = scorer;
310 joinScorers[slot] = scorer;
315 public void visitOptional(Query parent, Query child, Scorer scorer) {
316 //System.out.println("visitOpt");
317 if (child instanceof BlockJoinQuery) {
318 enroll((BlockJoinQuery) child,
319 (BlockJoinQuery.BlockJoinScorer) scorer);
324 public void visitRequired(Query parent, Query child, Scorer scorer) {
325 //System.out.println("visitReq parent=" + parent + " child=" + child + " scorer=" + scorer);
326 if (child instanceof BlockJoinQuery) {
327 enroll((BlockJoinQuery) child,
328 (BlockJoinQuery.BlockJoinScorer) scorer);
333 public void visitProhibited(Query parent, Query child, Scorer scorer) {
334 //System.out.println("visitProh");
335 if (child instanceof BlockJoinQuery) {
336 enroll((BlockJoinQuery) child,
337 (BlockJoinQuery.BlockJoinScorer) scorer);
343 private final static class FakeScorer extends Scorer {
348 public FakeScorer() {
349 super((Weight) null);
353 public float score() {
363 public int advance(int target) {
364 throw new UnsupportedOperationException();
368 public int nextDoc() {
369 throw new UnsupportedOperationException();
373 private OneGroup[] sortedGroups;
375 private void sortQueue() {
376 sortedGroups = new OneGroup[queue.size()];
377 for(int downTo=queue.size()-1;downTo>=0;downTo--) {
378 sortedGroups[downTo] = queue.pop();
382 /** Return the TopGroups for the specified
383 * BlockJoinQuery. The groupValue of each GroupDocs will
384 * be the parent docID for that group. Note that the
385 * {@link GroupDocs#totalHits}, which would be the
386 * total number of child documents matching that parent,
387 * is not computed (will always be 0). Returns null if
388 * no groups matched. */
389 @SuppressWarnings("unchecked")
390 public TopGroups<Integer> getTopGroups(BlockJoinQuery query, Sort withinGroupSort, int offset, int maxDocsPerGroup, int withinGroupOffset, boolean fillSortFields)
394 final Integer _slot = joinQueryID.get(query);
396 if (totalHitCount == 0) {
399 throw new IllegalArgumentException("the Query did not contain the provided BlockJoinQuery");
404 final int slot = _slot;
406 if (offset >= queue.size()) {
409 int totalGroupedHitCount = 0;
411 if (sortedGroups == null) {
415 final FakeScorer fakeScorer = new FakeScorer();
417 final GroupDocs<Integer>[] groups = new GroupDocs[sortedGroups.length - offset];
419 for(int groupIDX=offset;groupIDX<sortedGroups.length;groupIDX++) {
420 final OneGroup og = sortedGroups[groupIDX];
422 // At this point we hold all docs w/ in each group,
423 // unsorted; we now sort them:
424 final TopDocsCollector collector;
425 if (withinGroupSort == null) {
428 throw new IllegalArgumentException("cannot sort by relevance within group: trackScores=false");
430 collector = TopScoreDocCollector.create(maxDocsPerGroup, true);
433 collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, fillSortFields, trackScores, trackMaxScore, true);
436 collector.setScorer(fakeScorer);
437 collector.setNextReader(og.reader, og.docBase);
438 final int numChildDocs = og.counts[slot];
439 for(int docIDX=0;docIDX<numChildDocs;docIDX++) {
440 final int doc = og.docs[slot][docIDX];
441 fakeScorer.doc = doc;
443 fakeScorer.score = og.scores[slot][docIDX];
445 collector.collect(doc);
447 totalGroupedHitCount += numChildDocs;
449 final Object[] groupSortValues;
451 if (fillSortFields) {
452 groupSortValues = new Object[comparators.length];
453 for(int sortFieldIDX=0;sortFieldIDX<comparators.length;sortFieldIDX++) {
454 groupSortValues[sortFieldIDX] = comparators[sortFieldIDX].value(og.slot);
457 groupSortValues = null;
460 final TopDocs topDocs = collector.topDocs(withinGroupOffset, maxDocsPerGroup);
462 groups[groupIDX-offset] = new GroupDocs<Integer>(topDocs.getMaxScore(),
469 return new TopGroups<Integer>(new TopGroups<Integer>(sort.getSort(),
470 withinGroupSort == null ? null : withinGroupSort.getSort(),
471 0, totalGroupedHitCount, groups),