Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Created March 3, 2010 12:22
Show Gist options
  • Save pervognsen/320572 to your computer and use it in GitHub Desktop.
Save pervognsen/320572 to your computer and use it in GitHub Desktop.
def all_equal(iterable):
it = iter(iterable)
try:
first = it.next()
return all(x == first for x in it)
except StopIteration:
return True
def median(seq, key=None):
sorted_seq = sorted(seq, key=key)
return sorted_seq[len(seq) / 2]
def partition(seq, pivot, key=lambda x: x):
left, middle, right = [], [], []
for x in seq:
if key(x) < key(pivot):
left.append(x)
elif key(x) > key(pivot):
right.append(x)
else:
middle.append(x)
return left, middle, right
# ---
def hamming_distance(xs, ys):
return abs(len(xs) - len(ys)) + sum(1 for x, y in zip(xs, ys) if x != y)
# ---
def kdentry(key, value):
return (key, value)
def kdkey(entry):
return entry[0]
def kdvalue(entry):
return entry[1]
class kdleaf:
def __init__(self, entries):
self.entries = entries
def search(self, center, radius):
return (kdvalue(entry) for entry in self.entries)
class kdnode:
def __init__(self, dimension, separator, entries, left, right):
self.dimension = dimension
self.separator = separator
self.left = left
self.right = right
self.entries = entries
def search(self, center, radius):
if center[self.dimension] - radius < self.separator:
for value in self.left.search(center, radius):
yield value
if abs(center[self.dimension] - self.separator) <= radius:
for entry in self.entries:
yield kdvalue(entry)
if center[self.dimension] + radius >= self.separator:
for value in self.right.search(center, radius):
yield value
def build_kdtree(dimensions, entries, split_dimension=0):
if all_equal(kdkey(entry) for entry in entries):
return kdleaf(entries)
split_key = lambda entry: kdkey(entry)[split_dimension]
pivot = median(entries, key=split_key)
left, middle, right = partition(entries, pivot, key=split_key)
next_split_dimension = (split_dimension + 1) % dimensions
if not left and not right:
return build_kdtree(dimensions, entries, next_split_dimension)
return kdnode(split_dimension, split_key(pivot), middle,
build_kdtree(dimensions, left, next_split_dimension),
build_kdtree(dimensions, right, next_split_dimension))
# ---
def word_vector(word):
word = word.lower()
coords = [0] * 26
for char in word:
coords[ord(char) - ord('a')] += 1
return tuple(coords)
def build_word_tree(words):
return build_kdtree(26, [kdentry(word_vector(word), word) for word in words])
# TODO: Use Levenshtein distance for false-positive filtering. Hamming distance is an upper bound.
def word_search(word_tree, word, radius):
candidates = list(word_tree.search(word_vector(word), radius))
return [candidate for candidate in candidates if hamming_distance(word, candidate) <= radius]
# ---
word_tree = build_word_tree([word.strip() for word in open("/usr/share/dict/words")])
print word_search(word_tree, "book", 0)
print word_search(word_tree, "book", 1)
print word_search(word_tree, "book", 2)
print word_search(word_tree, "book", 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment