pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / highlighter / src / java / org / apache / lucene / search / highlight / WeightedSpanTermExtractor.java
1 package org.apache.lucene.search.highlight;
2
3 /**
4  * Licensed to the Apache Software Foundation (ASF) under one or more
5  * contributor license agreements.  See the NOTICE file distributed with
6  * this work for additional information regarding copyright ownership.
7  * The ASF licenses this file to You under the Apache License, Version 2.0
8  * (the "License"); you may not use this file except in compliance with
9  * the License.  You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 import java.io.IOException;
20 import java.util.ArrayList;
21 import java.util.Collection;
22 import java.util.HashMap;
23 import java.util.HashSet;
24 import java.util.Iterator;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Set;
28
29 import org.apache.lucene.analysis.CachingTokenFilter;
30 import org.apache.lucene.analysis.TokenStream;
31 import org.apache.lucene.index.FilterIndexReader;
32 import org.apache.lucene.index.IndexReader;
33 import org.apache.lucene.index.Term;
34 import org.apache.lucene.index.TermEnum;
35 import org.apache.lucene.index.memory.MemoryIndex;
36 import org.apache.lucene.search.*;
37 import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
38 import org.apache.lucene.search.spans.SpanFirstQuery;
39 import org.apache.lucene.search.spans.SpanNearQuery;
40 import org.apache.lucene.search.spans.SpanNotQuery;
41 import org.apache.lucene.search.spans.SpanOrQuery;
42 import org.apache.lucene.search.spans.SpanQuery;
43 import org.apache.lucene.search.spans.SpanTermQuery;
44 import org.apache.lucene.search.spans.Spans;
45 import org.apache.lucene.util.StringHelper;
46
47 /**
48  * Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether 
49  * {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
50  */
51 public class WeightedSpanTermExtractor {
52
53   private String fieldName;
54   private TokenStream tokenStream;
55   private Map<String,IndexReader> readers = new HashMap<String,IndexReader>(10); 
56   private String defaultField;
57   private boolean expandMultiTermQuery;
58   private boolean cachedTokenStream;
59   private boolean wrapToCaching = true;
60   private int maxDocCharsToAnalyze;
61
62   public WeightedSpanTermExtractor() {
63   }
64
65   public WeightedSpanTermExtractor(String defaultField) {
66     if (defaultField != null) {
67       this.defaultField = StringHelper.intern(defaultField);
68     }
69   }
70
71   private void closeReaders() {
72     Collection<IndexReader> readerSet = readers.values();
73
74     for (final IndexReader reader : readerSet) {
75       try {
76         reader.close();
77       } catch (IOException e) {
78         // alert?
79       }
80     }
81   }
82
83   /**
84    * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
85    * 
86    * @param query
87    *          Query to extract Terms from
88    * @param terms
89    *          Map to place created WeightedSpanTerms in
90    * @throws IOException
91    */
92   private void extract(Query query, Map<String,WeightedSpanTerm> terms) throws IOException {
93     if (query instanceof BooleanQuery) {
94       BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();
95
96       for (int i = 0; i < queryClauses.length; i++) {
97         if (!queryClauses[i].isProhibited()) {
98           extract(queryClauses[i].getQuery(), terms);
99         }
100       }
101     } else if (query instanceof PhraseQuery) {
102       PhraseQuery phraseQuery = ((PhraseQuery) query);
103       Term[] phraseQueryTerms = phraseQuery.getTerms();
104       SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
105       for (int i = 0; i < phraseQueryTerms.length; i++) {
106         clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
107       }
108       int slop = phraseQuery.getSlop();
109       int[] positions = phraseQuery.getPositions();
110       // add largest position increment to slop
111       if (positions.length > 0) {
112         int lastPos = positions[0];
113         int largestInc = 0;
114         int sz = positions.length;
115         for (int i = 1; i < sz; i++) {
116           int pos = positions[i];
117           int inc = pos - lastPos;
118           if (inc > largestInc) {
119             largestInc = inc;
120           }
121           lastPos = pos;
122         }
123         if(largestInc > 1) {
124           slop += largestInc;
125         }
126       }
127
128       boolean inorder = false;
129
130       if (slop == 0) {
131         inorder = true;
132       }
133
134       SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
135       sp.setBoost(query.getBoost());
136       extractWeightedSpanTerms(terms, sp);
137     } else if (query instanceof TermQuery) {
138       extractWeightedTerms(terms, query);
139     } else if (query instanceof SpanQuery) {
140       extractWeightedSpanTerms(terms, (SpanQuery) query);
141     } else if (query instanceof FilteredQuery) {
142       extract(((FilteredQuery) query).getQuery(), terms);
143     } else if (query instanceof DisjunctionMaxQuery) {
144       for (Iterator<Query> iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
145         extract(iterator.next(), terms);
146       }
147     } else if (query instanceof MultiTermQuery && expandMultiTermQuery) {
148       MultiTermQuery mtq = ((MultiTermQuery)query);
149       if(mtq.getRewriteMethod() != MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE) {
150         mtq = (MultiTermQuery) mtq.clone();
151         mtq.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
152         query = mtq;
153       }
154       FakeReader fReader = new FakeReader();
155       MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE.rewrite(fReader, mtq);
156       if (fReader.field != null) {
157         IndexReader ir = getReaderForField(fReader.field);
158         extract(query.rewrite(ir), terms);
159       }
160     } else if (query instanceof MultiPhraseQuery) {
161       final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
162       final List<Term[]> termArrays = mpq.getTermArrays();
163       final int[] positions = mpq.getPositions();
164       if (positions.length > 0) {
165
166         int maxPosition = positions[positions.length - 1];
167         for (int i = 0; i < positions.length - 1; ++i) {
168           if (positions[i] > maxPosition) {
169             maxPosition = positions[i];
170           }
171         }
172
173         @SuppressWarnings("unchecked")
174         final List<SpanQuery>[] disjunctLists = new List[maxPosition + 1];
175         int distinctPositions = 0;
176
177         for (int i = 0; i < termArrays.size(); ++i) {
178           final Term[] termArray = termArrays.get(i);
179           List<SpanQuery> disjuncts = disjunctLists[positions[i]];
180           if (disjuncts == null) {
181             disjuncts = (disjunctLists[positions[i]] = new ArrayList<SpanQuery>(termArray.length));
182             ++distinctPositions;
183           }
184           for (int j = 0; j < termArray.length; ++j) {
185             disjuncts.add(new SpanTermQuery(termArray[j]));
186           }
187         }
188
189         int positionGaps = 0;
190         int position = 0;
191         final SpanQuery[] clauses = new SpanQuery[distinctPositions];
192         for (int i = 0; i < disjunctLists.length; ++i) {
193           List<SpanQuery> disjuncts = disjunctLists[i];
194           if (disjuncts != null) {
195             clauses[position++] = new SpanOrQuery(disjuncts
196                 .toArray(new SpanQuery[disjuncts.size()]));
197           } else {
198             ++positionGaps;
199           }
200         }
201
202         final int slop = mpq.getSlop();
203         final boolean inorder = (slop == 0);
204
205         SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
206         sp.setBoost(query.getBoost());
207         extractWeightedSpanTerms(terms, sp);
208       }
209     }
210   }
211
212   /**
213    * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>SpanQuery</code>.
214    * 
215    * @param terms
216    *          Map to place created WeightedSpanTerms in
217    * @param spanQuery
218    *          SpanQuery to extract Terms from
219    * @throws IOException
220    */
221   private void extractWeightedSpanTerms(Map<String,WeightedSpanTerm> terms, SpanQuery spanQuery) throws IOException {
222     Set<String> fieldNames;
223
224     if (fieldName == null) {
225       fieldNames = new HashSet<String>();
226       collectSpanQueryFields(spanQuery, fieldNames);
227     } else {
228       fieldNames = new HashSet<String>(1);
229       fieldNames.add(fieldName);
230     }
231     // To support the use of the default field name
232     if (defaultField != null) {
233       fieldNames.add(defaultField);
234     }
235     
236     Map<String, SpanQuery> queries = new HashMap<String, SpanQuery>();
237  
238     Set<Term> nonWeightedTerms = new HashSet<Term>();
239     final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
240     if (mustRewriteQuery) {
241       for (final String field : fieldNames) {
242         final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getReaderForField(field));
243         queries.put(field, rewrittenQuery);
244         rewrittenQuery.extractTerms(nonWeightedTerms);
245       }
246     } else {
247       spanQuery.extractTerms(nonWeightedTerms);
248     }
249
250     List<PositionSpan> spanPositions = new ArrayList<PositionSpan>();
251
252     for (final String field : fieldNames) {
253
254       IndexReader reader = getReaderForField(field);
255       final Spans spans;
256       if (mustRewriteQuery) {
257         spans = queries.get(field).getSpans(reader);
258       } else {
259         spans = spanQuery.getSpans(reader);
260       }
261
262
263       // collect span positions
264       while (spans.next()) {
265         spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
266       }
267       
268     }
269
270     if (spanPositions.size() == 0) {
271       // no spans found
272       return;
273     }
274
275     for (final Term queryTerm :  nonWeightedTerms) {
276
277       if (fieldNameComparator(queryTerm.field())) {
278         WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());
279
280         if (weightedSpanTerm == null) {
281           weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
282           weightedSpanTerm.addPositionSpans(spanPositions);
283           weightedSpanTerm.positionSensitive = true;
284           terms.put(queryTerm.text(), weightedSpanTerm);
285         } else {
286           if (spanPositions.size() > 0) {
287             weightedSpanTerm.addPositionSpans(spanPositions);
288           }
289         }
290       }
291     }
292   }
293
294   /**
295    * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
296    * 
297    * @param terms
298    *          Map to place created WeightedSpanTerms in
299    * @param query
300    *          Query to extract Terms from
301    * @throws IOException
302    */
303   private void extractWeightedTerms(Map<String,WeightedSpanTerm> terms, Query query) throws IOException {
304     Set<Term> nonWeightedTerms = new HashSet<Term>();
305     query.extractTerms(nonWeightedTerms);
306
307     for (final Term queryTerm : nonWeightedTerms) {
308
309       if (fieldNameComparator(queryTerm.field())) {
310         WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
311         terms.put(queryTerm.text(), weightedSpanTerm);
312       }
313     }
314   }
315
316   /**
317    * Necessary to implement matches for queries against <code>defaultField</code>
318    */
319   private boolean fieldNameComparator(String fieldNameToCheck) {
320     boolean rv = fieldName == null || fieldNameToCheck == fieldName
321         || fieldNameToCheck == defaultField;
322     return rv;
323   }
324
325   private IndexReader getReaderForField(String field) throws IOException {
326     if(wrapToCaching && !cachedTokenStream && !(tokenStream instanceof CachingTokenFilter)) {
327       tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
328       cachedTokenStream = true;
329     }
330     IndexReader reader = readers.get(field);
331     if (reader == null) {
332       MemoryIndex indexer = new MemoryIndex();
333       indexer.addField(field, new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
334       tokenStream.reset();
335       IndexSearcher searcher = indexer.createSearcher();
336       reader = searcher.getIndexReader();
337       readers.put(field, reader);
338     }
339
340     return reader;
341   }
342
343   /**
344    * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
345    * 
346    * <p>
347    * 
348    * @param query
349    *          that caused hit
350    * @param tokenStream
351    *          of text to be highlighted
352    * @return Map containing WeightedSpanTerms
353    * @throws IOException
354    */
355   public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream)
356       throws IOException {
357     return getWeightedSpanTerms(query, tokenStream, null);
358   }
359
360   /**
361    * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
362    * 
363    * <p>
364    * 
365    * @param query
366    *          that caused hit
367    * @param tokenStream
368    *          of text to be highlighted
369    * @param fieldName
370    *          restricts Term's used based on field name
371    * @return Map containing WeightedSpanTerms
372    * @throws IOException
373    */
374   public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream,
375       String fieldName) throws IOException {
376     if (fieldName != null) {
377       this.fieldName = StringHelper.intern(fieldName);
378     } else {
379       this.fieldName = null;
380     }
381
382     Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
383     this.tokenStream = tokenStream;
384     try {
385       extract(query, terms);
386     } finally {
387       closeReaders();
388     }
389
390     return terms;
391   }
392
393   /**
394    * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>. Uses a supplied
395    * <code>IndexReader</code> to properly weight terms (for gradient highlighting).
396    * 
397    * <p>
398    * 
399    * @param query
400    *          that caused hit
401    * @param tokenStream
402    *          of text to be highlighted
403    * @param fieldName
404    *          restricts Term's used based on field name
405    * @param reader
406    *          to use for scoring
407    * @return Map of WeightedSpanTerms with quasi tf/idf scores
408    * @throws IOException
409    */
410   public Map<String,WeightedSpanTerm> getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream, String fieldName,
411       IndexReader reader) throws IOException {
412     if (fieldName != null) {
413       this.fieldName = StringHelper.intern(fieldName);
414     } else {
415       this.fieldName = null;
416     }
417     this.tokenStream = tokenStream;
418
419     Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
420     extract(query, terms);
421
422     int totalNumDocs = reader.numDocs();
423     Set<String> weightedTerms = terms.keySet();
424     Iterator<String> it = weightedTerms.iterator();
425
426     try {
427       while (it.hasNext()) {
428         WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
429         int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
430         // docFreq counts deletes
431         if(totalNumDocs < docFreq) {
432           docFreq = totalNumDocs;
433         }
434         // IDF algorithm taken from DefaultSimilarity class
435         float idf = (float) (Math.log((float) totalNumDocs / (double) (docFreq + 1)) + 1.0);
436         weightedSpanTerm.weight *= idf;
437       }
438     } finally {
439
440       closeReaders();
441     }
442
443     return terms;
444   }
445   
446   private void collectSpanQueryFields(SpanQuery spanQuery, Set<String> fieldNames) {
447     if (spanQuery instanceof FieldMaskingSpanQuery) {
448       collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames);
449     } else if (spanQuery instanceof SpanFirstQuery) {
450       collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames);
451     } else if (spanQuery instanceof SpanNearQuery) {
452       for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
453         collectSpanQueryFields(clause, fieldNames);
454       }
455     } else if (spanQuery instanceof SpanNotQuery) {
456       collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames);
457     } else if (spanQuery instanceof SpanOrQuery) {
458       for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
459         collectSpanQueryFields(clause, fieldNames);
460       }
461     } else {
462       fieldNames.add(spanQuery.getField());
463     }
464   }
465   
466   private boolean mustRewriteQuery(SpanQuery spanQuery) {
467     if (!expandMultiTermQuery) {
468       return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
469     } else if (spanQuery instanceof FieldMaskingSpanQuery) {
470       return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery());
471     } else if (spanQuery instanceof SpanFirstQuery) {
472       return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch());
473     } else if (spanQuery instanceof SpanNearQuery) {
474       for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
475         if (mustRewriteQuery(clause)) {
476           return true;
477         }
478       }
479       return false; 
480     } else if (spanQuery instanceof SpanNotQuery) {
481       SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery;
482       return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
483     } else if (spanQuery instanceof SpanOrQuery) {
484       for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
485         if (mustRewriteQuery(clause)) {
486           return true;
487         }
488       }
489       return false; 
490     } else if (spanQuery instanceof SpanTermQuery) {
491       return false;
492     } else {
493       return true;
494     }
495   }
496   
497   /**
498    * This class makes sure that if both position sensitive and insensitive
499    * versions of the same term are added, the position insensitive one wins.
500    */
501   static private class PositionCheckingMap<K> extends HashMap<K,WeightedSpanTerm> {
502
503     @Override
504     public void putAll(Map<? extends K,? extends WeightedSpanTerm> m) {
505       for (Map.Entry<? extends K,? extends WeightedSpanTerm> entry : m.entrySet())
506         this.put(entry.getKey(), entry.getValue());
507     }
508
509     @Override
510     public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
511       WeightedSpanTerm prev = super.put(key, value);
512       if (prev == null) return prev;
513       WeightedSpanTerm prevTerm = prev;
514       WeightedSpanTerm newTerm = value;
515       if (!prevTerm.positionSensitive) {
516         newTerm.positionSensitive = false;
517       }
518       return prev;
519     }
520     
521   }
522   
523   public boolean getExpandMultiTermQuery() {
524     return expandMultiTermQuery;
525   }
526
527   public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
528     this.expandMultiTermQuery = expandMultiTermQuery;
529   }
530   
531   public boolean isCachedTokenStream() {
532     return cachedTokenStream;
533   }
534   
535   public TokenStream getTokenStream() {
536     return tokenStream;
537   }
538   
539   /**
540    * By default, {@link TokenStream}s that are not of the type
541    * {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
542    * ensure an efficient reset - if you are already using a different caching
543    * {@link TokenStream} impl and you don't want it to be wrapped, set this to
544    * false.
545    * 
546    * @param wrap
547    */
548   public void setWrapIfNotCachingTokenFilter(boolean wrap) {
549     this.wrapToCaching = wrap;
550   }
551   
552   /**
553    * 
554    * A fake IndexReader class to extract the field from a MultiTermQuery
555    * 
556    */
557   static final class FakeReader extends FilterIndexReader {
558
559     private static final IndexReader EMPTY_MEMORY_INDEX_READER =
560       new MemoryIndex().createSearcher().getIndexReader();
561     
562     String field;
563
564     FakeReader() {
565       super(EMPTY_MEMORY_INDEX_READER);
566     }
567
568     @Override
569     public TermEnum terms(final Term t) throws IOException {
570       // only set first fieldname, maybe use a Set?
571       if (t != null && field == null)
572         field = t.field();
573       return super.terms(t);
574     }
575
576
577   }
578
579   protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
580     this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;
581   }
582
583 }