--- /dev/null
+package org.apache.lucene.search.join;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.Set;
+
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter; // javadocs
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.DocIdSet;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.Filter;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Searcher;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.grouping.TopGroups;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.FixedBitSet;
+
+/**
+ * This query requires that you index
+ * children and parent docs as a single block, using the
+ * {@link IndexWriter#addDocuments} or {@link
+ * IndexWriter#updateDocuments} API. In each block, the
+ * child documents must appear first, ending with the parent
+ * document. At search time you provide a Filter
+ * identifying the parents, however this Filter must provide
+ * an {@link FixedBitSet} per sub-reader.
+ *
+ * <p>Once the block index is built, use this query to wrap
+ * any sub-query matching only child docs and join matches in that
+ * child document space up to the parent document space.
+ * You can then use this Query as a clause with
+ * other queries in the parent document space.</p>
+ *
+ * <p>The child documents must be orthogonal to the parent
+ * documents: the wrapped child query must never
+ * return a parent document.</p>
+ *
+ * If you'd like to retrieve {@link TopGroups} for the
+ * resulting query, use the {@link BlockJoinCollector}.
+ * Note that this is not necessary, ie, if you simply want
+ * to collect the parent documents and don't need to see
+ * which child documents matched under that parent, then
+ * you can use any collector.
+ *
+ * <p><b>NOTE</b>: If the overall query contains parent-only
+ * matches, for example you OR a parent-only query with a
+ * joined child-only query, then the resulting collected documents
+ * will be correct, however the {@link TopGroups} you get
+ * from {@link BlockJoinCollector} will not contain every
+ * child for parents that had matched.
+ *
+ * <p>See {@link org.apache.lucene.search.join} for an
+ * overview. </p>
+ *
+ * @lucene.experimental
+ */
+
+public class BlockJoinQuery extends Query {
+
+ public static enum ScoreMode {None, Avg, Max, Total};
+
+ private final Filter parentsFilter;
+ private final Query childQuery;
+
+ // If we are rewritten, this is the original childQuery we
+ // were passed; we use this for .equals() and
+ // .hashCode(). This makes rewritten query equal the
+ // original, so that user does not have to .rewrite() their
+ // query before searching:
+ private final Query origChildQuery;
+ private final ScoreMode scoreMode;
+
+ public BlockJoinQuery(Query childQuery, Filter parentsFilter, ScoreMode scoreMode) {
+ super();
+ this.origChildQuery = childQuery;
+ this.childQuery = childQuery;
+ this.parentsFilter = parentsFilter;
+ this.scoreMode = scoreMode;
+ }
+
+ private BlockJoinQuery(Query origChildQuery, Query childQuery, Filter parentsFilter, ScoreMode scoreMode) {
+ super();
+ this.origChildQuery = origChildQuery;
+ this.childQuery = childQuery;
+ this.parentsFilter = parentsFilter;
+ this.scoreMode = scoreMode;
+ }
+
+ @Override
+ public Weight createWeight(Searcher searcher) throws IOException {
+ return new BlockJoinWeight(this, childQuery.createWeight(searcher), parentsFilter, scoreMode);
+ }
+
+ private static class BlockJoinWeight extends Weight {
+ private final Query joinQuery;
+ private final Weight childWeight;
+ private final Filter parentsFilter;
+ private final ScoreMode scoreMode;
+
+ public BlockJoinWeight(Query joinQuery, Weight childWeight, Filter parentsFilter, ScoreMode scoreMode) {
+ super();
+ this.joinQuery = joinQuery;
+ this.childWeight = childWeight;
+ this.parentsFilter = parentsFilter;
+ this.scoreMode = scoreMode;
+ }
+
+ @Override
+ public Query getQuery() {
+ return joinQuery;
+ }
+
+ @Override
+ public float getValue() {
+ return childWeight.getValue();
+ }
+
+ @Override
+ public float sumOfSquaredWeights() throws IOException {
+ return childWeight.sumOfSquaredWeights() * joinQuery.getBoost() * joinQuery.getBoost();
+ }
+
+ @Override
+ public void normalize(float norm) {
+ childWeight.normalize(norm * joinQuery.getBoost());
+ }
+
+ @Override
+ public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, boolean topScorer) throws IOException {
+ // Pass scoreDocsInOrder true, topScorer false to our sub:
+ final Scorer childScorer = childWeight.scorer(reader, true, false);
+
+ if (childScorer == null) {
+ // No matches
+ return null;
+ }
+
+ final int firstChildDoc = childScorer.nextDoc();
+ if (firstChildDoc == DocIdSetIterator.NO_MORE_DOCS) {
+ // No matches
+ return null;
+ }
+
+ final DocIdSet parents = parentsFilter.getDocIdSet(reader);
+ // TODO: once we do random-access filters we can
+ // generalize this:
+ if (parents == null) {
+ // No matches
+ return null;
+ }
+ if (!(parents instanceof FixedBitSet)) {
+ throw new IllegalStateException("parentFilter must return FixedBitSet; got " + parents);
+ }
+
+ return new BlockJoinScorer(this, childScorer, (FixedBitSet) parents, firstChildDoc, scoreMode);
+ }
+
+ @Override
+ public Explanation explain(IndexReader reader, int doc) throws IOException {
+ // TODO
+ throw new UnsupportedOperationException(getClass().getName() +
+ " cannot explain match on parent document");
+ }
+
+ @Override
+ public boolean scoresDocsOutOfOrder() {
+ return false;
+ }
+ }
+
+ static class BlockJoinScorer extends Scorer {
+ private final Scorer childScorer;
+ private final FixedBitSet parentBits;
+ private final ScoreMode scoreMode;
+ private int parentDoc;
+ private float parentScore;
+ private int nextChildDoc;
+
+ private int[] pendingChildDocs = new int[5];
+ private float[] pendingChildScores;
+ private int childDocUpto;
+
+ public BlockJoinScorer(Weight weight, Scorer childScorer, FixedBitSet parentBits, int firstChildDoc, ScoreMode scoreMode) {
+ super(weight);
+ //System.out.println("Q.init firstChildDoc=" + firstChildDoc);
+ this.parentBits = parentBits;
+ this.childScorer = childScorer;
+ this.scoreMode = scoreMode;
+ if (scoreMode != ScoreMode.None) {
+ pendingChildScores = new float[5];
+ }
+ nextChildDoc = firstChildDoc;
+ }
+
+ @Override
+ public void visitSubScorers(Query parent, BooleanClause.Occur relationship,
+ ScorerVisitor<Query, Query, Scorer> visitor) {
+ super.visitSubScorers(parent, relationship, visitor);
+ //childScorer.visitSubScorers(weight.getQuery(), BooleanClause.Occur.MUST, visitor);
+ childScorer.visitScorers(visitor);
+ }
+
+ int getChildCount() {
+ return childDocUpto;
+ }
+
+ int[] swapChildDocs(int[] other) {
+ final int[] ret = pendingChildDocs;
+ if (other == null) {
+ pendingChildDocs = new int[5];
+ } else {
+ pendingChildDocs = other;
+ }
+ return ret;
+ }
+
+ float[] swapChildScores(float[] other) {
+ if (scoreMode == ScoreMode.None) {
+ throw new IllegalStateException("ScoreMode is None");
+ }
+ final float[] ret = pendingChildScores;
+ if (other == null) {
+ pendingChildScores = new float[5];
+ } else {
+ pendingChildScores = other;
+ }
+ return ret;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ //System.out.println("Q.nextDoc() nextChildDoc=" + nextChildDoc);
+
+ if (nextChildDoc == NO_MORE_DOCS) {
+ //System.out.println(" end");
+ return parentDoc = NO_MORE_DOCS;
+ }
+
+ // Gather all children sharing the same parent as nextChildDoc
+ parentDoc = parentBits.nextSetBit(nextChildDoc);
+ //System.out.println(" parentDoc=" + parentDoc);
+ assert parentDoc != -1;
+
+ float totalScore = 0;
+ float maxScore = Float.NEGATIVE_INFINITY;
+
+ childDocUpto = 0;
+ do {
+ //System.out.println(" c=" + nextChildDoc);
+ if (pendingChildDocs.length == childDocUpto) {
+ pendingChildDocs = ArrayUtil.grow(pendingChildDocs);
+ if (scoreMode != ScoreMode.None) {
+ pendingChildScores = ArrayUtil.grow(pendingChildScores);
+ }
+ }
+ pendingChildDocs[childDocUpto] = nextChildDoc;
+ if (scoreMode != ScoreMode.None) {
+ // TODO: specialize this into dedicated classes per-scoreMode
+ final float childScore = childScorer.score();
+ pendingChildScores[childDocUpto] = childScore;
+ maxScore = Math.max(childScore, maxScore);
+ totalScore += childScore;
+ }
+ childDocUpto++;
+ nextChildDoc = childScorer.nextDoc();
+ } while (nextChildDoc < parentDoc);
+ //System.out.println(" nextChildDoc=" + nextChildDoc);
+
+ // Parent & child docs are supposed to be orthogonal:
+ assert nextChildDoc != parentDoc;
+
+ switch(scoreMode) {
+ case Avg:
+ parentScore = totalScore / childDocUpto;
+ break;
+ case Max:
+ parentScore = maxScore;
+ break;
+ case Total:
+ parentScore = totalScore;
+ break;
+ case None:
+ break;
+ }
+
+ //System.out.println(" return parentDoc=" + parentDoc);
+ return parentDoc;
+ }
+
+ @Override
+ public int docID() {
+ return parentDoc;
+ }
+
+ @Override
+ public float score() throws IOException {
+ return parentScore;
+ }
+
+ @Override
+ public int advance(int parentTarget) throws IOException {
+
+ //System.out.println("Q.advance parentTarget=" + parentTarget);
+ if (parentTarget == NO_MORE_DOCS) {
+ return parentDoc = NO_MORE_DOCS;
+ }
+
+ // Every parent must have at least one child:
+ assert parentTarget != 0;
+
+ final int prevParentDoc = parentBits.prevSetBit(parentTarget-1);
+
+ //System.out.println(" rolled back to prevParentDoc=" + prevParentDoc + " vs parentDoc=" + parentDoc);
+ assert prevParentDoc >= parentDoc;
+ if (prevParentDoc > nextChildDoc) {
+ nextChildDoc = childScorer.advance(prevParentDoc);
+ // System.out.println(" childScorer advanced to child docID=" + nextChildDoc);
+ //} else {
+ //System.out.println(" skip childScorer advance");
+ }
+
+ // Parent & child docs are supposed to be orthogonal:
+ assert nextChildDoc != prevParentDoc;
+
+ final int nd = nextDoc();
+ //System.out.println(" return nextParentDoc=" + nd);
+ return nd;
+ }
+ }
+
+ @Override
+ public void extractTerms(Set<Term> terms) {
+ childQuery.extractTerms(terms);
+ }
+
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ final Query childRewrite = childQuery.rewrite(reader);
+ if (childRewrite != childQuery) {
+ Query rewritten = new BlockJoinQuery(childQuery,
+ childRewrite,
+ parentsFilter,
+ scoreMode);
+ rewritten.setBoost(getBoost());
+ return rewritten;
+ } else {
+ return this;
+ }
+ }
+
+ @Override
+ public String toString(String field) {
+ return "BlockJoinQuery ("+childQuery.toString()+")";
+ }
+
+ @Override
+ public boolean equals(Object _other) {
+ if (_other instanceof BlockJoinQuery) {
+ final BlockJoinQuery other = (BlockJoinQuery) _other;
+ return origChildQuery.equals(other.origChildQuery) &&
+ parentsFilter.equals(other.parentsFilter) &&
+ scoreMode == other.scoreMode;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int hash = 1;
+ hash = prime * hash + origChildQuery.hashCode();
+ hash = prime * hash + scoreMode.hashCode();
+ hash = prime * hash + parentsFilter.hashCode();
+ return hash;
+ }
+
+ @Override
+ public Object clone() {
+ return new BlockJoinQuery((Query) origChildQuery.clone(),
+ parentsFilter,
+ scoreMode);
+ }
+}