You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.5 KiB
Python

class RadixTreeNode(object):
def __init__(self, path=None, handler=None, methods=None):
self.path = path
self.methods = {}
self.children = []
self.indices = str()
self.size = 0
self.add_methods(methods, handler)
def __repr__(self):
return ('<RadixTreeNode path: {}, methods: {}, indices: {}, children: '
'{}>'.format(self.path, self.methods, self.indices,
self.children))
def add_methods(self, methods, handler):
if not methods:
return
if not isinstance(methods, (list, tuple, set)):
methods = [methods]
for method in methods:
if method in self.methods and self.methods[method] != handler:
raise KeyError(
'{} conflicts with existed handler '
'{}'.format(handler, self.methods[method]))
self.methods[method] = handler
def bisect(self, target):
low, high = 0, self.size
while low < high:
mid = low + high >> 1
if self.indices[mid] < target:
low = mid + 1
else:
high = mid
return low
def insert_child(self, index, child):
pos = self.bisect(index)
self.indices = self.indices[:pos] + index + self.indices[pos:]
self.children.insert(pos, child)
self.size += 1
return child
def get_child(self, index):
for i, char in enumerate(self.indices):
if char == index:
return self.children[i]
class RadixTree(object):
def __init__(self):
self.root = RadixTreeNode()
def __repr__(self):
return repr(self.root)
def insert(self, key, handler, methods):
i, n, root = 0, len(key), self.root
def get_position(i):
return n if i == -1 else i
while i < n:
conflict, num = [], (key[i] == ':') + (root.indices == ':')
if (root.indices == '*' or
key[i] == '*' and root.indices or
num == 1 or
num == 2 and key[i + 1:get_position(
key.find('/', i))] != root.get_child(':').path):
conflict = [key[:i] + p for p in self.traverse(root)]
if conflict:
raise Exception('"{}" conflicts with {}'.format(key, conflict))
child = root.get_child(key[i])
if child is None:
pos = get_position(key.find(':', i))
if pos == n:
pos = get_position(key.find('*', i))
if pos == n:
root.insert_child(
key[i], RadixTreeNode(key[i:], handler, methods))
return
root = root.insert_child(key[i], RadixTreeNode(key[i:pos]))
root.insert_child(
'*', RadixTreeNode(key[pos + 1:], handler, methods))
return
root = root.insert_child(key[i], RadixTreeNode(key[i:pos]))
i = get_position(key.find('/', pos))
root = root.insert_child(':', RadixTreeNode(key[pos + 1:i]))
if i == n:
root.add_methods(methods, handler)
else:
root = child
if key[i] == ':':
i += len(root.path) + 1
if i == n:
root.add_methods(methods, handler)
else:
j, m = 0, len(root.path)
while i < n and j < m and key[i] == root.path[j]:
i += 1
j += 1
if j < m:
child = RadixTreeNode(root.path[j:])
child.methods = root.methods
child.children = root.children
child.indices = root.indices
child.size = root.size
root.path = root.path[:j]
root.methods = {}
root.children = [child]
root.indices = child.path[0]
root.size = 1
if i == n:
root.add_methods(methods, handler)
def get(self, key, method):
i, n, root, params = 0, len(key), self.root, {}
while i < n:
if root.indices == ':':
root, pos = root.children[0], key.find('/', i)
if pos == -1:
pos = n
params[root.path], i = key[i:pos], pos
elif root.indices == '*':
root = root.children[0]
params[root.path] = key[i:]
break
else:
root = root.get_child(key[i])
if root is None:
return False, None, {}
pos = i + len(root.path)
if key[i:pos] != root.path:
return False, None, {}
i = pos
return True, root.methods.get(method, None), params
def traverse(self, root):
r = []
for i, char in enumerate(root.indices):
child = root.children[i]
path = '{}{}'.format(
char if char in [':', '*'] else '', child.path)
if child.methods and child.indices:
r.append([path])
r.append([path + p for p in self.traverse(child) or ['']])
return sum(r, [])