pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / src / test / org / apache / lucene / search / TestSubScorerFreqs.java
1 package org.apache.lucene.search;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.analysis.MockAnalyzer;
21 import java.io.*;
22 import java.util.*;
23 import org.apache.lucene.document.*;
24 import org.apache.lucene.index.*;
25 import org.apache.lucene.search.BooleanClause.Occur;
26 import org.apache.lucene.search.Scorer.ScorerVisitor;
27 import org.apache.lucene.store.*;
28 import org.apache.lucene.util.*;
29 import org.junit.AfterClass;
30 import org.junit.BeforeClass;
31 import org.junit.Test;
32
33 public class TestSubScorerFreqs extends LuceneTestCase {
34
35   private static Directory dir;
36   private static IndexSearcher s;
37
38   @BeforeClass
39   public static void makeIndex() throws Exception {
40     dir = new RAMDirectory();
41     RandomIndexWriter w = new RandomIndexWriter(
42                                                 random, dir, newIndexWriterConfig(TEST_VERSION_CURRENT, new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy()));
43     // make sure we have more than one segment occationally
44     int num = atLeast(31);
45     for (int i = 0; i < num; i++) {
46       Document doc = new Document();
47       doc.add(newField("f", "a b c d b c d c d d", Field.Store.NO,
48           Field.Index.ANALYZED));
49       w.addDocument(doc);
50
51       doc = new Document();
52       doc.add(newField("f", "a b c d", Field.Store.NO, Field.Index.ANALYZED));
53       w.addDocument(doc);
54     }
55
56     s = newSearcher(w.getReader());
57     w.close();
58   }
59
60   @AfterClass
61   public static void finish() throws Exception {
62     s.getIndexReader().close();
63     s.close();
64     s = null;
65     dir.close();
66     dir = null;
67   }
68
69   private static class CountingCollector extends Collector {
70     private final Collector other;
71     private int docBase;
72
73     public final Map<Integer, Map<Query, Float>> docCounts = new HashMap<Integer, Map<Query, Float>>();
74
75     private final Map<Query, Scorer> subScorers = new HashMap<Query, Scorer>();
76     private final ScorerVisitor<Query, Query, Scorer> visitor = new MockScorerVisitor();
77     private final EnumSet<Occur> collect;
78
79     private class MockScorerVisitor extends ScorerVisitor<Query, Query, Scorer> {
80
81       @Override
82       public void visitOptional(Query parent, Query child, Scorer scorer) {
83         if (collect.contains(Occur.SHOULD))
84           subScorers.put(child, scorer);
85       }
86
87       @Override
88       public void visitProhibited(Query parent, Query child, Scorer scorer) {
89         if (collect.contains(Occur.MUST_NOT))
90           subScorers.put(child, scorer);
91       }
92
93       @Override
94       public void visitRequired(Query parent, Query child, Scorer scorer) {
95         if (collect.contains(Occur.MUST))
96           subScorers.put(child, scorer);
97       }
98
99     }
100
101     public CountingCollector(Collector other) {
102       this(other, EnumSet.allOf(Occur.class));
103     }
104
105     public CountingCollector(Collector other, EnumSet<Occur> collect) {
106       this.other = other;
107       this.collect = collect;
108     }
109
110     @Override
111     public void setScorer(Scorer scorer) throws IOException {
112       other.setScorer(scorer);
113       scorer.visitScorers(visitor);
114     }
115
116     @Override
117     public void collect(int doc) throws IOException {
118       final Map<Query, Float> freqs = new HashMap<Query, Float>();
119       for (Map.Entry<Query, Scorer> ent : subScorers.entrySet()) {
120         Scorer value = ent.getValue();
121         int matchId = value.docID();
122         freqs.put(ent.getKey(), matchId == doc ? value.freq() : 0.0f);
123       }
124       docCounts.put(doc + docBase, freqs);
125       other.collect(doc);
126     }
127
128     @Override
129     public void setNextReader(IndexReader reader, int docBase)
130         throws IOException {
131       this.docBase = docBase;
132       other.setNextReader(reader, docBase);
133     }
134
135     @Override
136     public boolean acceptsDocsOutOfOrder() {
137       return other.acceptsDocsOutOfOrder();
138     }
139   }
140
141   private static final float FLOAT_TOLERANCE = 0.00001F;
142
143   @Test
144   public void testTermQuery() throws Exception {
145     TermQuery q = new TermQuery(new Term("f", "d"));
146     CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
147         true));
148     s.search(q, null, c);
149     final int maxDocs = s.maxDoc();
150     assertEquals(maxDocs, c.docCounts.size());
151     for (int i = 0; i < maxDocs; i++) {
152       Map<Query, Float> doc0 = c.docCounts.get(i);
153       assertEquals(1, doc0.size());
154       assertEquals(4.0F, doc0.get(q), FLOAT_TOLERANCE);
155
156       Map<Query, Float> doc1 = c.docCounts.get(++i);
157       assertEquals(1, doc1.size());
158       assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
159     }
160   }
161
162   @SuppressWarnings("unchecked")
163   @Test
164   public void testBooleanQuery() throws Exception {
165     TermQuery aQuery = new TermQuery(new Term("f", "a"));
166     TermQuery dQuery = new TermQuery(new Term("f", "d"));
167     TermQuery cQuery = new TermQuery(new Term("f", "c"));
168     TermQuery yQuery = new TermQuery(new Term("f", "y"));
169
170     BooleanQuery query = new BooleanQuery();
171     BooleanQuery inner = new BooleanQuery();
172
173     inner.add(cQuery, Occur.SHOULD);
174     inner.add(yQuery, Occur.MUST_NOT);
175     query.add(inner, Occur.MUST);
176     query.add(aQuery, Occur.MUST);
177     query.add(dQuery, Occur.MUST);
178     EnumSet<Occur>[] occurList = new EnumSet[] {EnumSet.of(Occur.MUST), EnumSet.of(Occur.MUST, Occur.SHOULD)};
179     for (EnumSet<Occur> occur : occurList) {
180       CountingCollector c = new CountingCollector(TopScoreDocCollector.create(
181           10, true), occur);
182       s.search(query, null, c);
183       final int maxDocs = s.maxDoc();
184       assertEquals(maxDocs, c.docCounts.size());
185       boolean includeOptional = occur.contains(Occur.SHOULD);
186       for (int i = 0; i < maxDocs; i++) {
187         Map<Query, Float> doc0 = c.docCounts.get(i);
188         assertEquals(includeOptional ? 5 : 4, doc0.size());
189         assertEquals(1.0F, doc0.get(aQuery), FLOAT_TOLERANCE);
190         assertEquals(4.0F, doc0.get(dQuery), FLOAT_TOLERANCE);
191         if (includeOptional)
192           assertEquals(3.0F, doc0.get(cQuery), FLOAT_TOLERANCE);
193
194         Map<Query, Float> doc1 = c.docCounts.get(++i);
195         assertEquals(includeOptional ? 5 : 4, doc1.size());
196         assertEquals(1.0F, doc1.get(aQuery), FLOAT_TOLERANCE);
197         assertEquals(1.0F, doc1.get(dQuery), FLOAT_TOLERANCE);
198         if (includeOptional)
199           assertEquals(1.0F, doc1.get(cQuery), FLOAT_TOLERANCE);
200
201       }
202     }
203   }
204
205   @Test
206   public void testPhraseQuery() throws Exception {
207     PhraseQuery q = new PhraseQuery();
208     q.add(new Term("f", "b"));
209     q.add(new Term("f", "c"));
210     CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10,
211         true));
212     s.search(q, null, c);
213     final int maxDocs = s.maxDoc();
214     assertEquals(maxDocs, c.docCounts.size());
215     for (int i = 0; i < maxDocs; i++) {
216       Map<Query, Float> doc0 = c.docCounts.get(i);
217       assertEquals(1, doc0.size());
218       assertEquals(2.0F, doc0.get(q), FLOAT_TOLERANCE);
219
220       Map<Query, Float> doc1 = c.docCounts.get(++i);
221       assertEquals(1, doc1.size());
222       assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE);
223     }
224
225   }
226 }