|
|
- from __future__ import absolute_import, division, unicode_literals
- from pip._vendor.six import text_type
-
- from bisect import bisect_left
-
- from ._base import Trie as ABCTrie
-
-
- class Trie(ABCTrie):
- def __init__(self, data):
- if not all(isinstance(x, text_type) for x in data.keys()):
- raise TypeError("All keys must be strings")
-
- self._data = data
- self._keys = sorted(data.keys())
- self._cachestr = ""
- self._cachepoints = (0, len(data))
-
- def __contains__(self, key):
- return key in self._data
-
- def __len__(self):
- return len(self._data)
-
- def __iter__(self):
- return iter(self._data)
-
- def __getitem__(self, key):
- return self._data[key]
-
- def keys(self, prefix=None):
- if prefix is None or prefix == "" or not self._keys:
- return set(self._keys)
-
- if prefix.startswith(self._cachestr):
- lo, hi = self._cachepoints
- start = i = bisect_left(self._keys, prefix, lo, hi)
- else:
- start = i = bisect_left(self._keys, prefix)
-
- keys = set()
- if start == len(self._keys):
- return keys
-
- while self._keys[i].startswith(prefix):
- keys.add(self._keys[i])
- i += 1
-
- self._cachestr = prefix
- self._cachepoints = (start, i)
-
- return keys
-
- def has_keys_with_prefix(self, prefix):
- if prefix in self._data:
- return True
-
- if prefix.startswith(self._cachestr):
- lo, hi = self._cachepoints
- i = bisect_left(self._keys, prefix, lo, hi)
- else:
- i = bisect_left(self._keys, prefix)
-
- if i == len(self._keys):
- return False
-
- return self._keys[i].startswith(prefix)
|