1 package org.apache.lucene.search.suggest.tst;
3 import java.io.DataInputStream;
4 import java.io.DataOutputStream;
6 import java.io.FileInputStream;
7 import java.io.FileOutputStream;
8 import java.io.IOException;
9 import java.util.ArrayList;
10 import java.util.List;
12 import org.apache.lucene.search.suggest.Lookup;
13 import org.apache.lucene.search.suggest.SortedTermFreqIteratorWrapper;
14 import org.apache.lucene.search.spell.SortedIterator;
15 import org.apache.lucene.search.spell.TermFreqIterator;
17 public class TSTLookup extends Lookup {
18 TernaryTreeNode root = new TernaryTreeNode();
19 TSTAutocomplete autocomplete = new TSTAutocomplete();
22 public void build(TermFreqIterator tfit) throws IOException {
23 root = new TernaryTreeNode();
25 if (!(tfit instanceof SortedIterator)) {
26 // make sure it's sorted
27 tfit = new SortedTermFreqIteratorWrapper(tfit);
30 ArrayList<String> tokens = new ArrayList<String>();
31 ArrayList<Float> vals = new ArrayList<Float>();
32 while (tfit.hasNext()) {
33 tokens.add(tfit.next());
34 vals.add(new Float(tfit.freq()));
36 autocomplete.balancedTree(tokens.toArray(), vals.toArray(), 0, tokens.size() - 1, root);
40 public boolean add(String key, Object value) {
41 autocomplete.insert(root, key, value, 0);
42 // XXX we don't know if a new node was created
47 public Object get(String key) {
48 List<TernaryTreeNode> list = autocomplete.prefixCompletion(root, key, 0);
49 if (list == null || list.isEmpty()) {
52 for (TernaryTreeNode n : list) {
53 if (n.token.equals(key)) {
61 public List<LookupResult> lookup(String key, boolean onlyMorePopular, int num) {
62 List<TernaryTreeNode> list = autocomplete.prefixCompletion(root, key, 0);
63 List<LookupResult> res = new ArrayList<LookupResult>();
64 if (list == null || list.size() == 0) {
67 int maxCnt = Math.min(num, list.size());
68 if (onlyMorePopular) {
69 LookupPriorityQueue queue = new LookupPriorityQueue(num);
70 for (TernaryTreeNode ttn : list) {
71 queue.insertWithOverflow(new LookupResult(ttn.token, (Float)ttn.val));
73 for (LookupResult lr : queue.getResults()) {
77 for (int i = 0; i < maxCnt; i++) {
78 TernaryTreeNode ttn = list.get(i);
79 res.add(new LookupResult(ttn.token, (Float)ttn.val));
85 public static final String FILENAME = "tst.dat";
87 private static final byte LO_KID = 0x01;
88 private static final byte EQ_KID = 0x02;
89 private static final byte HI_KID = 0x04;
90 private static final byte HAS_TOKEN = 0x08;
91 private static final byte HAS_VALUE = 0x10;
94 public synchronized boolean load(File storeDir) throws IOException {
95 File data = new File(storeDir, FILENAME);
96 if (!data.exists() || !data.canRead()) {
99 DataInputStream in = new DataInputStream(new FileInputStream(data));
100 root = new TernaryTreeNode();
102 readRecursively(in, root);
109 // pre-order traversal
110 private void readRecursively(DataInputStream in, TernaryTreeNode node) throws IOException {
111 node.splitchar = in.readChar();
112 byte mask = in.readByte();
113 if ((mask & HAS_TOKEN) != 0) {
114 node.token = in.readUTF();
116 if ((mask & HAS_VALUE) != 0) {
117 node.val = new Float(in.readFloat());
119 if ((mask & LO_KID) != 0) {
120 node.loKid = new TernaryTreeNode();
121 readRecursively(in, node.loKid);
123 if ((mask & EQ_KID) != 0) {
124 node.eqKid = new TernaryTreeNode();
125 readRecursively(in, node.eqKid);
127 if ((mask & HI_KID) != 0) {
128 node.hiKid = new TernaryTreeNode();
129 readRecursively(in, node.hiKid);
134 public synchronized boolean store(File storeDir) throws IOException {
135 if (!storeDir.exists() || !storeDir.isDirectory() || !storeDir.canWrite()) {
138 File data = new File(storeDir, FILENAME);
139 DataOutputStream out = new DataOutputStream(new FileOutputStream(data));
141 writeRecursively(out, root);
149 // pre-order traversal
150 private void writeRecursively(DataOutputStream out, TernaryTreeNode node) throws IOException {
151 // write out the current node
152 out.writeChar(node.splitchar);
153 // prepare a mask of kids
155 if (node.eqKid != null) mask |= EQ_KID;
156 if (node.loKid != null) mask |= LO_KID;
157 if (node.hiKid != null) mask |= HI_KID;
158 if (node.token != null) mask |= HAS_TOKEN;
159 if (node.val != null) mask |= HAS_VALUE;
161 if (node.token != null) out.writeUTF(node.token);
162 if (node.val != null) out.writeFloat((Float)node.val);
163 // recurse and write kids
164 if (node.loKid != null) {
165 writeRecursively(out, node.loKid);
167 if (node.eqKid != null) {
168 writeRecursively(out, node.eqKid);
170 if (node.hiKid != null) {
171 writeRecursively(out, node.hiKid);