pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / src / test / org / apache / lucene / search / TestSubScorerFreqs.java
diff --git a/lucene-java-3.5.0/lucene/src/test/org/apache/lucene/search/TestSubScorerFreqs.java b/lucene-java-3.5.0/lucene/src/test/org/apache/lucene/search/TestSubScorerFreqs.java
new file mode 100644 (file)
index 0000000..46d7ed9
--- /dev/null
@@ -0,0 +1,226 @@
+package org.apache.lucene.search;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.analysis.MockAnalyzer;
+import java.io.*;
+import java.util.*;
+import org.apache.lucene.document.*;
+import org.apache.lucene.index.*;
+import org.apache.lucene.search.BooleanClause.Occur;
+import org.apache.lucene.search.Scorer.ScorerVisitor;
+import org.apache.lucene.store.*;
+import org.apache.lucene.util.*;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestSubScorerFreqs extends LuceneTestCase {
+
+  private static Directory dir;
+  private static IndexSearcher s;
+
+  @BeforeClass
+  public static void makeIndex() throws Exception {
+    dir = new RAMDirectory();
+    RandomIndexWriter w = new RandomIndexWriter(
+                                                random, dir, newIndexWriterConfig(TEST_VERSION_CURRENT, new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy()));
+    // make sure we have more than one segment occationally
+    int num = atLeast(31);
+    for (int i = 0; i < num; i++) {
+      Document doc = new Document();
+      doc.add(newField("f", "a b c d b c d c d d", Field.Store.NO,
+          Field.Index.ANALYZED));
+      w.addDocument(doc);
+
+      doc = new Document();
+      doc.add(newField("f", "a b c d", Field.Store.NO, Field.Index.ANALYZED));
+      w.addDocument(doc);
+    }
+
+    s = newSearcher(w.getReader());
+    w.close();
+  }
+
+  @AfterClass
+  public static void finish() throws Exception {
+    s.getIndexReader().close();
+    s.close();
+    s = null;
+    dir.close();
+    dir = null;
+  }
+
+  private static class CountingCollector extends Collector {
+    private final Collector other;
+    private int docBase;
+
+    public final Map<Integer, Map<Query, Float>> docCounts = new HashMap<Integer, Map<Query, Float>>();
+
+    private final Map<Query, Scorer> subScorers = new HashMap<Query, Scorer>();
+    private final ScorerVisitor<Query, Query, Scorer> visitor = new MockScorerVisitor();
+    private final EnumSet<Occur> collect;
+
+    private class MockScorerVisitor extends ScorerVisitor<Query, Query, Scorer> {
+
+      @Override
+      public void visitOptional(Query parent, Query child, Scorer scorer) {
+        if (collect.contains(Occur.SHOULD))
+          subScorers.put(child, scorer);
+      }
+
+      @Override
+      public void visitProhibited(Query parent, Query child, Scorer scorer) {
+        if (collect.contains(Occur.MUST_NOT))
+          subScorers.put(child, scorer);
+      }
+
+      @Override
+      public void visitRequired(Query parent, Query child, Scorer scorer) {
+        if (collect.contains(Occur.MUST))
+          subScorers.put(child, scorer);
+      }
+
+    }
+
+    public CountingCollector(Collector other) {
+      this(other, EnumSet.allOf(Occur.class));
+    }
+
+    public CountingCollector(Collector other, EnumSet<Occur> collect) {
+      this.other = other;
+      this.collect = collect;
+    }
+
+    @Override
+    public void setScorer(Scorer scorer) throws IOException {
+      other.setScorer(scorer);
+      scorer.visitScorers(visitor);
+    }
+
+    @Override
+    public void collect(int doc) throws IOException {
+      final Map<Query, Float> freqs = new HashMap<Query, Float>();
+      for (Map.Entry<Query, Scorer> ent : subScorers.entrySet()) {
+        Scorer value = ent.getValue();
+        int matchId = value.docID();
+        freqs.put(ent.getKey(), matchId == doc ? value.freq() : 0.0f);
+      }
+      docCounts.put(doc + docBase, freqs);
+      other.collect(doc);
+    }
+
+    @Override
+    public void setNextReader(IndexReader reader, int docBase)
+        throws IOException {
+      this.docBase = docBase;
+      other.setNextReader(reader, docBase);
+    }
+
+    @Override
+    public boolean acceptsDocsOutOfOrder() {
+      return other.acceptsDocsOutOfOrder();
+    }
+  }
+
+  private static final float FLOAT_TOLERANCE = 0.00001F;
+
+  @Test
+  public void testTermQuery() throws Exception {
+    TermQuery q = new TermQuery(new Term("f", "d"));
+    CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
+        true));
+    s.search(q, null, c);
+    final int maxDocs = s.maxDoc();
+    assertEquals(maxDocs, c.docCounts.size());
+    for (int i = 0; i < maxDocs; i++) {
+      Map<Query, Float> doc0 = c.docCounts.get(i);
+      assertEquals(1, doc0.size());
+      assertEquals(4.0F, doc0.get(q), FLOAT_TOLERANCE);
+
+      Map<Query, Float> doc1 = c.docCounts.get(++i);
+      assertEquals(1, doc1.size());
+      assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void testBooleanQuery() throws Exception {
+    TermQuery aQuery = new TermQuery(new Term("f", "a"));
+    TermQuery dQuery = new TermQuery(new Term("f", "d"));
+    TermQuery cQuery = new TermQuery(new Term("f", "c"));
+    TermQuery yQuery = new TermQuery(new Term("f", "y"));
+
+    BooleanQuery query = new BooleanQuery();
+    BooleanQuery inner = new BooleanQuery();
+
+    inner.add(cQuery, Occur.SHOULD);
+    inner.add(yQuery, Occur.MUST_NOT);
+    query.add(inner, Occur.MUST);
+    query.add(aQuery, Occur.MUST);
+    query.add(dQuery, Occur.MUST);
+    EnumSet<Occur>[] occurList = new EnumSet[] {EnumSet.of(Occur.MUST), EnumSet.of(Occur.MUST, Occur.SHOULD)};
+    for (EnumSet<Occur> occur : occurList) {
+      CountingCollector c = new CountingCollector(TopScoreDocCollector.create(
+          10, true), occur);
+      s.search(query, null, c);
+      final int maxDocs = s.maxDoc();
+      assertEquals(maxDocs, c.docCounts.size());
+      boolean includeOptional = occur.contains(Occur.SHOULD);
+      for (int i = 0; i < maxDocs; i++) {
+        Map<Query, Float> doc0 = c.docCounts.get(i);
+        assertEquals(includeOptional ? 5 : 4, doc0.size());
+        assertEquals(1.0F, doc0.get(aQuery), FLOAT_TOLERANCE);
+        assertEquals(4.0F, doc0.get(dQuery), FLOAT_TOLERANCE);
+        if (includeOptional)
+          assertEquals(3.0F, doc0.get(cQuery), FLOAT_TOLERANCE);
+
+        Map<Query, Float> doc1 = c.docCounts.get(++i);
+        assertEquals(includeOptional ? 5 : 4, doc1.size());
+        assertEquals(1.0F, doc1.get(aQuery), FLOAT_TOLERANCE);
+        assertEquals(1.0F, doc1.get(dQuery), FLOAT_TOLERANCE);
+        if (includeOptional)
+          assertEquals(1.0F, doc1.get(cQuery), FLOAT_TOLERANCE);
+
+      }
+    }
+  }
+
+  @Test
+  public void testPhraseQuery() throws Exception {
+    PhraseQuery q = new PhraseQuery();
+    q.add(new Term("f", "b"));
+    q.add(new Term("f", "c"));
+    CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
+        true));
+    s.search(q, null, c);
+    final int maxDocs = s.maxDoc();
+    assertEquals(maxDocs, c.docCounts.size());
+    for (int i = 0; i < maxDocs; i++) {
+      Map<Query, Float> doc0 = c.docCounts.get(i);
+      assertEquals(1, doc0.size());
+      assertEquals(2.0F, doc0.get(q), FLOAT_TOLERANCE);
+
+      Map<Query, Float> doc1 = c.docCounts.get(++i);
+      assertEquals(1, doc1.size());
+      assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
+    }
+
+  }
+}