1 package org.apache.lucene.search.highlight;
4 * Licensed to the Apache Software Foundation (ASF) under one or more
5 * contributor license agreements. See the NOTICE file distributed with
6 * this work for additional information regarding copyright ownership.
7 * The ASF licenses this file to You under the Apache License, Version 2.0
8 * (the "License"); you may not use this file except in compliance with
9 * the License. You may obtain a copy of the License at
11 * http://www.apache.org/licenses/LICENSE-2.0
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
19 import java.io.IOException;
20 import java.util.ArrayList;
21 import java.util.Collection;
22 import java.util.HashMap;
23 import java.util.HashSet;
24 import java.util.Iterator;
25 import java.util.List;
29 import org.apache.lucene.analysis.CachingTokenFilter;
30 import org.apache.lucene.analysis.TokenStream;
31 import org.apache.lucene.index.FilterIndexReader;
32 import org.apache.lucene.index.IndexReader;
33 import org.apache.lucene.index.Term;
34 import org.apache.lucene.index.TermEnum;
35 import org.apache.lucene.index.memory.MemoryIndex;
36 import org.apache.lucene.search.*;
37 import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
38 import org.apache.lucene.search.spans.SpanFirstQuery;
39 import org.apache.lucene.search.spans.SpanNearQuery;
40 import org.apache.lucene.search.spans.SpanNotQuery;
41 import org.apache.lucene.search.spans.SpanOrQuery;
42 import org.apache.lucene.search.spans.SpanQuery;
43 import org.apache.lucene.search.spans.SpanTermQuery;
44 import org.apache.lucene.search.spans.Spans;
45 import org.apache.lucene.util.StringHelper;
48 * Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether
49 * {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
51 public class WeightedSpanTermExtractor {
53 private String fieldName;
54 private TokenStream tokenStream;
55 private Map<String,IndexReader> readers = new HashMap<String,IndexReader>(10);
56 private String defaultField;
57 private boolean expandMultiTermQuery;
58 private boolean cachedTokenStream;
59 private boolean wrapToCaching = true;
60 private int maxDocCharsToAnalyze;
62 public WeightedSpanTermExtractor() {
65 public WeightedSpanTermExtractor(String defaultField) {
66 if (defaultField != null) {
67 this.defaultField = StringHelper.intern(defaultField);
71 private void closeReaders() {
72 Collection<IndexReader> readerSet = readers.values();
74 for (final IndexReader reader : readerSet) {
77 } catch (IOException e) {
84 * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
87 * Query to extract Terms from
89 * Map to place created WeightedSpanTerms in
92 private void extract(Query query, Map<String,WeightedSpanTerm> terms) throws IOException {
93 if (query instanceof BooleanQuery) {
94 BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();
96 for (int i = 0; i < queryClauses.length; i++) {
97 if (!queryClauses[i].isProhibited()) {
98 extract(queryClauses[i].getQuery(), terms);
101 } else if (query instanceof PhraseQuery) {
102 PhraseQuery phraseQuery = ((PhraseQuery) query);
103 Term[] phraseQueryTerms = phraseQuery.getTerms();
104 SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
105 for (int i = 0; i < phraseQueryTerms.length; i++) {
106 clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
108 int slop = phraseQuery.getSlop();
109 int[] positions = phraseQuery.getPositions();
110 // add largest position increment to slop
111 if (positions.length > 0) {
112 int lastPos = positions[0];
114 int sz = positions.length;
115 for (int i = 1; i < sz; i++) {
116 int pos = positions[i];
117 int inc = pos - lastPos;
118 if (inc > largestInc) {
128 boolean inorder = false;
134 SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
135 sp.setBoost(query.getBoost());
136 extractWeightedSpanTerms(terms, sp);
137 } else if (query instanceof TermQuery) {
138 extractWeightedTerms(terms, query);
139 } else if (query instanceof SpanQuery) {
140 extractWeightedSpanTerms(terms, (SpanQuery) query);
141 } else if (query instanceof FilteredQuery) {
142 extract(((FilteredQuery) query).getQuery(), terms);
143 } else if (query instanceof DisjunctionMaxQuery) {
144 for (Iterator<Query> iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
145 extract(iterator.next(), terms);
147 } else if (query instanceof MultiTermQuery && expandMultiTermQuery) {
148 MultiTermQuery mtq = ((MultiTermQuery)query);
149 if(mtq.getRewriteMethod() != MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE) {
150 mtq = (MultiTermQuery) mtq.clone();
151 mtq.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
154 FakeReader fReader = new FakeReader();
155 MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE.rewrite(fReader, mtq);
156 if (fReader.field != null) {
157 IndexReader ir = getReaderForField(fReader.field);
158 extract(query.rewrite(ir), terms);
160 } else if (query instanceof MultiPhraseQuery) {
161 final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
162 final List<Term[]> termArrays = mpq.getTermArrays();
163 final int[] positions = mpq.getPositions();
164 if (positions.length > 0) {
166 int maxPosition = positions[positions.length - 1];
167 for (int i = 0; i < positions.length - 1; ++i) {
168 if (positions[i] > maxPosition) {
169 maxPosition = positions[i];
173 @SuppressWarnings("unchecked")
174 final List<SpanQuery>[] disjunctLists = new List[maxPosition + 1];
175 int distinctPositions = 0;
177 for (int i = 0; i < termArrays.size(); ++i) {
178 final Term[] termArray = termArrays.get(i);
179 List<SpanQuery> disjuncts = disjunctLists[positions[i]];
180 if (disjuncts == null) {
181 disjuncts = (disjunctLists[positions[i]] = new ArrayList<SpanQuery>(termArray.length));
184 for (int j = 0; j < termArray.length; ++j) {
185 disjuncts.add(new SpanTermQuery(termArray[j]));
189 int positionGaps = 0;
191 final SpanQuery[] clauses = new SpanQuery[distinctPositions];
192 for (int i = 0; i < disjunctLists.length; ++i) {
193 List<SpanQuery> disjuncts = disjunctLists[i];
194 if (disjuncts != null) {
195 clauses[position++] = new SpanOrQuery(disjuncts
196 .toArray(new SpanQuery[disjuncts.size()]));
202 final int slop = mpq.getSlop();
203 final boolean inorder = (slop == 0);
205 SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
206 sp.setBoost(query.getBoost());
207 extractWeightedSpanTerms(terms, sp);
213 * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>SpanQuery</code>.
216 * Map to place created WeightedSpanTerms in
218 * SpanQuery to extract Terms from
219 * @throws IOException
221 private void extractWeightedSpanTerms(Map<String,WeightedSpanTerm> terms, SpanQuery spanQuery) throws IOException {
222 Set<String> fieldNames;
224 if (fieldName == null) {
225 fieldNames = new HashSet<String>();
226 collectSpanQueryFields(spanQuery, fieldNames);
228 fieldNames = new HashSet<String>(1);
229 fieldNames.add(fieldName);
231 // To support the use of the default field name
232 if (defaultField != null) {
233 fieldNames.add(defaultField);
236 Map<String, SpanQuery> queries = new HashMap<String, SpanQuery>();
238 Set<Term> nonWeightedTerms = new HashSet<Term>();
239 final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
240 if (mustRewriteQuery) {
241 for (final String field : fieldNames) {
242 final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getReaderForField(field));
243 queries.put(field, rewrittenQuery);
244 rewrittenQuery.extractTerms(nonWeightedTerms);
247 spanQuery.extractTerms(nonWeightedTerms);
250 List<PositionSpan> spanPositions = new ArrayList<PositionSpan>();
252 for (final String field : fieldNames) {
254 IndexReader reader = getReaderForField(field);
256 if (mustRewriteQuery) {
257 spans = queries.get(field).getSpans(reader);
259 spans = spanQuery.getSpans(reader);
263 // collect span positions
264 while (spans.next()) {
265 spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
270 if (spanPositions.size() == 0) {
275 for (final Term queryTerm : nonWeightedTerms) {
277 if (fieldNameComparator(queryTerm.field())) {
278 WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());
280 if (weightedSpanTerm == null) {
281 weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
282 weightedSpanTerm.addPositionSpans(spanPositions);
283 weightedSpanTerm.positionSensitive = true;
284 terms.put(queryTerm.text(), weightedSpanTerm);
286 if (spanPositions.size() > 0) {
287 weightedSpanTerm.addPositionSpans(spanPositions);
295 * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
298 * Map to place created WeightedSpanTerms in
300 * Query to extract Terms from
301 * @throws IOException
303 private void extractWeightedTerms(Map<String,WeightedSpanTerm> terms, Query query) throws IOException {
304 Set<Term> nonWeightedTerms = new HashSet<Term>();
305 query.extractTerms(nonWeightedTerms);
307 for (final Term queryTerm : nonWeightedTerms) {
309 if (fieldNameComparator(queryTerm.field())) {
310 WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
311 terms.put(queryTerm.text(), weightedSpanTerm);
317 * Necessary to implement matches for queries against <code>defaultField</code>
319 private boolean fieldNameComparator(String fieldNameToCheck) {
320 boolean rv = fieldName == null || fieldNameToCheck == fieldName
321 || fieldNameToCheck == defaultField;
325 private IndexReader getReaderForField(String field) throws IOException {
326 if(wrapToCaching && !cachedTokenStream && !(tokenStream instanceof CachingTokenFilter)) {
327 tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
328 cachedTokenStream = true;
330 IndexReader reader = readers.get(field);
331 if (reader == null) {
332 MemoryIndex indexer = new MemoryIndex();
333 indexer.addField(field, new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
335 IndexSearcher searcher = indexer.createSearcher();
336 reader = searcher.getIndexReader();
337 readers.put(field, reader);
344 * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
351 * of text to be highlighted
352 * @return Map containing WeightedSpanTerms
353 * @throws IOException
355 public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream)
357 return getWeightedSpanTerms(query, tokenStream, null);
361 * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
368 * of text to be highlighted
370 * restricts Term's used based on field name
371 * @return Map containing WeightedSpanTerms
372 * @throws IOException
374 public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream,
375 String fieldName) throws IOException {
376 if (fieldName != null) {
377 this.fieldName = StringHelper.intern(fieldName);
379 this.fieldName = null;
382 Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
383 this.tokenStream = tokenStream;
385 extract(query, terms);
394 * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>. Uses a supplied
395 * <code>IndexReader</code> to properly weight terms (for gradient highlighting).
402 * of text to be highlighted
404 * restricts Term's used based on field name
407 * @return Map of WeightedSpanTerms with quasi tf/idf scores
408 * @throws IOException
410 public Map<String,WeightedSpanTerm> getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream, String fieldName,
411 IndexReader reader) throws IOException {
412 if (fieldName != null) {
413 this.fieldName = StringHelper.intern(fieldName);
415 this.fieldName = null;
417 this.tokenStream = tokenStream;
419 Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
420 extract(query, terms);
422 int totalNumDocs = reader.numDocs();
423 Set<String> weightedTerms = terms.keySet();
424 Iterator<String> it = weightedTerms.iterator();
427 while (it.hasNext()) {
428 WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
429 int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
430 // docFreq counts deletes
431 if(totalNumDocs < docFreq) {
432 docFreq = totalNumDocs;
434 // IDF algorithm taken from DefaultSimilarity class
435 float idf = (float) (Math.log((float) totalNumDocs / (double) (docFreq + 1)) + 1.0);
436 weightedSpanTerm.weight *= idf;
446 private void collectSpanQueryFields(SpanQuery spanQuery, Set<String> fieldNames) {
447 if (spanQuery instanceof FieldMaskingSpanQuery) {
448 collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames);
449 } else if (spanQuery instanceof SpanFirstQuery) {
450 collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames);
451 } else if (spanQuery instanceof SpanNearQuery) {
452 for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
453 collectSpanQueryFields(clause, fieldNames);
455 } else if (spanQuery instanceof SpanNotQuery) {
456 collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames);
457 } else if (spanQuery instanceof SpanOrQuery) {
458 for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
459 collectSpanQueryFields(clause, fieldNames);
462 fieldNames.add(spanQuery.getField());
466 private boolean mustRewriteQuery(SpanQuery spanQuery) {
467 if (!expandMultiTermQuery) {
468 return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
469 } else if (spanQuery instanceof FieldMaskingSpanQuery) {
470 return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery());
471 } else if (spanQuery instanceof SpanFirstQuery) {
472 return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch());
473 } else if (spanQuery instanceof SpanNearQuery) {
474 for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
475 if (mustRewriteQuery(clause)) {
480 } else if (spanQuery instanceof SpanNotQuery) {
481 SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery;
482 return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
483 } else if (spanQuery instanceof SpanOrQuery) {
484 for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
485 if (mustRewriteQuery(clause)) {
490 } else if (spanQuery instanceof SpanTermQuery) {
498 * This class makes sure that if both position sensitive and insensitive
499 * versions of the same term are added, the position insensitive one wins.
501 static private class PositionCheckingMap<K> extends HashMap<K,WeightedSpanTerm> {
504 public void putAll(Map<? extends K,? extends WeightedSpanTerm> m) {
505 for (Map.Entry<? extends K,? extends WeightedSpanTerm> entry : m.entrySet())
506 this.put(entry.getKey(), entry.getValue());
510 public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
511 WeightedSpanTerm prev = super.put(key, value);
512 if (prev == null) return prev;
513 WeightedSpanTerm prevTerm = prev;
514 WeightedSpanTerm newTerm = value;
515 if (!prevTerm.positionSensitive) {
516 newTerm.positionSensitive = false;
523 public boolean getExpandMultiTermQuery() {
524 return expandMultiTermQuery;
527 public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
528 this.expandMultiTermQuery = expandMultiTermQuery;
531 public boolean isCachedTokenStream() {
532 return cachedTokenStream;
535 public TokenStream getTokenStream() {
540 * By default, {@link TokenStream}s that are not of the type
541 * {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
542 * ensure an efficient reset - if you are already using a different caching
543 * {@link TokenStream} impl and you don't want it to be wrapped, set this to
548 public void setWrapIfNotCachingTokenFilter(boolean wrap) {
549 this.wrapToCaching = wrap;
554 * A fake IndexReader class to extract the field from a MultiTermQuery
557 static final class FakeReader extends FilterIndexReader {
559 private static final IndexReader EMPTY_MEMORY_INDEX_READER =
560 new MemoryIndex().createSearcher().getIndexReader();
565 super(EMPTY_MEMORY_INDEX_READER);
569 public TermEnum terms(final Term t) throws IOException {
570 // only set first fieldname, maybe use a Set?
571 if (t != null && field == null)
573 return super.terms(t);
579 protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
580 this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;