pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / highlighter / src / java / org / apache / lucene / search / highlight / WeightedSpanTermExtractor.java
diff --git a/lucene-java-3.5.0/lucene/contrib/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java b/lucene-java-3.5.0/lucene/contrib/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java
new file mode 100644 (file)
index 0000000..c81390c
--- /dev/null
@@ -0,0 +1,583 @@
+package org.apache.lucene.search.highlight;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.lucene.analysis.CachingTokenFilter;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.index.FilterIndexReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.TermEnum;
+import org.apache.lucene.index.memory.MemoryIndex;
+import org.apache.lucene.search.*;
+import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
+import org.apache.lucene.search.spans.SpanFirstQuery;
+import org.apache.lucene.search.spans.SpanNearQuery;
+import org.apache.lucene.search.spans.SpanNotQuery;
+import org.apache.lucene.search.spans.SpanOrQuery;
+import org.apache.lucene.search.spans.SpanQuery;
+import org.apache.lucene.search.spans.SpanTermQuery;
+import org.apache.lucene.search.spans.Spans;
+import org.apache.lucene.util.StringHelper;
+
+/**
+ * Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether 
+ * {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
+ */
+public class WeightedSpanTermExtractor {
+
+  private String fieldName;
+  private TokenStream tokenStream;
+  private Map<String,IndexReader> readers = new HashMap<String,IndexReader>(10); 
+  private String defaultField;
+  private boolean expandMultiTermQuery;
+  private boolean cachedTokenStream;
+  private boolean wrapToCaching = true;
+  private int maxDocCharsToAnalyze;
+
+  public WeightedSpanTermExtractor() {
+  }
+
+  public WeightedSpanTermExtractor(String defaultField) {
+    if (defaultField != null) {
+      this.defaultField = StringHelper.intern(defaultField);
+    }
+  }
+
+  private void closeReaders() {
+    Collection<IndexReader> readerSet = readers.values();
+
+    for (final IndexReader reader : readerSet) {
+      try {
+        reader.close();
+      } catch (IOException e) {
+        // alert?
+      }
+    }
+  }
+
+  /**
+   * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
+   * 
+   * @param query
+   *          Query to extract Terms from
+   * @param terms
+   *          Map to place created WeightedSpanTerms in
+   * @throws IOException
+   */
+  private void extract(Query query, Map<String,WeightedSpanTerm> terms) throws IOException {
+    if (query instanceof BooleanQuery) {
+      BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();
+
+      for (int i = 0; i < queryClauses.length; i++) {
+        if (!queryClauses[i].isProhibited()) {
+          extract(queryClauses[i].getQuery(), terms);
+        }
+      }
+    } else if (query instanceof PhraseQuery) {
+      PhraseQuery phraseQuery = ((PhraseQuery) query);
+      Term[] phraseQueryTerms = phraseQuery.getTerms();
+      SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
+      for (int i = 0; i < phraseQueryTerms.length; i++) {
+        clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
+      }
+      int slop = phraseQuery.getSlop();
+      int[] positions = phraseQuery.getPositions();
+      // add largest position increment to slop
+      if (positions.length > 0) {
+        int lastPos = positions[0];
+        int largestInc = 0;
+        int sz = positions.length;
+        for (int i = 1; i < sz; i++) {
+          int pos = positions[i];
+          int inc = pos - lastPos;
+          if (inc > largestInc) {
+            largestInc = inc;
+          }
+          lastPos = pos;
+        }
+        if(largestInc > 1) {
+          slop += largestInc;
+        }
+      }
+
+      boolean inorder = false;
+
+      if (slop == 0) {
+        inorder = true;
+      }
+
+      SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
+      sp.setBoost(query.getBoost());
+      extractWeightedSpanTerms(terms, sp);
+    } else if (query instanceof TermQuery) {
+      extractWeightedTerms(terms, query);
+    } else if (query instanceof SpanQuery) {
+      extractWeightedSpanTerms(terms, (SpanQuery) query);
+    } else if (query instanceof FilteredQuery) {
+      extract(((FilteredQuery) query).getQuery(), terms);
+    } else if (query instanceof DisjunctionMaxQuery) {
+      for (Iterator<Query> iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
+        extract(iterator.next(), terms);
+      }
+    } else if (query instanceof MultiTermQuery && expandMultiTermQuery) {
+      MultiTermQuery mtq = ((MultiTermQuery)query);
+      if(mtq.getRewriteMethod() != MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE) {
+        mtq = (MultiTermQuery) mtq.clone();
+        mtq.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
+        query = mtq;
+      }
+      FakeReader fReader = new FakeReader();
+      MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE.rewrite(fReader, mtq);
+      if (fReader.field != null) {
+        IndexReader ir = getReaderForField(fReader.field);
+        extract(query.rewrite(ir), terms);
+      }
+    } else if (query instanceof MultiPhraseQuery) {
+      final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
+      final List<Term[]> termArrays = mpq.getTermArrays();
+      final int[] positions = mpq.getPositions();
+      if (positions.length > 0) {
+
+        int maxPosition = positions[positions.length - 1];
+        for (int i = 0; i < positions.length - 1; ++i) {
+          if (positions[i] > maxPosition) {
+            maxPosition = positions[i];
+          }
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<SpanQuery>[] disjunctLists = new List[maxPosition + 1];
+        int distinctPositions = 0;
+
+        for (int i = 0; i < termArrays.size(); ++i) {
+          final Term[] termArray = termArrays.get(i);
+          List<SpanQuery> disjuncts = disjunctLists[positions[i]];
+          if (disjuncts == null) {
+            disjuncts = (disjunctLists[positions[i]] = new ArrayList<SpanQuery>(termArray.length));
+            ++distinctPositions;
+          }
+          for (int j = 0; j < termArray.length; ++j) {
+            disjuncts.add(new SpanTermQuery(termArray[j]));
+          }
+        }
+
+        int positionGaps = 0;
+        int position = 0;
+        final SpanQuery[] clauses = new SpanQuery[distinctPositions];
+        for (int i = 0; i < disjunctLists.length; ++i) {
+          List<SpanQuery> disjuncts = disjunctLists[i];
+          if (disjuncts != null) {
+            clauses[position++] = new SpanOrQuery(disjuncts
+                .toArray(new SpanQuery[disjuncts.size()]));
+          } else {
+            ++positionGaps;
+          }
+        }
+
+        final int slop = mpq.getSlop();
+        final boolean inorder = (slop == 0);
+
+        SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
+        sp.setBoost(query.getBoost());
+        extractWeightedSpanTerms(terms, sp);
+      }
+    }
+  }
+
+  /**
+   * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>SpanQuery</code>.
+   * 
+   * @param terms
+   *          Map to place created WeightedSpanTerms in
+   * @param spanQuery
+   *          SpanQuery to extract Terms from
+   * @throws IOException
+   */
+  private void extractWeightedSpanTerms(Map<String,WeightedSpanTerm> terms, SpanQuery spanQuery) throws IOException {
+    Set<String> fieldNames;
+
+    if (fieldName == null) {
+      fieldNames = new HashSet<String>();
+      collectSpanQueryFields(spanQuery, fieldNames);
+    } else {
+      fieldNames = new HashSet<String>(1);
+      fieldNames.add(fieldName);
+    }
+    // To support the use of the default field name
+    if (defaultField != null) {
+      fieldNames.add(defaultField);
+    }
+    
+    Map<String, SpanQuery> queries = new HashMap<String, SpanQuery>();
+    Set<Term> nonWeightedTerms = new HashSet<Term>();
+    final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
+    if (mustRewriteQuery) {
+      for (final String field : fieldNames) {
+        final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getReaderForField(field));
+        queries.put(field, rewrittenQuery);
+        rewrittenQuery.extractTerms(nonWeightedTerms);
+      }
+    } else {
+      spanQuery.extractTerms(nonWeightedTerms);
+    }
+
+    List<PositionSpan> spanPositions = new ArrayList<PositionSpan>();
+
+    for (final String field : fieldNames) {
+
+      IndexReader reader = getReaderForField(field);
+      final Spans spans;
+      if (mustRewriteQuery) {
+        spans = queries.get(field).getSpans(reader);
+      } else {
+        spans = spanQuery.getSpans(reader);
+      }
+
+
+      // collect span positions
+      while (spans.next()) {
+        spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
+      }
+      
+    }
+
+    if (spanPositions.size() == 0) {
+      // no spans found
+      return;
+    }
+
+    for (final Term queryTerm :  nonWeightedTerms) {
+
+      if (fieldNameComparator(queryTerm.field())) {
+        WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());
+
+        if (weightedSpanTerm == null) {
+          weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
+          weightedSpanTerm.addPositionSpans(spanPositions);
+          weightedSpanTerm.positionSensitive = true;
+          terms.put(queryTerm.text(), weightedSpanTerm);
+        } else {
+          if (spanPositions.size() > 0) {
+            weightedSpanTerm.addPositionSpans(spanPositions);
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
+   * 
+   * @param terms
+   *          Map to place created WeightedSpanTerms in
+   * @param query
+   *          Query to extract Terms from
+   * @throws IOException
+   */
+  private void extractWeightedTerms(Map<String,WeightedSpanTerm> terms, Query query) throws IOException {
+    Set<Term> nonWeightedTerms = new HashSet<Term>();
+    query.extractTerms(nonWeightedTerms);
+
+    for (final Term queryTerm : nonWeightedTerms) {
+
+      if (fieldNameComparator(queryTerm.field())) {
+        WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
+        terms.put(queryTerm.text(), weightedSpanTerm);
+      }
+    }
+  }
+
+  /**
+   * Necessary to implement matches for queries against <code>defaultField</code>
+   */
+  private boolean fieldNameComparator(String fieldNameToCheck) {
+    boolean rv = fieldName == null || fieldNameToCheck == fieldName
+        || fieldNameToCheck == defaultField;
+    return rv;
+  }
+
+  private IndexReader getReaderForField(String field) throws IOException {
+    if(wrapToCaching && !cachedTokenStream && !(tokenStream instanceof CachingTokenFilter)) {
+      tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
+      cachedTokenStream = true;
+    }
+    IndexReader reader = readers.get(field);
+    if (reader == null) {
+      MemoryIndex indexer = new MemoryIndex();
+      indexer.addField(field, new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
+      tokenStream.reset();
+      IndexSearcher searcher = indexer.createSearcher();
+      reader = searcher.getIndexReader();
+      readers.put(field, reader);
+    }
+
+    return reader;
+  }
+
+  /**
+   * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
+   * 
+   * <p>
+   * 
+   * @param query
+   *          that caused hit
+   * @param tokenStream
+   *          of text to be highlighted
+   * @return Map containing WeightedSpanTerms
+   * @throws IOException
+   */
+  public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream)
+      throws IOException {
+    return getWeightedSpanTerms(query, tokenStream, null);
+  }
+
+  /**
+   * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
+   * 
+   * <p>
+   * 
+   * @param query
+   *          that caused hit
+   * @param tokenStream
+   *          of text to be highlighted
+   * @param fieldName
+   *          restricts Term's used based on field name
+   * @return Map containing WeightedSpanTerms
+   * @throws IOException
+   */
+  public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream,
+      String fieldName) throws IOException {
+    if (fieldName != null) {
+      this.fieldName = StringHelper.intern(fieldName);
+    } else {
+      this.fieldName = null;
+    }
+
+    Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
+    this.tokenStream = tokenStream;
+    try {
+      extract(query, terms);
+    } finally {
+      closeReaders();
+    }
+
+    return terms;
+  }
+
+  /**
+   * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>. Uses a supplied
+   * <code>IndexReader</code> to properly weight terms (for gradient highlighting).
+   * 
+   * <p>
+   * 
+   * @param query
+   *          that caused hit
+   * @param tokenStream
+   *          of text to be highlighted
+   * @param fieldName
+   *          restricts Term's used based on field name
+   * @param reader
+   *          to use for scoring
+   * @return Map of WeightedSpanTerms with quasi tf/idf scores
+   * @throws IOException
+   */
+  public Map<String,WeightedSpanTerm> getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream, String fieldName,
+      IndexReader reader) throws IOException {
+    if (fieldName != null) {
+      this.fieldName = StringHelper.intern(fieldName);
+    } else {
+      this.fieldName = null;
+    }
+    this.tokenStream = tokenStream;
+
+    Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
+    extract(query, terms);
+
+    int totalNumDocs = reader.numDocs();
+    Set<String> weightedTerms = terms.keySet();
+    Iterator<String> it = weightedTerms.iterator();
+
+    try {
+      while (it.hasNext()) {
+        WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
+        int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
+        // docFreq counts deletes
+        if(totalNumDocs < docFreq) {
+          docFreq = totalNumDocs;
+        }
+        // IDF algorithm taken from DefaultSimilarity class
+        float idf = (float) (Math.log((float) totalNumDocs / (double) (docFreq + 1)) + 1.0);
+        weightedSpanTerm.weight *= idf;
+      }
+    } finally {
+
+      closeReaders();
+    }
+
+    return terms;
+  }
+  
+  private void collectSpanQueryFields(SpanQuery spanQuery, Set<String> fieldNames) {
+    if (spanQuery instanceof FieldMaskingSpanQuery) {
+      collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames);
+    } else if (spanQuery instanceof SpanFirstQuery) {
+      collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames);
+    } else if (spanQuery instanceof SpanNearQuery) {
+      for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
+        collectSpanQueryFields(clause, fieldNames);
+      }
+    } else if (spanQuery instanceof SpanNotQuery) {
+      collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames);
+    } else if (spanQuery instanceof SpanOrQuery) {
+      for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
+        collectSpanQueryFields(clause, fieldNames);
+      }
+    } else {
+      fieldNames.add(spanQuery.getField());
+    }
+  }
+  
+  private boolean mustRewriteQuery(SpanQuery spanQuery) {
+    if (!expandMultiTermQuery) {
+      return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
+    } else if (spanQuery instanceof FieldMaskingSpanQuery) {
+      return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery());
+    } else if (spanQuery instanceof SpanFirstQuery) {
+      return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch());
+    } else if (spanQuery instanceof SpanNearQuery) {
+      for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
+        if (mustRewriteQuery(clause)) {
+          return true;
+        }
+      }
+      return false; 
+    } else if (spanQuery instanceof SpanNotQuery) {
+      SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery;
+      return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
+    } else if (spanQuery instanceof SpanOrQuery) {
+      for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
+        if (mustRewriteQuery(clause)) {
+          return true;
+        }
+      }
+      return false; 
+    } else if (spanQuery instanceof SpanTermQuery) {
+      return false;
+    } else {
+      return true;
+    }
+  }
+  
+  /**
+   * This class makes sure that if both position sensitive and insensitive
+   * versions of the same term are added, the position insensitive one wins.
+   */
+  static private class PositionCheckingMap<K> extends HashMap<K,WeightedSpanTerm> {
+
+    @Override
+    public void putAll(Map<? extends K,? extends WeightedSpanTerm> m) {
+      for (Map.Entry<? extends K,? extends WeightedSpanTerm> entry : m.entrySet())
+        this.put(entry.getKey(), entry.getValue());
+    }
+
+    @Override
+    public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
+      WeightedSpanTerm prev = super.put(key, value);
+      if (prev == null) return prev;
+      WeightedSpanTerm prevTerm = prev;
+      WeightedSpanTerm newTerm = value;
+      if (!prevTerm.positionSensitive) {
+        newTerm.positionSensitive = false;
+      }
+      return prev;
+    }
+    
+  }
+  
+  public boolean getExpandMultiTermQuery() {
+    return expandMultiTermQuery;
+  }
+
+  public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
+    this.expandMultiTermQuery = expandMultiTermQuery;
+  }
+  
+  public boolean isCachedTokenStream() {
+    return cachedTokenStream;
+  }
+  
+  public TokenStream getTokenStream() {
+    return tokenStream;
+  }
+  
+  /**
+   * By default, {@link TokenStream}s that are not of the type
+   * {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
+   * ensure an efficient reset - if you are already using a different caching
+   * {@link TokenStream} impl and you don't want it to be wrapped, set this to
+   * false.
+   * 
+   * @param wrap
+   */
+  public void setWrapIfNotCachingTokenFilter(boolean wrap) {
+    this.wrapToCaching = wrap;
+  }
+  
+  /**
+   * 
+   * A fake IndexReader class to extract the field from a MultiTermQuery
+   * 
+   */
+  static final class FakeReader extends FilterIndexReader {
+
+    private static final IndexReader EMPTY_MEMORY_INDEX_READER =
+      new MemoryIndex().createSearcher().getIndexReader();
+    
+    String field;
+
+    FakeReader() {
+      super(EMPTY_MEMORY_INDEX_READER);
+    }
+
+    @Override
+    public TermEnum terms(final Term t) throws IOException {
+      // only set first fieldname, maybe use a Set?
+      if (t != null && field == null)
+        field = t.field();
+      return super.terms(t);
+    }
+
+
+  }
+
+  protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
+    this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;
+  }
+
+}