You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

315 lines
8.5 KiB

4 years ago
  1. import copy
  2. import operator
  3. from toolz.compatibility import (map, zip, iteritems, iterkeys, itervalues,
  4. reduce)
  5. __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
  6. 'valfilter', 'keyfilter', 'itemfilter',
  7. 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
  8. def _get_factory(f, kwargs):
  9. factory = kwargs.pop('factory', dict)
  10. if kwargs:
  11. raise TypeError("{0}() got an unexpected keyword argument "
  12. "'{1}'".format(f.__name__, kwargs.popitem()[0]))
  13. return factory
  14. def merge(*dicts, **kwargs):
  15. """ Merge a collection of dictionaries
  16. >>> merge({1: 'one'}, {2: 'two'})
  17. {1: 'one', 2: 'two'}
  18. Later dictionaries have precedence
  19. >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
  20. {1: 2, 3: 3, 4: 4}
  21. See Also:
  22. merge_with
  23. """
  24. if len(dicts) == 1 and not isinstance(dicts[0], dict):
  25. dicts = dicts[0]
  26. factory = _get_factory(merge, kwargs)
  27. rv = factory()
  28. for d in dicts:
  29. rv.update(d)
  30. return rv
  31. def merge_with(func, *dicts, **kwargs):
  32. """ Merge dictionaries and apply function to combined values
  33. A key may occur in more than one dict, and all values mapped from the key
  34. will be passed to the function as a list, such as func([val1, val2, ...]).
  35. >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
  36. {1: 11, 2: 22}
  37. >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
  38. {1: 1, 2: 2, 3: 30}
  39. See Also:
  40. merge
  41. """
  42. if len(dicts) == 1 and not isinstance(dicts[0], dict):
  43. dicts = dicts[0]
  44. factory = _get_factory(merge_with, kwargs)
  45. result = factory()
  46. for d in dicts:
  47. for k, v in iteritems(d):
  48. if k not in result:
  49. result[k] = [v]
  50. else:
  51. result[k].append(v)
  52. return valmap(func, result, factory)
  53. def valmap(func, d, factory=dict):
  54. """ Apply function to values of dictionary
  55. >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
  56. >>> valmap(sum, bills) # doctest: +SKIP
  57. {'Alice': 65, 'Bob': 45}
  58. See Also:
  59. keymap
  60. itemmap
  61. """
  62. rv = factory()
  63. rv.update(zip(iterkeys(d), map(func, itervalues(d))))
  64. return rv
  65. def keymap(func, d, factory=dict):
  66. """ Apply function to keys of dictionary
  67. >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
  68. >>> keymap(str.lower, bills) # doctest: +SKIP
  69. {'alice': [20, 15, 30], 'bob': [10, 35]}
  70. See Also:
  71. valmap
  72. itemmap
  73. """
  74. rv = factory()
  75. rv.update(zip(map(func, iterkeys(d)), itervalues(d)))
  76. return rv
  77. def itemmap(func, d, factory=dict):
  78. """ Apply function to items of dictionary
  79. >>> accountids = {"Alice": 10, "Bob": 20}
  80. >>> itemmap(reversed, accountids) # doctest: +SKIP
  81. {10: "Alice", 20: "Bob"}
  82. See Also:
  83. keymap
  84. valmap
  85. """
  86. rv = factory()
  87. rv.update(map(func, iteritems(d)))
  88. return rv
  89. def valfilter(predicate, d, factory=dict):
  90. """ Filter items in dictionary by value
  91. >>> iseven = lambda x: x % 2 == 0
  92. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  93. >>> valfilter(iseven, d)
  94. {1: 2, 3: 4}
  95. See Also:
  96. keyfilter
  97. itemfilter
  98. valmap
  99. """
  100. rv = factory()
  101. for k, v in iteritems(d):
  102. if predicate(v):
  103. rv[k] = v
  104. return rv
  105. def keyfilter(predicate, d, factory=dict):
  106. """ Filter items in dictionary by key
  107. >>> iseven = lambda x: x % 2 == 0
  108. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  109. >>> keyfilter(iseven, d)
  110. {2: 3, 4: 5}
  111. See Also:
  112. valfilter
  113. itemfilter
  114. keymap
  115. """
  116. rv = factory()
  117. for k, v in iteritems(d):
  118. if predicate(k):
  119. rv[k] = v
  120. return rv
  121. def itemfilter(predicate, d, factory=dict):
  122. """ Filter items in dictionary by item
  123. >>> def isvalid(item):
  124. ... k, v = item
  125. ... return k % 2 == 0 and v < 4
  126. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  127. >>> itemfilter(isvalid, d)
  128. {2: 3}
  129. See Also:
  130. keyfilter
  131. valfilter
  132. itemmap
  133. """
  134. rv = factory()
  135. for item in iteritems(d):
  136. if predicate(item):
  137. k, v = item
  138. rv[k] = v
  139. return rv
  140. def assoc(d, key, value, factory=dict):
  141. """ Return a new dict with new key value pair
  142. New dict has d[key] set to value. Does not modify the initial dictionary.
  143. >>> assoc({'x': 1}, 'x', 2)
  144. {'x': 2}
  145. >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
  146. {'x': 1, 'y': 3}
  147. """
  148. d2 = factory()
  149. d2[key] = value
  150. return merge(d, d2, factory=factory)
  151. def dissoc(d, *keys):
  152. """ Return a new dict with the given key(s) removed.
  153. New dict has d[key] deleted for each supplied key.
  154. Does not modify the initial dictionary.
  155. >>> dissoc({'x': 1, 'y': 2}, 'y')
  156. {'x': 1}
  157. >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
  158. {}
  159. >>> dissoc({'x': 1}, 'y') # Ignores missing keys
  160. {'x': 1}
  161. """
  162. d2 = copy.copy(d)
  163. for key in keys:
  164. if key in d2:
  165. del d2[key]
  166. return d2
  167. def assoc_in(d, keys, value, factory=dict):
  168. """ Return a new dict with new, potentially nested, key value pair
  169. >>> purchase = {'name': 'Alice',
  170. ... 'order': {'items': ['Apple', 'Orange'],
  171. ... 'costs': [0.50, 1.25]},
  172. ... 'credit card': '5555-1234-1234-1234'}
  173. >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
  174. {'credit card': '5555-1234-1234-1234',
  175. 'name': 'Alice',
  176. 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
  177. """
  178. return update_in(d, keys, lambda x: value, value, factory)
  179. def update_in(d, keys, func, default=None, factory=dict):
  180. """ Update value in a (potentially) nested dictionary
  181. inputs:
  182. d - dictionary on which to operate
  183. keys - list or tuple giving the location of the value to be changed in d
  184. func - function to operate on that value
  185. If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
  186. original dictionary with v replaced by func(v), but does not mutate the
  187. original dictionary.
  188. If k0 is not a key in d, update_in creates nested dictionaries to the depth
  189. specified by the keys, with the innermost value set to func(default).
  190. >>> inc = lambda x: x + 1
  191. >>> update_in({'a': 0}, ['a'], inc)
  192. {'a': 1}
  193. >>> transaction = {'name': 'Alice',
  194. ... 'purchase': {'items': ['Apple', 'Orange'],
  195. ... 'costs': [0.50, 1.25]},
  196. ... 'credit card': '5555-1234-1234-1234'}
  197. >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
  198. {'credit card': '5555-1234-1234-1234',
  199. 'name': 'Alice',
  200. 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
  201. >>> # updating a value when k0 is not in d
  202. >>> update_in({}, [1, 2, 3], str, default="bar")
  203. {1: {2: {3: 'bar'}}}
  204. >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
  205. {1: 'foo', 2: {3: {4: 1}}}
  206. """
  207. assert len(keys) > 0
  208. k, ks = keys[0], keys[1:]
  209. if ks:
  210. return assoc(d, k, update_in(d[k] if (k in d) else factory(),
  211. ks, func, default, factory),
  212. factory)
  213. else:
  214. innermost = func(d[k]) if (k in d) else func(default)
  215. return assoc(d, k, innermost, factory)
  216. def get_in(keys, coll, default=None, no_default=False):
  217. """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
  218. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
  219. ``no_default`` is specified, then it raises KeyError or IndexError.
  220. ``get_in`` is a generalization of ``operator.getitem`` for nested data
  221. structures such as dictionaries and lists.
  222. >>> transaction = {'name': 'Alice',
  223. ... 'purchase': {'items': ['Apple', 'Orange'],
  224. ... 'costs': [0.50, 1.25]},
  225. ... 'credit card': '5555-1234-1234-1234'}
  226. >>> get_in(['purchase', 'items', 0], transaction)
  227. 'Apple'
  228. >>> get_in(['name'], transaction)
  229. 'Alice'
  230. >>> get_in(['purchase', 'total'], transaction)
  231. >>> get_in(['purchase', 'items', 'apple'], transaction)
  232. >>> get_in(['purchase', 'items', 10], transaction)
  233. >>> get_in(['purchase', 'total'], transaction, 0)
  234. 0
  235. >>> get_in(['y'], {}, no_default=True)
  236. Traceback (most recent call last):
  237. ...
  238. KeyError: 'y'
  239. See Also:
  240. itertoolz.get
  241. operator.getitem
  242. """
  243. try:
  244. return reduce(operator.getitem, keys, coll)
  245. except (KeyError, IndexError, TypeError):
  246. if no_default:
  247. raise
  248. return default