--- /dev/null
+package org.apache.lucene.search;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.analysis.MockAnalyzer;
+import java.io.*;
+import java.util.*;
+import org.apache.lucene.document.*;
+import org.apache.lucene.index.*;
+import org.apache.lucene.search.BooleanClause.Occur;
+import org.apache.lucene.search.Scorer.ScorerVisitor;
+import org.apache.lucene.store.*;
+import org.apache.lucene.util.*;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestSubScorerFreqs extends LuceneTestCase {
+
+ private static Directory dir;
+ private static IndexSearcher s;
+
+ @BeforeClass
+ public static void makeIndex() throws Exception {
+ dir = new RAMDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(
+ random, dir, newIndexWriterConfig(TEST_VERSION_CURRENT, new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy()));
+ // make sure we have more than one segment occationally
+ int num = atLeast(31);
+ for (int i = 0; i < num; i++) {
+ Document doc = new Document();
+ doc.add(newField("f", "a b c d b c d c d d", Field.Store.NO,
+ Field.Index.ANALYZED));
+ w.addDocument(doc);
+
+ doc = new Document();
+ doc.add(newField("f", "a b c d", Field.Store.NO, Field.Index.ANALYZED));
+ w.addDocument(doc);
+ }
+
+ s = newSearcher(w.getReader());
+ w.close();
+ }
+
+ @AfterClass
+ public static void finish() throws Exception {
+ s.getIndexReader().close();
+ s.close();
+ s = null;
+ dir.close();
+ dir = null;
+ }
+
+ private static class CountingCollector extends Collector {
+ private final Collector other;
+ private int docBase;
+
+ public final Map<Integer, Map<Query, Float>> docCounts = new HashMap<Integer, Map<Query, Float>>();
+
+ private final Map<Query, Scorer> subScorers = new HashMap<Query, Scorer>();
+ private final ScorerVisitor<Query, Query, Scorer> visitor = new MockScorerVisitor();
+ private final EnumSet<Occur> collect;
+
+ private class MockScorerVisitor extends ScorerVisitor<Query, Query, Scorer> {
+
+ @Override
+ public void visitOptional(Query parent, Query child, Scorer scorer) {
+ if (collect.contains(Occur.SHOULD))
+ subScorers.put(child, scorer);
+ }
+
+ @Override
+ public void visitProhibited(Query parent, Query child, Scorer scorer) {
+ if (collect.contains(Occur.MUST_NOT))
+ subScorers.put(child, scorer);
+ }
+
+ @Override
+ public void visitRequired(Query parent, Query child, Scorer scorer) {
+ if (collect.contains(Occur.MUST))
+ subScorers.put(child, scorer);
+ }
+
+ }
+
+ public CountingCollector(Collector other) {
+ this(other, EnumSet.allOf(Occur.class));
+ }
+
+ public CountingCollector(Collector other, EnumSet<Occur> collect) {
+ this.other = other;
+ this.collect = collect;
+ }
+
+ @Override
+ public void setScorer(Scorer scorer) throws IOException {
+ other.setScorer(scorer);
+ scorer.visitScorers(visitor);
+ }
+
+ @Override
+ public void collect(int doc) throws IOException {
+ final Map<Query, Float> freqs = new HashMap<Query, Float>();
+ for (Map.Entry<Query, Scorer> ent : subScorers.entrySet()) {
+ Scorer value = ent.getValue();
+ int matchId = value.docID();
+ freqs.put(ent.getKey(), matchId == doc ? value.freq() : 0.0f);
+ }
+ docCounts.put(doc + docBase, freqs);
+ other.collect(doc);
+ }
+
+ @Override
+ public void setNextReader(IndexReader reader, int docBase)
+ throws IOException {
+ this.docBase = docBase;
+ other.setNextReader(reader, docBase);
+ }
+
+ @Override
+ public boolean acceptsDocsOutOfOrder() {
+ return other.acceptsDocsOutOfOrder();
+ }
+ }
+
+ private static final float FLOAT_TOLERANCE = 0.00001F;
+
+ @Test
+ public void testTermQuery() throws Exception {
+ TermQuery q = new TermQuery(new Term("f", "d"));
+ CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
+ true));
+ s.search(q, null, c);
+ final int maxDocs = s.maxDoc();
+ assertEquals(maxDocs, c.docCounts.size());
+ for (int i = 0; i < maxDocs; i++) {
+ Map<Query, Float> doc0 = c.docCounts.get(i);
+ assertEquals(1, doc0.size());
+ assertEquals(4.0F, doc0.get(q), FLOAT_TOLERANCE);
+
+ Map<Query, Float> doc1 = c.docCounts.get(++i);
+ assertEquals(1, doc1.size());
+ assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testBooleanQuery() throws Exception {
+ TermQuery aQuery = new TermQuery(new Term("f", "a"));
+ TermQuery dQuery = new TermQuery(new Term("f", "d"));
+ TermQuery cQuery = new TermQuery(new Term("f", "c"));
+ TermQuery yQuery = new TermQuery(new Term("f", "y"));
+
+ BooleanQuery query = new BooleanQuery();
+ BooleanQuery inner = new BooleanQuery();
+
+ inner.add(cQuery, Occur.SHOULD);
+ inner.add(yQuery, Occur.MUST_NOT);
+ query.add(inner, Occur.MUST);
+ query.add(aQuery, Occur.MUST);
+ query.add(dQuery, Occur.MUST);
+ EnumSet<Occur>[] occurList = new EnumSet[] {EnumSet.of(Occur.MUST), EnumSet.of(Occur.MUST, Occur.SHOULD)};
+ for (EnumSet<Occur> occur : occurList) {
+ CountingCollector c = new CountingCollector(TopScoreDocCollector.create(
+ 10, true), occur);
+ s.search(query, null, c);
+ final int maxDocs = s.maxDoc();
+ assertEquals(maxDocs, c.docCounts.size());
+ boolean includeOptional = occur.contains(Occur.SHOULD);
+ for (int i = 0; i < maxDocs; i++) {
+ Map<Query, Float> doc0 = c.docCounts.get(i);
+ assertEquals(includeOptional ? 5 : 4, doc0.size());
+ assertEquals(1.0F, doc0.get(aQuery), FLOAT_TOLERANCE);
+ assertEquals(4.0F, doc0.get(dQuery), FLOAT_TOLERANCE);
+ if (includeOptional)
+ assertEquals(3.0F, doc0.get(cQuery), FLOAT_TOLERANCE);
+
+ Map<Query, Float> doc1 = c.docCounts.get(++i);
+ assertEquals(includeOptional ? 5 : 4, doc1.size());
+ assertEquals(1.0F, doc1.get(aQuery), FLOAT_TOLERANCE);
+ assertEquals(1.0F, doc1.get(dQuery), FLOAT_TOLERANCE);
+ if (includeOptional)
+ assertEquals(1.0F, doc1.get(cQuery), FLOAT_TOLERANCE);
+
+ }
+ }
+ }
+
+ @Test
+ public void testPhraseQuery() throws Exception {
+ PhraseQuery q = new PhraseQuery();
+ q.add(new Term("f", "b"));
+ q.add(new Term("f", "c"));
+ CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
+ true));
+ s.search(q, null, c);
+ final int maxDocs = s.maxDoc();
+ assertEquals(maxDocs, c.docCounts.size());
+ for (int i = 0; i < maxDocs; i++) {
+ Map<Query, Float> doc0 = c.docCounts.get(i);
+ assertEquals(1, doc0.size());
+ assertEquals(2.0F, doc0.get(q), FLOAT_TOLERANCE);
+
+ Map<Query, Float> doc1 = c.docCounts.get(++i);
+ assertEquals(1, doc1.size());
+ assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
+ }
+
+ }
+}