You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

492 lines
14 KiB

4 years ago
  1. from array import array
  2. from collections import abc
  3. import sys
  4. from ._abc import MultiMapping, MutableMultiMapping
  5. _marker = object()
  6. class istr(str):
  7. """Case insensitive str."""
  8. __is_istr__ = True
  9. def __new__(cls, val='',
  10. encoding=sys.getdefaultencoding(), errors='strict'):
  11. if getattr(val, '__is_istr__', False):
  12. # Faster than instance check
  13. return val
  14. if type(val) is str:
  15. pass
  16. else:
  17. val = str(val)
  18. val = val.title()
  19. return str.__new__(cls, val)
  20. def title(self):
  21. return self
  22. upstr = istr # for relaxing backward compatibility problems
  23. def getversion(md):
  24. if not isinstance(md, _Base):
  25. raise TypeError("Parameter should be multidict or proxy")
  26. return md._impl._version
  27. _version = array('Q', [0])
  28. class _Impl:
  29. __slots__ = ('_items', '_version')
  30. def __init__(self):
  31. self._items = []
  32. self.incr_version()
  33. def incr_version(self):
  34. global _version
  35. v = _version
  36. v[0] += 1
  37. self._version = v[0]
  38. class _Base:
  39. def _title(self, key):
  40. return key
  41. def getall(self, key, default=_marker):
  42. """Return a list of all values matching the key."""
  43. identity = self._title(key)
  44. res = [v for i, k, v in self._impl._items if i == identity]
  45. if res:
  46. return res
  47. if not res and default is not _marker:
  48. return default
  49. raise KeyError('Key not found: %r' % key)
  50. def getone(self, key, default=_marker):
  51. """Get first value matching the key."""
  52. identity = self._title(key)
  53. for i, k, v in self._impl._items:
  54. if i == identity:
  55. return v
  56. if default is not _marker:
  57. return default
  58. raise KeyError('Key not found: %r' % key)
  59. # Mapping interface #
  60. def __getitem__(self, key):
  61. return self.getone(key)
  62. def get(self, key, default=None):
  63. """Get first value matching the key.
  64. The method is alias for .getone().
  65. """
  66. return self.getone(key, default)
  67. def __iter__(self):
  68. return iter(self.keys())
  69. def __len__(self):
  70. return len(self._impl._items)
  71. def keys(self):
  72. """Return a new view of the dictionary's keys."""
  73. return _KeysView(self._impl)
  74. def items(self):
  75. """Return a new view of the dictionary's items *(key, value) pairs)."""
  76. return _ItemsView(self._impl)
  77. def values(self):
  78. """Return a new view of the dictionary's values."""
  79. return _ValuesView(self._impl)
  80. def __eq__(self, other):
  81. if not isinstance(other, abc.Mapping):
  82. return NotImplemented
  83. if isinstance(other, _Base):
  84. lft = self._impl._items
  85. rht = other._impl._items
  86. if len(lft) != len(rht):
  87. return False
  88. for (i1, k2, v1), (i2, k2, v2) in zip(lft, rht):
  89. if i1 != i2 or v1 != v2:
  90. return False
  91. return True
  92. if len(self._impl._items) != len(other):
  93. return False
  94. for k, v in self.items():
  95. nv = other.get(k, _marker)
  96. if v != nv:
  97. return False
  98. return True
  99. def __contains__(self, key):
  100. identity = self._title(key)
  101. for i, k, v in self._impl._items:
  102. if i == identity:
  103. return True
  104. return False
  105. def __repr__(self):
  106. body = ', '.join("'{}': {!r}".format(k, v) for k, v in self.items())
  107. return '<{}({})>'.format(self.__class__.__name__, body)
  108. class MultiDictProxy(_Base, MultiMapping):
  109. def __init__(self, arg):
  110. if not isinstance(arg, (MultiDict, MultiDictProxy)):
  111. raise TypeError(
  112. 'ctor requires MultiDict or MultiDictProxy instance'
  113. ', not {}'.format(
  114. type(arg)))
  115. self._impl = arg._impl
  116. def __reduce__(self):
  117. raise TypeError("can't pickle {} objects".format(
  118. self.__class__.__name__))
  119. def copy(self):
  120. """Return a copy of itself."""
  121. return MultiDict(self.items())
  122. class CIMultiDictProxy(MultiDictProxy):
  123. def __init__(self, arg):
  124. if not isinstance(arg, (CIMultiDict, CIMultiDictProxy)):
  125. raise TypeError(
  126. 'ctor requires CIMultiDict or CIMultiDictProxy instance'
  127. ', not {}'.format(
  128. type(arg)))
  129. self._impl = arg._impl
  130. def _title(self, key):
  131. return key.title()
  132. def copy(self):
  133. """Return a copy of itself."""
  134. return CIMultiDict(self.items())
  135. class MultiDict(_Base, MutableMultiMapping):
  136. def __init__(self, *args, **kwargs):
  137. self._impl = _Impl()
  138. self._extend(args, kwargs, self.__class__.__name__,
  139. self._extend_items)
  140. def __reduce__(self):
  141. return (self.__class__, (list(self.items()),))
  142. def _title(self, key):
  143. return key
  144. def _key(self, key):
  145. if isinstance(key, str):
  146. return key
  147. else:
  148. raise TypeError("MultiDict keys should be either str "
  149. "or subclasses of str")
  150. def add(self, key, value):
  151. identity = self._title(key)
  152. self._impl._items.append((identity, self._key(key), value))
  153. self._impl.incr_version()
  154. def copy(self):
  155. """Return a copy of itself."""
  156. cls = self.__class__
  157. return cls(self.items())
  158. __copy__ = copy
  159. def extend(self, *args, **kwargs):
  160. """Extend current MultiDict with more values.
  161. This method must be used instead of update.
  162. """
  163. self._extend(args, kwargs, 'extend', self._extend_items)
  164. def _extend(self, args, kwargs, name, method):
  165. if len(args) > 1:
  166. raise TypeError("{} takes at most 1 positional argument"
  167. " ({} given)".format(name, len(args)))
  168. if args:
  169. arg = args[0]
  170. if isinstance(args[0], (MultiDict, MultiDictProxy)) and not kwargs:
  171. items = arg._impl._items
  172. else:
  173. if hasattr(arg, 'items'):
  174. arg = arg.items()
  175. if kwargs:
  176. arg = list(arg)
  177. arg.extend(list(kwargs.items()))
  178. items = []
  179. for item in arg:
  180. if not len(item) == 2:
  181. raise TypeError(
  182. "{} takes either dict or list of (key, value) "
  183. "tuples".format(name))
  184. items.append((self._title(item[0]),
  185. self._key(item[0]),
  186. item[1]))
  187. method(items)
  188. else:
  189. method([(self._title(key), self._key(key), value)
  190. for key, value in kwargs.items()])
  191. def _extend_items(self, items):
  192. for identity, key, value in items:
  193. self.add(key, value)
  194. def clear(self):
  195. """Remove all items from MultiDict."""
  196. self._impl._items.clear()
  197. self._impl.incr_version()
  198. # Mapping interface #
  199. def __setitem__(self, key, value):
  200. self._replace(key, value)
  201. def __delitem__(self, key):
  202. identity = self._title(key)
  203. items = self._impl._items
  204. found = False
  205. for i in range(len(items) - 1, -1, -1):
  206. if items[i][0] == identity:
  207. del items[i]
  208. found = True
  209. if not found:
  210. raise KeyError(key)
  211. else:
  212. self._impl.incr_version()
  213. def setdefault(self, key, default=None):
  214. """Return value for key, set value to default if key is not present."""
  215. identity = self._title(key)
  216. for i, k, v in self._impl._items:
  217. if i == identity:
  218. return v
  219. self.add(key, default)
  220. return default
  221. def popone(self, key, default=_marker):
  222. """Remove specified key and return the corresponding value.
  223. If key is not found, d is returned if given, otherwise
  224. KeyError is raised.
  225. """
  226. identity = self._title(key)
  227. for i in range(len(self._impl._items)):
  228. if self._impl._items[i][0] == identity:
  229. value = self._impl._items[i][2]
  230. del self._impl._items[i]
  231. self._impl.incr_version()
  232. return value
  233. if default is _marker:
  234. raise KeyError(key)
  235. else:
  236. return default
  237. pop = popone
  238. def popall(self, key, default=_marker):
  239. """Remove all occurrences of key and return the list of corresponding
  240. values.
  241. If key is not found, default is returned if given, otherwise
  242. KeyError is raised.
  243. """
  244. found = False
  245. identity = self._title(key)
  246. ret = []
  247. for i in range(len(self._impl._items)-1, -1, -1):
  248. item = self._impl._items[i]
  249. if item[0] == identity:
  250. ret.append(item[2])
  251. del self._impl._items[i]
  252. self._impl.incr_version()
  253. found = True
  254. if not found:
  255. if default is _marker:
  256. raise KeyError(key)
  257. else:
  258. return default
  259. else:
  260. ret.reverse()
  261. return ret
  262. def popitem(self):
  263. """Remove and return an arbitrary (key, value) pair."""
  264. if self._impl._items:
  265. i = self._impl._items.pop(0)
  266. self._impl.incr_version()
  267. return i[1], i[2]
  268. else:
  269. raise KeyError("empty multidict")
  270. def update(self, *args, **kwargs):
  271. """Update the dictionary from *other*, overwriting existing keys."""
  272. self._extend(args, kwargs, 'update', self._update_items)
  273. def _update_items(self, items):
  274. if not items:
  275. return
  276. used_keys = {}
  277. for identity, key, value in items:
  278. start = used_keys.get(identity, 0)
  279. for i in range(start, len(self._impl._items)):
  280. item = self._impl._items[i]
  281. if item[0] == identity:
  282. used_keys[identity] = i + 1
  283. self._impl._items[i] = (identity, key, value)
  284. break
  285. else:
  286. self._impl._items.append((identity, key, value))
  287. used_keys[identity] = len(self._impl._items)
  288. # drop tails
  289. i = 0
  290. while i < len(self._impl._items):
  291. item = self._impl._items[i]
  292. identity = item[0]
  293. pos = used_keys.get(identity)
  294. if pos is None:
  295. i += 1
  296. continue
  297. if i >= pos:
  298. del self._impl._items[i]
  299. else:
  300. i += 1
  301. self._impl.incr_version()
  302. def _replace(self, key, value):
  303. key = self._key(key)
  304. identity = self._title(key)
  305. items = self._impl._items
  306. for i in range(len(items)):
  307. item = items[i]
  308. if item[0] == identity:
  309. items[i] = (identity, key, value)
  310. # i points to last found item
  311. rgt = i
  312. self._impl.incr_version()
  313. break
  314. else:
  315. self._impl._items.append((identity, key, value))
  316. self._impl.incr_version()
  317. return
  318. # remove all tail items
  319. i = rgt + 1
  320. while i < len(items):
  321. item = items[i]
  322. if item[0] == identity:
  323. del items[i]
  324. else:
  325. i += 1
  326. class CIMultiDict(MultiDict):
  327. def _title(self, key):
  328. return key.title()
  329. class _ViewBase:
  330. def __init__(self, impl):
  331. self._impl = impl
  332. self._version = impl._version
  333. def __len__(self):
  334. return len(self._impl._items)
  335. class _ItemsView(_ViewBase, abc.ItemsView):
  336. def __contains__(self, item):
  337. assert isinstance(item, tuple) or isinstance(item, list)
  338. assert len(item) == 2
  339. for i, k, v in self._impl._items:
  340. if item[0] == k and item[1] == v:
  341. return True
  342. return False
  343. def __iter__(self):
  344. for i, k, v in self._impl._items:
  345. if self._version != self._impl._version:
  346. raise RuntimeError("Dictionary changed during iteration")
  347. yield k, v
  348. def __repr__(self):
  349. lst = []
  350. for item in self._impl._items:
  351. lst.append("{!r}: {!r}".format(item[1], item[2]))
  352. body = ', '.join(lst)
  353. return '{}({})'.format(self.__class__.__name__, body)
  354. class _ValuesView(_ViewBase, abc.ValuesView):
  355. def __contains__(self, value):
  356. for item in self._impl._items:
  357. if item[2] == value:
  358. return True
  359. return False
  360. def __iter__(self):
  361. for item in self._impl._items:
  362. if self._version != self._impl._version:
  363. raise RuntimeError("Dictionary changed during iteration")
  364. yield item[2]
  365. def __repr__(self):
  366. lst = []
  367. for item in self._impl._items:
  368. lst.append("{!r}".format(item[2]))
  369. body = ', '.join(lst)
  370. return '{}({})'.format(self.__class__.__name__, body)
  371. class _KeysView(_ViewBase, abc.KeysView):
  372. def __contains__(self, key):
  373. for item in self._impl._items:
  374. if item[1] == key:
  375. return True
  376. return False
  377. def __iter__(self):
  378. for item in self._impl._items:
  379. if self._version != self._impl._version:
  380. raise RuntimeError("Dictionary changed during iteration")
  381. yield item[1]
  382. def __repr__(self):
  383. lst = []
  384. for item in self._impl._items:
  385. lst.append("{!r}".format(item[1]))
  386. body = ', '.join(lst)
  387. return '{}({})'.format(self.__class__.__name__, body)