add --shared
[pylucene.git] / lucene-java-3.4.0 / lucene / src / test / org / apache / lucene / search / function / TestCustomScoreQuery.java
1 package org.apache.lucene.search.function;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19
20 import org.apache.lucene.queryParser.QueryParser;
21 import org.apache.lucene.queryParser.ParseException;
22 import org.apache.lucene.search.*;
23 import org.junit.BeforeClass;
24 import org.junit.Test;
25 import java.io.IOException;
26 import java.util.HashMap;
27 import java.util.Map;
28
29 import org.apache.lucene.index.IndexReader;
30 import org.apache.lucene.index.Term;
31
32 /**
33  * Test CustomScoreQuery search.
34  */
35 public class TestCustomScoreQuery extends FunctionTestSetup {
36
37   @BeforeClass
38   public static void beforeClass() throws Exception {
39     createIndex(true);
40   }
41
42   /**
43    * Test that CustomScoreQuery of Type.BYTE returns the expected scores.
44    */
45   @Test
46   public void testCustomScoreByte() throws Exception, ParseException {
47     // INT field values are small enough to be parsed as byte
48     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.BYTE, 1.0);
49     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.BYTE, 2.0);
50   }
51
52   /**
53    * Test that CustomScoreQuery of Type.SHORT returns the expected scores.
54    */
55   @Test
56   public void testCustomScoreShort() throws Exception, ParseException {
57     // INT field values are small enough to be parsed as short
58     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.SHORT, 1.0);
59     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.SHORT, 3.0);
60   }
61
62   /**
63    * Test that CustomScoreQuery of Type.INT returns the expected scores.
64    */
65   @Test
66   public void testCustomScoreInt() throws Exception, ParseException {
67     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.INT, 1.0);
68     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.INT, 4.0);
69   }
70
71   /**
72    * Test that CustomScoreQuery of Type.FLOAT returns the expected scores.
73    */
74   @Test
75   public void testCustomScoreFloat() throws Exception, ParseException {
76     // INT field can be parsed as float
77     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.FLOAT, 1.0);
78     doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.FLOAT, 5.0);
79     // same values, but in float format
80     doTestCustomScore(FLOAT_FIELD, FieldScoreQuery.Type.FLOAT, 1.0);
81     doTestCustomScore(FLOAT_FIELD, FieldScoreQuery.Type.FLOAT, 6.0);
82   }
83
84   // must have static class otherwise serialization tests fail
85   private static class CustomAddQuery extends CustomScoreQuery {
86     // constructor
87     CustomAddQuery(Query q, ValueSourceQuery qValSrc) {
88       super(q, qValSrc);
89     }
90
91     /*(non-Javadoc) @see org.apache.lucene.search.function.CustomScoreQuery#name() */
92     @Override
93     public String name() {
94       return "customAdd";
95     }
96     
97     @Override
98     protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) {
99       return new CustomScoreProvider(reader) {
100         @Override
101         public float customScore(int doc, float subQueryScore, float valSrcScore) {
102           return subQueryScore + valSrcScore;
103         }
104
105         @Override
106         public Explanation customExplain(int doc, Explanation subQueryExpl, Explanation valSrcExpl) {
107           float valSrcScore = valSrcExpl == null ? 0 : valSrcExpl.getValue();
108           Explanation exp = new Explanation(valSrcScore + subQueryExpl.getValue(), "custom score: sum of:");
109           exp.addDetail(subQueryExpl);
110           if (valSrcExpl != null) {
111             exp.addDetail(valSrcExpl);
112           }
113           return exp;
114         }
115       };
116     }
117   }
118
119   // must have static class otherwise serialization tests fail
120   private static class CustomMulAddQuery extends CustomScoreQuery {
121     // constructor
122     CustomMulAddQuery(Query q, ValueSourceQuery qValSrc1, ValueSourceQuery qValSrc2) {
123       super(q, new ValueSourceQuery[]{qValSrc1, qValSrc2});
124     }
125
126     /*(non-Javadoc) @see org.apache.lucene.search.function.CustomScoreQuery#name() */
127     @Override
128     public String name() {
129       return "customMulAdd";
130     }
131
132     @Override
133     protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) {
134       return new CustomScoreProvider(reader) {
135         @Override
136         public float customScore(int doc, float subQueryScore, float valSrcScores[]) {
137           if (valSrcScores.length == 0) {
138             return subQueryScore;
139           }
140           if (valSrcScores.length == 1) {
141             return subQueryScore + valSrcScores[0];
142             // confirm that skipping beyond the last doc, on the
143             // previous reader, hits NO_MORE_DOCS
144           }
145           return (subQueryScore + valSrcScores[0]) * valSrcScores[1]; // we know there are two
146         }
147
148         @Override
149         public Explanation customExplain(int doc, Explanation subQueryExpl, Explanation valSrcExpls[]) {
150           if (valSrcExpls.length == 0) {
151             return subQueryExpl;
152           }
153           Explanation exp = new Explanation(valSrcExpls[0].getValue() + subQueryExpl.getValue(), "sum of:");
154           exp.addDetail(subQueryExpl);
155           exp.addDetail(valSrcExpls[0]);
156           if (valSrcExpls.length == 1) {
157             exp.setDescription("CustomMulAdd, sum of:");
158             return exp;
159           }
160           Explanation exp2 = new Explanation(valSrcExpls[1].getValue() * exp.getValue(), "custom score: product of:");
161           exp2.addDetail(valSrcExpls[1]);
162           exp2.addDetail(exp);
163           return exp2;
164         }
165       };
166     }
167   }
168
169   private final class CustomExternalQuery extends CustomScoreQuery {
170
171     @Override
172     protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) throws IOException {
173       final int[] values = FieldCache.DEFAULT.getInts(reader, INT_FIELD);
174       return new CustomScoreProvider(reader) {
175         @Override
176         public float customScore(int doc, float subScore, float valSrcScore) throws IOException {
177           assertTrue(doc <= reader.maxDoc());
178           return values[doc];
179         }
180       };
181     }
182
183     public CustomExternalQuery(Query q) {
184       super(q);
185     }
186   }
187
188   @Test
189   public void testCustomExternalQuery() throws Exception {
190     QueryParser qp = new QueryParser(TEST_VERSION_CURRENT, TEXT_FIELD,anlzr); 
191     String qtxt = "first aid text"; // from the doc texts in FunctionQuerySetup.
192     Query q1 = qp.parse(qtxt); 
193     
194     final Query q = new CustomExternalQuery(q1);
195     log(q);
196
197     IndexSearcher s = new IndexSearcher(dir, true);
198     TopDocs hits = s.search(q, 1000);
199     assertEquals(N_DOCS, hits.totalHits);
200     for(int i=0;i<N_DOCS;i++) {
201       final int doc = hits.scoreDocs[i].doc;
202       final float score = hits.scoreDocs[i].score;
203       assertEquals("doc=" + doc, (float) 1+(4*doc) % N_DOCS, score, 0.0001);
204     }
205     s.close();
206   }
207   
208   @Test
209   public void testRewrite() throws Exception {
210     final IndexSearcher s = new IndexSearcher(dir, true);
211
212     Query q = new TermQuery(new Term(TEXT_FIELD, "first"));
213     CustomScoreQuery original = new CustomScoreQuery(q);
214     CustomScoreQuery rewritten = (CustomScoreQuery) original.rewrite(s.getIndexReader());
215     assertTrue("rewritten query should be identical, as TermQuery does not rewrite", original == rewritten);
216     assertTrue("no hits for query", s.search(rewritten,1).totalHits > 0);
217     assertEquals(s.search(q,1).totalHits, s.search(rewritten,1).totalHits);
218
219     q = new TermRangeQuery(TEXT_FIELD, null, null, true, true); // everything
220     original = new CustomScoreQuery(q);
221     rewritten = (CustomScoreQuery) original.rewrite(s.getIndexReader());
222     assertTrue("rewritten query should not be identical, as TermRangeQuery rewrites", original != rewritten);
223     assertTrue("no hits for query", s.search(rewritten,1).totalHits > 0);
224     assertEquals(s.search(q,1).totalHits, s.search(original,1).totalHits);
225     assertEquals(s.search(q,1).totalHits, s.search(rewritten,1).totalHits);
226     
227     s.close();
228   }
229   
230   // Test that FieldScoreQuery returns docs with expected score.
231   private void doTestCustomScore(String field, FieldScoreQuery.Type tp, double dboost) throws Exception, ParseException {
232     float boost = (float) dboost;
233     IndexSearcher s = new IndexSearcher(dir, true);
234     FieldScoreQuery qValSrc = new FieldScoreQuery(field, tp); // a query that would score by the field
235     QueryParser qp = new QueryParser(TEST_VERSION_CURRENT, TEXT_FIELD, anlzr);
236     String qtxt = "first aid text"; // from the doc texts in FunctionQuerySetup.
237
238     // regular (boolean) query.
239     Query q1 = qp.parse(qtxt);
240     log(q1);
241
242     // custom query, that should score the same as q1.
243     Query q2CustomNeutral = new CustomScoreQuery(q1);
244     q2CustomNeutral.setBoost(boost);
245     log(q2CustomNeutral);
246
247     // custom query, that should (by default) multiply the scores of q1 by that of the field
248     CustomScoreQuery q3CustomMul = new CustomScoreQuery(q1, qValSrc);
249     q3CustomMul.setStrict(true);
250     q3CustomMul.setBoost(boost);
251     log(q3CustomMul);
252
253     // custom query, that should add the scores of q1 to that of the field
254     CustomScoreQuery q4CustomAdd = new CustomAddQuery(q1, qValSrc);
255     q4CustomAdd.setStrict(true);
256     q4CustomAdd.setBoost(boost);
257     log(q4CustomAdd);
258
259     // custom query, that multiplies and adds the field score to that of q1
260     CustomScoreQuery q5CustomMulAdd = new CustomMulAddQuery(q1, qValSrc, qValSrc);
261     q5CustomMulAdd.setStrict(true);
262     q5CustomMulAdd.setBoost(boost);
263     log(q5CustomMulAdd);
264
265     // do al the searches 
266     TopDocs td1 = s.search(q1, null, 1000);
267     TopDocs td2CustomNeutral = s.search(q2CustomNeutral, null, 1000);
268     TopDocs td3CustomMul = s.search(q3CustomMul, null, 1000);
269     TopDocs td4CustomAdd = s.search(q4CustomAdd, null, 1000);
270     TopDocs td5CustomMulAdd = s.search(q5CustomMulAdd, null, 1000);
271
272     // put results in map so we can verify the scores although they have changed
273     Map<Integer,Float> h1               = topDocsToMap(td1);
274     Map<Integer,Float> h2CustomNeutral  = topDocsToMap(td2CustomNeutral);
275     Map<Integer,Float> h3CustomMul      = topDocsToMap(td3CustomMul);
276     Map<Integer,Float> h4CustomAdd      = topDocsToMap(td4CustomAdd);
277     Map<Integer,Float> h5CustomMulAdd   = topDocsToMap(td5CustomMulAdd);
278     
279     verifyResults(boost, s, 
280         h1, h2CustomNeutral, h3CustomMul, h4CustomAdd, h5CustomMulAdd,
281         q1, q2CustomNeutral, q3CustomMul, q4CustomAdd, q5CustomMulAdd);
282     s.close();
283   }
284
285   // verify results are as expected.
286   private void verifyResults(float boost, IndexSearcher s, 
287       Map<Integer,Float> h1, Map<Integer,Float> h2customNeutral, Map<Integer,Float> h3CustomMul, Map<Integer,Float> h4CustomAdd, Map<Integer,Float> h5CustomMulAdd,
288       Query q1, Query q2, Query q3, Query q4, Query q5) throws Exception {
289     
290     // verify numbers of matches
291     log("#hits = "+h1.size());
292     assertEquals("queries should have same #hits",h1.size(),h2customNeutral.size());
293     assertEquals("queries should have same #hits",h1.size(),h3CustomMul.size());
294     assertEquals("queries should have same #hits",h1.size(),h4CustomAdd.size());
295     assertEquals("queries should have same #hits",h1.size(),h5CustomMulAdd.size());
296
297     QueryUtils.check(random, q1,s);
298     QueryUtils.check(random, q2,s);
299     QueryUtils.check(random, q3,s);
300     QueryUtils.check(random, q4,s);
301     QueryUtils.check(random, q5,s);
302
303     // verify scores ratios
304     for (final Integer doc : h1.keySet()) {
305
306       log("doc = "+doc);
307
308       float fieldScore = expectedFieldScore(s.getIndexReader().document(doc).get(ID_FIELD));
309       log("fieldScore = " + fieldScore);
310       assertTrue("fieldScore should not be 0", fieldScore > 0);
311
312       float score1 = h1.get(doc);
313       logResult("score1=", s, q1, doc, score1);
314       
315       float score2 = h2customNeutral.get(doc);
316       logResult("score2=", s, q2, doc, score2);
317       assertEquals("same score (just boosted) for neutral", boost * score1, score2, TEST_SCORE_TOLERANCE_DELTA);
318
319       float score3 = h3CustomMul.get(doc);
320       logResult("score3=", s, q3, doc, score3);
321       assertEquals("new score for custom mul", boost * fieldScore * score1, score3, TEST_SCORE_TOLERANCE_DELTA);
322       
323       float score4 = h4CustomAdd.get(doc);
324       logResult("score4=", s, q4, doc, score4);
325       assertEquals("new score for custom add", boost * (fieldScore + score1), score4, TEST_SCORE_TOLERANCE_DELTA);
326       
327       float score5 = h5CustomMulAdd.get(doc);
328       logResult("score5=", s, q5, doc, score5);
329       assertEquals("new score for custom mul add", boost * fieldScore * (score1 + fieldScore), score5, TEST_SCORE_TOLERANCE_DELTA);
330     }
331   }
332
333   private void logResult(String msg, Searcher s, Query q, int doc, float score1) throws IOException {
334     log(msg+" "+score1);
335     log("Explain by: "+q);
336     log(s.explain(q,doc));
337   }
338
339   // since custom scoring modifies the order of docs, map results 
340   // by doc ids so that we can later compare/verify them 
341   private Map<Integer,Float> topDocsToMap(TopDocs td) {
342     Map<Integer,Float> h = new HashMap<Integer,Float>();
343     for (int i=0; i<td.totalHits; i++) {
344       h.put(td.scoreDocs[i].doc, td.scoreDocs[i].score);
345     }
346     return h;
347   }
348
349 }