You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

982 lines
25 KiB

4 years ago
  1. import itertools
  2. import heapq
  3. import collections
  4. import operator
  5. from functools import partial
  6. from random import Random
  7. from toolz.compatibility import (map, filterfalse, zip, zip_longest, iteritems,
  8. filter)
  9. from toolz.utils import no_default
  10. __all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave',
  11. 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'take_nth',
  12. 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
  13. 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
  14. 'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
  15. 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample')
  16. def remove(predicate, seq):
  17. """ Return those items of sequence for which predicate(item) is False
  18. >>> def iseven(x):
  19. ... return x % 2 == 0
  20. >>> list(remove(iseven, [1, 2, 3, 4]))
  21. [1, 3]
  22. """
  23. return filterfalse(predicate, seq)
  24. def accumulate(binop, seq, initial=no_default):
  25. """ Repeatedly apply binary function to a sequence, accumulating results
  26. >>> from operator import add, mul
  27. >>> list(accumulate(add, [1, 2, 3, 4, 5]))
  28. [1, 3, 6, 10, 15]
  29. >>> list(accumulate(mul, [1, 2, 3, 4, 5]))
  30. [1, 2, 6, 24, 120]
  31. Accumulate is similar to ``reduce`` and is good for making functions like
  32. cumulative sum:
  33. >>> from functools import partial, reduce
  34. >>> sum = partial(reduce, add)
  35. >>> cumsum = partial(accumulate, add)
  36. Accumulate also takes an optional argument that will be used as the first
  37. value. This is similar to reduce.
  38. >>> list(accumulate(add, [1, 2, 3], -1))
  39. [-1, 0, 2, 5]
  40. >>> list(accumulate(add, [], 1))
  41. [1]
  42. See Also:
  43. itertools.accumulate : In standard itertools for Python 3.2+
  44. """
  45. seq = iter(seq)
  46. result = next(seq) if initial == no_default else initial
  47. yield result
  48. for elem in seq:
  49. result = binop(result, elem)
  50. yield result
  51. def groupby(key, seq):
  52. """ Group a collection by a key function
  53. >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
  54. >>> groupby(len, names) # doctest: +SKIP
  55. {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
  56. >>> iseven = lambda x: x % 2 == 0
  57. >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
  58. {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
  59. Non-callable keys imply grouping on a member.
  60. >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
  61. ... {'name': 'Bob', 'gender': 'M'},
  62. ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
  63. {'F': [{'gender': 'F', 'name': 'Alice'}],
  64. 'M': [{'gender': 'M', 'name': 'Bob'},
  65. {'gender': 'M', 'name': 'Charlie'}]}
  66. See Also:
  67. countby
  68. """
  69. if not callable(key):
  70. key = getter(key)
  71. d = collections.defaultdict(lambda: [].append)
  72. for item in seq:
  73. d[key(item)](item)
  74. rv = {}
  75. for k, v in iteritems(d):
  76. rv[k] = v.__self__
  77. return rv
  78. def merge_sorted(*seqs, **kwargs):
  79. """ Merge and sort a collection of sorted collections
  80. This works lazily and only keeps one value from each iterable in memory.
  81. >>> list(merge_sorted([1, 3, 5], [2, 4, 6]))
  82. [1, 2, 3, 4, 5, 6]
  83. >>> ''.join(merge_sorted('abc', 'abc', 'abc'))
  84. 'aaabbbccc'
  85. The "key" function used to sort the input may be passed as a keyword.
  86. >>> list(merge_sorted([2, 3], [1, 3], key=lambda x: x // 3))
  87. [2, 1, 3, 3]
  88. """
  89. if len(seqs) == 0:
  90. return iter([])
  91. elif len(seqs) == 1:
  92. return iter(seqs[0])
  93. key = kwargs.get('key', None)
  94. if key is None:
  95. return _merge_sorted_binary(seqs)
  96. else:
  97. return _merge_sorted_binary_key(seqs, key)
  98. def _merge_sorted_binary(seqs):
  99. mid = len(seqs) // 2
  100. L1 = seqs[:mid]
  101. if len(L1) == 1:
  102. seq1 = iter(L1[0])
  103. else:
  104. seq1 = _merge_sorted_binary(L1)
  105. L2 = seqs[mid:]
  106. if len(L2) == 1:
  107. seq2 = iter(L2[0])
  108. else:
  109. seq2 = _merge_sorted_binary(L2)
  110. try:
  111. val2 = next(seq2)
  112. except StopIteration:
  113. for val1 in seq1:
  114. yield val1
  115. return
  116. for val1 in seq1:
  117. if val2 < val1:
  118. yield val2
  119. for val2 in seq2:
  120. if val2 < val1:
  121. yield val2
  122. else:
  123. yield val1
  124. break
  125. else:
  126. break
  127. else:
  128. yield val1
  129. else:
  130. yield val2
  131. for val2 in seq2:
  132. yield val2
  133. return
  134. yield val1
  135. for val1 in seq1:
  136. yield val1
  137. def _merge_sorted_binary_key(seqs, key):
  138. mid = len(seqs) // 2
  139. L1 = seqs[:mid]
  140. if len(L1) == 1:
  141. seq1 = iter(L1[0])
  142. else:
  143. seq1 = _merge_sorted_binary_key(L1, key)
  144. L2 = seqs[mid:]
  145. if len(L2) == 1:
  146. seq2 = iter(L2[0])
  147. else:
  148. seq2 = _merge_sorted_binary_key(L2, key)
  149. try:
  150. val2 = next(seq2)
  151. except StopIteration:
  152. for val1 in seq1:
  153. yield val1
  154. return
  155. key2 = key(val2)
  156. for val1 in seq1:
  157. key1 = key(val1)
  158. if key2 < key1:
  159. yield val2
  160. for val2 in seq2:
  161. key2 = key(val2)
  162. if key2 < key1:
  163. yield val2
  164. else:
  165. yield val1
  166. break
  167. else:
  168. break
  169. else:
  170. yield val1
  171. else:
  172. yield val2
  173. for val2 in seq2:
  174. yield val2
  175. return
  176. yield val1
  177. for val1 in seq1:
  178. yield val1
  179. def interleave(seqs):
  180. """ Interleave a sequence of sequences
  181. >>> list(interleave([[1, 2], [3, 4]]))
  182. [1, 3, 2, 4]
  183. >>> ''.join(interleave(('ABC', 'XY')))
  184. 'AXBYC'
  185. Both the individual sequences and the sequence of sequences may be infinite
  186. Returns a lazy iterator
  187. """
  188. iters = itertools.cycle(map(iter, seqs))
  189. while True:
  190. try:
  191. for itr in iters:
  192. yield next(itr)
  193. return
  194. except StopIteration:
  195. predicate = partial(operator.is_not, itr)
  196. iters = itertools.cycle(itertools.takewhile(predicate, iters))
  197. def unique(seq, key=None):
  198. """ Return only unique elements of a sequence
  199. >>> tuple(unique((1, 2, 3)))
  200. (1, 2, 3)
  201. >>> tuple(unique((1, 2, 1, 3)))
  202. (1, 2, 3)
  203. Uniqueness can be defined by key keyword
  204. >>> tuple(unique(['cat', 'mouse', 'dog', 'hen'], key=len))
  205. ('cat', 'mouse')
  206. """
  207. seen = set()
  208. seen_add = seen.add
  209. if key is None:
  210. for item in seq:
  211. if item not in seen:
  212. seen_add(item)
  213. yield item
  214. else: # calculate key
  215. for item in seq:
  216. val = key(item)
  217. if val not in seen:
  218. seen_add(val)
  219. yield item
  220. def isiterable(x):
  221. """ Is x iterable?
  222. >>> isiterable([1, 2, 3])
  223. True
  224. >>> isiterable('abc')
  225. True
  226. >>> isiterable(5)
  227. False
  228. """
  229. try:
  230. iter(x)
  231. return True
  232. except TypeError:
  233. return False
  234. def isdistinct(seq):
  235. """ All values in sequence are distinct
  236. >>> isdistinct([1, 2, 3])
  237. True
  238. >>> isdistinct([1, 2, 1])
  239. False
  240. >>> isdistinct("Hello")
  241. False
  242. >>> isdistinct("World")
  243. True
  244. """
  245. if iter(seq) is seq:
  246. seen = set()
  247. seen_add = seen.add
  248. for item in seq:
  249. if item in seen:
  250. return False
  251. seen_add(item)
  252. return True
  253. else:
  254. return len(seq) == len(set(seq))
  255. def take(n, seq):
  256. """ The first n elements of a sequence
  257. >>> list(take(2, [10, 20, 30, 40, 50]))
  258. [10, 20]
  259. See Also:
  260. drop
  261. tail
  262. """
  263. return itertools.islice(seq, n)
  264. def tail(n, seq):
  265. """ The last n elements of a sequence
  266. >>> tail(2, [10, 20, 30, 40, 50])
  267. [40, 50]
  268. See Also:
  269. drop
  270. take
  271. """
  272. try:
  273. return seq[-n:]
  274. except (TypeError, KeyError):
  275. return tuple(collections.deque(seq, n))
  276. def drop(n, seq):
  277. """ The sequence following the first n elements
  278. >>> list(drop(2, [10, 20, 30, 40, 50]))
  279. [30, 40, 50]
  280. See Also:
  281. take
  282. tail
  283. """
  284. return itertools.islice(seq, n, None)
  285. def take_nth(n, seq):
  286. """ Every nth item in seq
  287. >>> list(take_nth(2, [10, 20, 30, 40, 50]))
  288. [10, 30, 50]
  289. """
  290. return itertools.islice(seq, 0, None, n)
  291. def first(seq):
  292. """ The first element in a sequence
  293. >>> first('ABC')
  294. 'A'
  295. """
  296. return next(iter(seq))
  297. def second(seq):
  298. """ The second element in a sequence
  299. >>> second('ABC')
  300. 'B'
  301. """
  302. return next(itertools.islice(seq, 1, None))
  303. def nth(n, seq):
  304. """ The nth element in a sequence
  305. >>> nth(1, 'ABC')
  306. 'B'
  307. """
  308. if isinstance(seq, (tuple, list, collections.Sequence)):
  309. return seq[n]
  310. else:
  311. return next(itertools.islice(seq, n, None))
  312. def last(seq):
  313. """ The last element in a sequence
  314. >>> last('ABC')
  315. 'C'
  316. """
  317. return tail(1, seq)[0]
  318. rest = partial(drop, 1)
  319. def _get(ind, seq, default):
  320. try:
  321. return seq[ind]
  322. except (KeyError, IndexError):
  323. return default
  324. def get(ind, seq, default=no_default):
  325. """ Get element in a sequence or dict
  326. Provides standard indexing
  327. >>> get(1, 'ABC') # Same as 'ABC'[1]
  328. 'B'
  329. Pass a list to get multiple values
  330. >>> get([1, 2], 'ABC') # ('ABC'[1], 'ABC'[2])
  331. ('B', 'C')
  332. Works on any value that supports indexing/getitem
  333. For example here we see that it works with dictionaries
  334. >>> phonebook = {'Alice': '555-1234',
  335. ... 'Bob': '555-5678',
  336. ... 'Charlie':'555-9999'}
  337. >>> get('Alice', phonebook)
  338. '555-1234'
  339. >>> get(['Alice', 'Bob'], phonebook)
  340. ('555-1234', '555-5678')
  341. Provide a default for missing values
  342. >>> get(['Alice', 'Dennis'], phonebook, None)
  343. ('555-1234', None)
  344. See Also:
  345. pluck
  346. """
  347. try:
  348. return seq[ind]
  349. except TypeError: # `ind` may be a list
  350. if isinstance(ind, list):
  351. if default == no_default:
  352. if len(ind) > 1:
  353. return operator.itemgetter(*ind)(seq)
  354. elif ind:
  355. return (seq[ind[0]],)
  356. else:
  357. return ()
  358. else:
  359. return tuple(_get(i, seq, default) for i in ind)
  360. elif default != no_default:
  361. return default
  362. else:
  363. raise
  364. except (KeyError, IndexError): # we know `ind` is not a list
  365. if default == no_default:
  366. raise
  367. else:
  368. return default
  369. def concat(seqs):
  370. """ Concatenate zero or more iterables, any of which may be infinite.
  371. An infinite sequence will prevent the rest of the arguments from
  372. being included.
  373. We use chain.from_iterable rather than ``chain(*seqs)`` so that seqs
  374. can be a generator.
  375. >>> list(concat([[], [1], [2, 3]]))
  376. [1, 2, 3]
  377. See also:
  378. itertools.chain.from_iterable equivalent
  379. """
  380. return itertools.chain.from_iterable(seqs)
  381. def concatv(*seqs):
  382. """ Variadic version of concat
  383. >>> list(concatv([], ["a"], ["b", "c"]))
  384. ['a', 'b', 'c']
  385. See also:
  386. itertools.chain
  387. """
  388. return concat(seqs)
  389. def mapcat(func, seqs):
  390. """ Apply func to each sequence in seqs, concatenating results.
  391. >>> list(mapcat(lambda s: [c.upper() for c in s],
  392. ... [["a", "b"], ["c", "d", "e"]]))
  393. ['A', 'B', 'C', 'D', 'E']
  394. """
  395. return concat(map(func, seqs))
  396. def cons(el, seq):
  397. """ Add el to beginning of (possibly infinite) sequence seq.
  398. >>> list(cons(1, [2, 3]))
  399. [1, 2, 3]
  400. """
  401. return itertools.chain([el], seq)
  402. def interpose(el, seq):
  403. """ Introduce element between each pair of elements in seq
  404. >>> list(interpose("a", [1, 2, 3]))
  405. [1, 'a', 2, 'a', 3]
  406. """
  407. inposed = concat(zip(itertools.repeat(el), seq))
  408. next(inposed)
  409. return inposed
  410. def frequencies(seq):
  411. """ Find number of occurrences of each value in seq
  412. >>> frequencies(['cat', 'cat', 'ox', 'pig', 'pig', 'cat']) #doctest: +SKIP
  413. {'cat': 3, 'ox': 1, 'pig': 2}
  414. See Also:
  415. countby
  416. groupby
  417. """
  418. d = collections.defaultdict(int)
  419. for item in seq:
  420. d[item] += 1
  421. return dict(d)
  422. def reduceby(key, binop, seq, init=no_default):
  423. """ Perform a simultaneous groupby and reduction
  424. The computation:
  425. >>> result = reduceby(key, binop, seq, init) # doctest: +SKIP
  426. is equivalent to the following:
  427. >>> def reduction(group): # doctest: +SKIP
  428. ... return reduce(binop, group, init) # doctest: +SKIP
  429. >>> groups = groupby(key, seq) # doctest: +SKIP
  430. >>> result = valmap(reduction, groups) # doctest: +SKIP
  431. But the former does not build the intermediate groups, allowing it to
  432. operate in much less space. This makes it suitable for larger datasets
  433. that do not fit comfortably in memory
  434. The ``init`` keyword argument is the default initialization of the
  435. reduction. This can be either a constant value like ``0`` or a callable
  436. like ``lambda : 0`` as might be used in ``defaultdict``.
  437. Simple Examples
  438. ---------------
  439. >>> from operator import add, mul
  440. >>> iseven = lambda x: x % 2 == 0
  441. >>> data = [1, 2, 3, 4, 5]
  442. >>> reduceby(iseven, add, data) # doctest: +SKIP
  443. {False: 9, True: 6}
  444. >>> reduceby(iseven, mul, data) # doctest: +SKIP
  445. {False: 15, True: 8}
  446. Complex Example
  447. ---------------
  448. >>> projects = [{'name': 'build roads', 'state': 'CA', 'cost': 1000000},
  449. ... {'name': 'fight crime', 'state': 'IL', 'cost': 100000},
  450. ... {'name': 'help farmers', 'state': 'IL', 'cost': 2000000},
  451. ... {'name': 'help farmers', 'state': 'CA', 'cost': 200000}]
  452. >>> reduceby('state', # doctest: +SKIP
  453. ... lambda acc, x: acc + x['cost'],
  454. ... projects, 0)
  455. {'CA': 1200000, 'IL': 2100000}
  456. Example Using ``init``
  457. ----------------------
  458. >>> def set_add(s, i):
  459. ... s.add(i)
  460. ... return s
  461. >>> reduceby(iseven, set_add, [1, 2, 3, 4, 1, 2, 3], set) # doctest: +SKIP
  462. {True: set([2, 4]),
  463. False: set([1, 3])}
  464. """
  465. is_no_default = init == no_default
  466. if not is_no_default and not callable(init):
  467. _init = init
  468. init = lambda: _init
  469. if not callable(key):
  470. key = getter(key)
  471. d = {}
  472. for item in seq:
  473. k = key(item)
  474. if k not in d:
  475. if is_no_default:
  476. d[k] = item
  477. continue
  478. else:
  479. d[k] = init()
  480. d[k] = binop(d[k], item)
  481. return d
  482. def iterate(func, x):
  483. """ Repeatedly apply a function func onto an original input
  484. Yields x, then func(x), then func(func(x)), then func(func(func(x))), etc..
  485. >>> def inc(x): return x + 1
  486. >>> counter = iterate(inc, 0)
  487. >>> next(counter)
  488. 0
  489. >>> next(counter)
  490. 1
  491. >>> next(counter)
  492. 2
  493. >>> double = lambda x: x * 2
  494. >>> powers_of_two = iterate(double, 1)
  495. >>> next(powers_of_two)
  496. 1
  497. >>> next(powers_of_two)
  498. 2
  499. >>> next(powers_of_two)
  500. 4
  501. >>> next(powers_of_two)
  502. 8
  503. """
  504. while True:
  505. yield x
  506. x = func(x)
  507. def sliding_window(n, seq):
  508. """ A sequence of overlapping subsequences
  509. >>> list(sliding_window(2, [1, 2, 3, 4]))
  510. [(1, 2), (2, 3), (3, 4)]
  511. This function creates a sliding window suitable for transformations like
  512. sliding means / smoothing
  513. >>> mean = lambda seq: float(sum(seq)) / len(seq)
  514. >>> list(map(mean, sliding_window(2, [1, 2, 3, 4])))
  515. [1.5, 2.5, 3.5]
  516. """
  517. return zip(*(collections.deque(itertools.islice(it, i), 0) or it
  518. for i, it in enumerate(itertools.tee(seq, n))))
  519. no_pad = '__no__pad__'
  520. def partition(n, seq, pad=no_pad):
  521. """ Partition sequence into tuples of length n
  522. >>> list(partition(2, [1, 2, 3, 4]))
  523. [(1, 2), (3, 4)]
  524. If the length of ``seq`` is not evenly divisible by ``n``, the final tuple
  525. is dropped if ``pad`` is not specified, or filled to length ``n`` by pad:
  526. >>> list(partition(2, [1, 2, 3, 4, 5]))
  527. [(1, 2), (3, 4)]
  528. >>> list(partition(2, [1, 2, 3, 4, 5], pad=None))
  529. [(1, 2), (3, 4), (5, None)]
  530. See Also:
  531. partition_all
  532. """
  533. args = [iter(seq)] * n
  534. if pad is no_pad:
  535. return zip(*args)
  536. else:
  537. return zip_longest(*args, fillvalue=pad)
  538. def partition_all(n, seq):
  539. """ Partition all elements of sequence into tuples of length at most n
  540. The final tuple may be shorter to accommodate extra elements.
  541. >>> list(partition_all(2, [1, 2, 3, 4]))
  542. [(1, 2), (3, 4)]
  543. >>> list(partition_all(2, [1, 2, 3, 4, 5]))
  544. [(1, 2), (3, 4), (5,)]
  545. See Also:
  546. partition
  547. """
  548. args = [iter(seq)] * n
  549. it = zip_longest(*args, fillvalue=no_pad)
  550. try:
  551. prev = next(it)
  552. except StopIteration:
  553. return
  554. for item in it:
  555. yield prev
  556. prev = item
  557. if prev[-1] is no_pad:
  558. yield prev[:prev.index(no_pad)]
  559. else:
  560. yield prev
  561. def count(seq):
  562. """ Count the number of items in seq
  563. Like the builtin ``len`` but works on lazy sequencies.
  564. Not to be confused with ``itertools.count``
  565. See also:
  566. len
  567. """
  568. if hasattr(seq, '__len__'):
  569. return len(seq)
  570. return sum(1 for i in seq)
  571. def pluck(ind, seqs, default=no_default):
  572. """ plucks an element or several elements from each item in a sequence.
  573. ``pluck`` maps ``itertoolz.get`` over a sequence and returns one or more
  574. elements of each item in the sequence.
  575. This is equivalent to running `map(curried.get(ind), seqs)`
  576. ``ind`` can be either a single string/index or a list of strings/indices.
  577. ``seqs`` should be sequence containing sequences or dicts.
  578. e.g.
  579. >>> data = [{'id': 1, 'name': 'Cheese'}, {'id': 2, 'name': 'Pies'}]
  580. >>> list(pluck('name', data))
  581. ['Cheese', 'Pies']
  582. >>> list(pluck([0, 1], [[1, 2, 3], [4, 5, 7]]))
  583. [(1, 2), (4, 5)]
  584. See Also:
  585. get
  586. map
  587. """
  588. if default == no_default:
  589. get = getter(ind)
  590. return map(get, seqs)
  591. elif isinstance(ind, list):
  592. return (tuple(_get(item, seq, default) for item in ind)
  593. for seq in seqs)
  594. return (_get(ind, seq, default) for seq in seqs)
  595. def getter(index):
  596. if isinstance(index, list):
  597. if len(index) == 1:
  598. index = index[0]
  599. return lambda x: (x[index],)
  600. elif index:
  601. return operator.itemgetter(*index)
  602. else:
  603. return lambda x: ()
  604. else:
  605. return operator.itemgetter(index)
  606. def join(leftkey, leftseq, rightkey, rightseq,
  607. left_default=no_default, right_default=no_default):
  608. """ Join two sequences on common attributes
  609. This is a semi-streaming operation. The LEFT sequence is fully evaluated
  610. and placed into memory. The RIGHT sequence is evaluated lazily and so can
  611. be arbitrarily large.
  612. >>> friends = [('Alice', 'Edith'),
  613. ... ('Alice', 'Zhao'),
  614. ... ('Edith', 'Alice'),
  615. ... ('Zhao', 'Alice'),
  616. ... ('Zhao', 'Edith')]
  617. >>> cities = [('Alice', 'NYC'),
  618. ... ('Alice', 'Chicago'),
  619. ... ('Dan', 'Syndey'),
  620. ... ('Edith', 'Paris'),
  621. ... ('Edith', 'Berlin'),
  622. ... ('Zhao', 'Shanghai')]
  623. >>> # Vacation opportunities
  624. >>> # In what cities do people have friends?
  625. >>> result = join(second, friends,
  626. ... first, cities)
  627. >>> for ((a, b), (c, d)) in sorted(unique(result)):
  628. ... print((a, d))
  629. ('Alice', 'Berlin')
  630. ('Alice', 'Paris')
  631. ('Alice', 'Shanghai')
  632. ('Edith', 'Chicago')
  633. ('Edith', 'NYC')
  634. ('Zhao', 'Chicago')
  635. ('Zhao', 'NYC')
  636. ('Zhao', 'Berlin')
  637. ('Zhao', 'Paris')
  638. Specify outer joins with keyword arguments ``left_default`` and/or
  639. ``right_default``. Here is a full outer join in which unmatched elements
  640. are paired with None.
  641. >>> identity = lambda x: x
  642. >>> list(join(identity, [1, 2, 3],
  643. ... identity, [2, 3, 4],
  644. ... left_default=None, right_default=None))
  645. [(2, 2), (3, 3), (None, 4), (1, None)]
  646. Usually the key arguments are callables to be applied to the sequences. If
  647. the keys are not obviously callable then it is assumed that indexing was
  648. intended, e.g. the following is a legal change
  649. >>> # result = join(second, friends, first, cities)
  650. >>> result = join(1, friends, 0, cities) # doctest: +SKIP
  651. """
  652. if not callable(leftkey):
  653. leftkey = getter(leftkey)
  654. if not callable(rightkey):
  655. rightkey = getter(rightkey)
  656. d = groupby(leftkey, leftseq)
  657. seen_keys = set()
  658. left_default_is_no_default = (left_default == no_default)
  659. for item in rightseq:
  660. key = rightkey(item)
  661. seen_keys.add(key)
  662. try:
  663. left_matches = d[key]
  664. for match in left_matches:
  665. yield (match, item)
  666. except KeyError:
  667. if not left_default_is_no_default:
  668. yield (left_default, item)
  669. if right_default != no_default:
  670. for key, matches in d.items():
  671. if key not in seen_keys:
  672. for match in matches:
  673. yield (match, right_default)
  674. def diff(*seqs, **kwargs):
  675. """ Return those items that differ between sequences
  676. >>> list(diff([1, 2, 3], [1, 2, 10, 100]))
  677. [(3, 10)]
  678. Shorter sequences may be padded with a ``default`` value:
  679. >>> list(diff([1, 2, 3], [1, 2, 10, 100], default=None))
  680. [(3, 10), (None, 100)]
  681. A ``key`` function may also be applied to each item to use during
  682. comparisons:
  683. >>> list(diff(['apples', 'bananas'], ['Apples', 'Oranges'], key=str.lower))
  684. [('bananas', 'Oranges')]
  685. """
  686. N = len(seqs)
  687. if N == 1 and isinstance(seqs[0], list):
  688. seqs = seqs[0]
  689. N = len(seqs)
  690. if N < 2:
  691. raise TypeError('Too few sequences given (min 2 required)')
  692. default = kwargs.get('default', no_default)
  693. if default == no_default:
  694. iters = zip(*seqs)
  695. else:
  696. iters = zip_longest(*seqs, fillvalue=default)
  697. key = kwargs.get('key', None)
  698. if key is None:
  699. for items in iters:
  700. if items.count(items[0]) != N:
  701. yield items
  702. else:
  703. for items in iters:
  704. vals = tuple(map(key, items))
  705. if vals.count(vals[0]) != N:
  706. yield items
  707. def topk(k, seq, key=None):
  708. """ Find the k largest elements of a sequence
  709. Operates lazily in ``n*log(k)`` time
  710. >>> topk(2, [1, 100, 10, 1000])
  711. (1000, 100)
  712. Use a key function to change sorted order
  713. >>> topk(2, ['Alice', 'Bob', 'Charlie', 'Dan'], key=len)
  714. ('Charlie', 'Alice')
  715. See also:
  716. heapq.nlargest
  717. """
  718. if key is not None and not callable(key):
  719. key = getter(key)
  720. return tuple(heapq.nlargest(k, seq, key=key))
  721. def peek(seq):
  722. """ Retrieve the next element of a sequence
  723. Returns the first element and an iterable equivalent to the original
  724. sequence, still having the element retrieved.
  725. >>> seq = [0, 1, 2, 3, 4]
  726. >>> first, seq = peek(seq)
  727. >>> first
  728. 0
  729. >>> list(seq)
  730. [0, 1, 2, 3, 4]
  731. """
  732. iterator = iter(seq)
  733. item = next(iterator)
  734. return item, itertools.chain([item], iterator)
  735. def random_sample(prob, seq, random_state=None):
  736. """ Return elements from a sequence with probability of prob
  737. Returns a lazy iterator of random items from seq.
  738. ``random_sample`` considers each item independently and without
  739. replacement. See below how the first time it returned 13 items and the
  740. next time it returned 6 items.
  741. >>> seq = list(range(100))
  742. >>> list(random_sample(0.1, seq)) # doctest: +SKIP
  743. [6, 9, 19, 35, 45, 50, 58, 62, 68, 72, 78, 86, 95]
  744. >>> list(random_sample(0.1, seq)) # doctest: +SKIP
  745. [6, 44, 54, 61, 69, 94]
  746. Providing an integer seed for ``random_state`` will result in
  747. deterministic sampling. Given the same seed it will return the same sample
  748. every time.
  749. >>> list(random_sample(0.1, seq, random_state=2016))
  750. [7, 9, 19, 25, 30, 32, 34, 48, 59, 60, 81, 98]
  751. >>> list(random_sample(0.1, seq, random_state=2016))
  752. [7, 9, 19, 25, 30, 32, 34, 48, 59, 60, 81, 98]
  753. ``random_state`` can also be any object with a method ``random`` that
  754. returns floats between 0.0 and 1.0 (exclusive).
  755. >>> from random import Random
  756. >>> randobj = Random(2016)
  757. >>> list(random_sample(0.1, seq, random_state=randobj))
  758. [7, 9, 19, 25, 30, 32, 34, 48, 59, 60, 81, 98]
  759. """
  760. if not hasattr(random_state, 'random'):
  761. random_state = Random(random_state)
  762. return filter(lambda _: random_state.random() < prob, seq)