pylucene 3.5.0-3
[pylucene.git] / lucene-java-3.5.0 / lucene / contrib / grouping / src / java / org / apache / lucene / search / grouping / SearchGroup.java
diff --git a/lucene-java-3.5.0/lucene/contrib/grouping/src/java/org/apache/lucene/search/grouping/SearchGroup.java b/lucene-java-3.5.0/lucene/contrib/grouping/src/java/org/apache/lucene/search/grouping/SearchGroup.java
new file mode 100644 (file)
index 0000000..ce8ffd2
--- /dev/null
@@ -0,0 +1,343 @@
+package org.apache.lucene.search.grouping;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.search.FieldComparator;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.SortField;
+
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * Represents a group that is found during the first pass search.
+ *
+ * @lucene.experimental
+ */
+public class SearchGroup<GROUP_VALUE_TYPE> {
+
+  /** The value that defines this group  */
+  public GROUP_VALUE_TYPE groupValue;
+
+  /** The sort values used during sorting. These are the
+   *  groupSort field values of the highest rank document
+   *  (by the groupSort) within the group.  Can be
+   * <code>null</code> if <code>fillFields=false</code> had
+   * been passed to {@link AbstractFirstPassGroupingCollector#getTopGroups} */
+  public Object[] sortValues;
+
+  @Override
+  public String toString() {
+    return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    SearchGroup that = (SearchGroup) o;
+
+    if (groupValue == null) {
+      if (that.groupValue != null) {
+        return false;
+      }
+    } else if (!groupValue.equals(that.groupValue)) {
+      return false;
+    }
+
+    return true;
+  }
+
+  @Override
+  public int hashCode() {
+    return groupValue != null ? groupValue.hashCode() : 0;
+  }
+
+  private static class ShardIter<T> {
+    public final Iterator<SearchGroup<T>> iter;
+    public final int shardIndex;
+
+    public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
+      this.shardIndex = shardIndex;
+      iter = shard.iterator();
+      assert iter.hasNext();
+    }
+
+    public SearchGroup<T> next() {
+      assert iter.hasNext();
+      final SearchGroup<T> group = iter.next();
+      if (group.sortValues == null) {
+        throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
+      }
+      return group;
+    }
+    
+    @Override
+    public String toString() {
+      return "ShardIter(shard=" + shardIndex + ")";
+    }
+  }
+
+  // Holds all shards currently on the same group
+  private static class MergedGroup<T> {
+
+    // groupValue may be null!
+    public final T groupValue;
+
+    public Object[] topValues;
+    public final List<ShardIter<T>> shards = new ArrayList<ShardIter<T>>();
+    public int minShardIndex;
+    public boolean processed;
+    public boolean inQueue;
+
+    public MergedGroup(T groupValue) {
+      this.groupValue = groupValue;
+    }
+
+    // Only for assert
+    private boolean neverEquals(Object _other) {
+      if (_other instanceof MergedGroup) {
+        MergedGroup other = (MergedGroup) _other;
+        if (groupValue == null) {
+          assert other.groupValue != null;
+        } else {
+          assert !groupValue.equals(other.groupValue);
+        }
+      }
+      return true;
+    }
+
+    @Override
+    public boolean equals(Object _other) {
+      // We never have another MergedGroup instance with
+      // same groupValue
+      assert neverEquals(_other);
+
+      if (_other instanceof MergedGroup) {
+        MergedGroup other = (MergedGroup) _other;
+        if (groupValue == null) {
+          return other == null;
+        } else {
+          return groupValue.equals(other);
+        }
+      } else {
+        return false;
+      }
+    }
+
+    @Override
+    public int hashCode() {
+      if (groupValue == null) {
+        return 0;
+      } else {
+        return groupValue.hashCode();
+      }
+    }
+  }
+
+  private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
+
+    public final FieldComparator[] comparators;
+    public final int[] reversed;
+
+    public GroupComparator(Sort groupSort) throws IOException {
+      final SortField[] sortFields = groupSort.getSort();
+      comparators = new FieldComparator[sortFields.length];
+      reversed = new int[sortFields.length];
+      for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
+        final SortField sortField = sortFields[compIDX];
+        comparators[compIDX] = sortField.getComparator(1, compIDX);
+        reversed[compIDX] = sortField.getReverse() ? -1 : 1;
+      }
+    }
+
+    @SuppressWarnings("unchecked")
+    public int compare(MergedGroup<T> group, MergedGroup<T> other) {
+      if (group == other) {
+        return 0;
+      }
+      //System.out.println("compare group=" + group + " other=" + other);
+      final Object[] groupValues = group.topValues;
+      final Object[] otherValues = other.topValues;
+      //System.out.println("  groupValues=" + groupValues + " otherValues=" + otherValues);
+      for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
+        final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
+                                                                             otherValues[compIDX]);
+        if (c != 0) {
+          return c;
+        }
+      }
+
+      // Tie break by min shard index:
+      assert group.minShardIndex != other.minShardIndex;
+      return group.minShardIndex - other.minShardIndex;
+    }
+  }
+
+  private static class GroupMerger<T> {
+
+    private final GroupComparator<T> groupComp;
+    private final SortedSet<MergedGroup<T>> queue;
+    private final Map<T,MergedGroup<T>> groupsSeen;
+
+    public GroupMerger(Sort groupSort) throws IOException {
+      groupComp = new GroupComparator<T>(groupSort);
+      queue = new TreeSet<MergedGroup<T>>(groupComp);
+      groupsSeen = new HashMap<T,MergedGroup<T>>();
+    }
+
+    @SuppressWarnings("unchecked")
+    private void updateNextGroup(int topN, ShardIter<T> shard) {
+      while(shard.iter.hasNext()) {
+        final SearchGroup<T> group = shard.next();
+        MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
+        final boolean isNew = mergedGroup == null;
+        //System.out.println("    next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
+
+        if (isNew) {
+          // Start a new group:
+          //System.out.println("      new");
+          mergedGroup = new MergedGroup<T>(group.groupValue);
+          mergedGroup.minShardIndex = shard.shardIndex;
+          assert group.sortValues != null;
+          mergedGroup.topValues = group.sortValues;
+          groupsSeen.put(group.groupValue, mergedGroup);
+          mergedGroup.inQueue = true;
+          queue.add(mergedGroup);
+        } else if (mergedGroup.processed) {
+          // This shard produced a group that we already
+          // processed; move on to next group...
+          continue;
+        } else {
+          //System.out.println("      old");
+          boolean competes = false;
+          for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
+            final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
+                                                                                                       mergedGroup.topValues[compIDX]);
+            if (cmp < 0) {
+              // Definitely competes
+              competes = true;
+              break;
+            } else if (cmp > 0) {
+              // Definitely does not compete
+              break;
+            } else if (compIDX == groupComp.comparators.length-1) {
+              if (shard.shardIndex < mergedGroup.minShardIndex) {
+                competes = true;
+              }
+            }
+          }
+
+          //System.out.println("      competes=" + competes);
+
+          if (competes) {
+            // Group's sort changed -- remove & re-insert
+            if (mergedGroup.inQueue) {
+              queue.remove(mergedGroup);
+            }
+            mergedGroup.topValues = group.sortValues;
+            mergedGroup.minShardIndex = shard.shardIndex;
+            queue.add(mergedGroup);
+            mergedGroup.inQueue = true;
+          }
+        }
+
+        mergedGroup.shards.add(shard);
+        break;
+      }
+
+      // Prune un-competitive groups:
+      while(queue.size() > topN) {
+        // TODO java 1.6: .pollLast
+        final MergedGroup<T> group = queue.last();
+        //System.out.println("PRUNE: " + group);
+        queue.remove(group);
+        group.inQueue = false;
+      }
+    }
+
+    public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
+
+      final int maxQueueSize = offset + topN;
+
+      //System.out.println("merge");
+      // Init queue:
+      for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
+        final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
+        if (!shard.isEmpty()) {
+          //System.out.println("  insert shard=" + shardIDX);
+          updateNextGroup(maxQueueSize, new ShardIter<T>(shard, shardIDX));
+        }
+      }
+
+      // Pull merged topN groups:
+      final List<SearchGroup<T>> newTopGroups = new ArrayList<SearchGroup<T>>();
+
+      int count = 0;
+
+      while(queue.size() != 0) {
+        // TODO Java 1.6: pollFirst()
+        final MergedGroup<T> group = queue.first();
+        queue.remove(group);
+        group.processed = true;
+        //System.out.println("  pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
+        if (count++ >= offset) {
+          final SearchGroup<T> newGroup = new SearchGroup<T>();
+          newGroup.groupValue = group.groupValue;
+          newGroup.sortValues = group.topValues;
+          newTopGroups.add(newGroup);
+          if (newTopGroups.size() == topN) {
+            break;
+          }
+        //} else {
+        // System.out.println("    skip < offset");
+        }
+
+        // Advance all iters in this group:
+        for(ShardIter<T> shardIter : group.shards) {
+          updateNextGroup(maxQueueSize, shardIter);
+        }
+      }
+
+      if (newTopGroups.size() == 0) {
+        return null;
+      } else {
+        return newTopGroups;
+      }
+    }
+  }
+
+  /** Merges multiple collections of top groups, for example
+   *  obtained from separate index shards.  The provided
+   *  groupSort must match how the groups were sorted, and
+   *  the provided SearchGroups must have been computed
+   *  with fillFields=true passed to {@link
+   *  AbstractFirstPassGroupingCollector#getTopGroups}.
+   *
+   * <p>NOTE: this returns null if the topGroups is empty.
+   */
+  public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort)
+    throws IOException {
+    if (topGroups.size() == 0) {
+      return null;
+    } else {
+      return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);
+    }
+  }
+}