You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

4011 lines
121 KiB

4 years ago
  1. """
  2. Numerical python functions written for compatibility with MATLAB
  3. commands with the same names.
  4. MATLAB compatible functions
  5. ---------------------------
  6. :func:`cohere`
  7. Coherence (normalized cross spectral density)
  8. :func:`csd`
  9. Cross spectral density using Welch's average periodogram
  10. :func:`detrend`
  11. Remove the mean or best fit line from an array
  12. :func:`find`
  13. Return the indices where some condition is true;
  14. numpy.nonzero is similar but more general.
  15. :func:`griddata`
  16. Interpolate irregularly distributed data to a
  17. regular grid.
  18. :func:`prctile`
  19. Find the percentiles of a sequence
  20. :func:`prepca`
  21. Principal Component Analysis
  22. :func:`psd`
  23. Power spectral density using Welch's average periodogram
  24. :func:`rk4`
  25. A 4th order runge kutta integrator for 1D or ND systems
  26. :func:`specgram`
  27. Spectrogram (spectrum over segments of time)
  28. Miscellaneous functions
  29. -----------------------
  30. Functions that don't exist in MATLAB, but are useful anyway:
  31. :func:`cohere_pairs`
  32. Coherence over all pairs. This is not a MATLAB function, but we
  33. compute coherence a lot in my lab, and we compute it for a lot of
  34. pairs. This function is optimized to do this efficiently by
  35. caching the direct FFTs.
  36. :func:`rk4`
  37. A 4th order Runge-Kutta ODE integrator in case you ever find
  38. yourself stranded without scipy (and the far superior
  39. scipy.integrate tools)
  40. :func:`contiguous_regions`
  41. Return the indices of the regions spanned by some logical mask
  42. :func:`cross_from_below`
  43. Return the indices where a 1D array crosses a threshold from below
  44. :func:`cross_from_above`
  45. Return the indices where a 1D array crosses a threshold from above
  46. :func:`complex_spectrum`
  47. Return the complex-valued frequency spectrum of a signal
  48. :func:`magnitude_spectrum`
  49. Return the magnitude of the frequency spectrum of a signal
  50. :func:`angle_spectrum`
  51. Return the angle (wrapped phase) of the frequency spectrum of a signal
  52. :func:`phase_spectrum`
  53. Return the phase (unwrapped angle) of the frequency spectrum of a signal
  54. :func:`detrend_mean`
  55. Remove the mean from a line.
  56. :func:`demean`
  57. Remove the mean from a line. This function is the same as
  58. :func:`detrend_mean` except for the default *axis*.
  59. :func:`detrend_linear`
  60. Remove the best fit line from a line.
  61. :func:`detrend_none`
  62. Return the original line.
  63. :func:`stride_windows`
  64. Get all windows in an array in a memory-efficient manner
  65. :func:`stride_repeat`
  66. Repeat an array in a memory-efficient manner
  67. :func:`apply_window`
  68. Apply a window along a given axis
  69. record array helper functions
  70. -----------------------------
  71. A collection of helper methods for numpyrecord arrays
  72. .. _htmlonly:
  73. See :ref:`misc-examples-index`
  74. :func:`rec2txt`
  75. Pretty print a record array
  76. :func:`rec2csv`
  77. Store record array in CSV file
  78. :func:`csv2rec`
  79. Import record array from CSV file with type inspection
  80. :func:`rec_append_fields`
  81. Adds field(s)/array(s) to record array
  82. :func:`rec_drop_fields`
  83. Drop fields from record array
  84. :func:`rec_join`
  85. Join two record arrays on sequence of fields
  86. :func:`recs_join`
  87. A simple join of multiple recarrays using a single column as a key
  88. :func:`rec_groupby`
  89. Summarize data by groups (similar to SQL GROUP BY)
  90. :func:`rec_summarize`
  91. Helper code to filter rec array fields into new fields
  92. For the rec viewer functions(e rec2csv), there are a bunch of Format
  93. objects you can pass into the functions that will do things like color
  94. negative values red, set percent formatting and scaling, etc.
  95. Example usage::
  96. r = csv2rec('somefile.csv', checkrows=0)
  97. formatd = dict(
  98. weight = FormatFloat(2),
  99. change = FormatPercent(2),
  100. cost = FormatThousands(2),
  101. )
  102. rec2excel(r, 'test.xls', formatd=formatd)
  103. rec2csv(r, 'test.csv', formatd=formatd)
  104. """
  105. import copy
  106. import csv
  107. import operator
  108. import os
  109. import warnings
  110. import numpy as np
  111. import matplotlib.cbook as cbook
  112. from matplotlib import docstring
  113. from matplotlib.path import Path
  114. import math
  115. @cbook.deprecated("2.2", alternative='numpy.logspace or numpy.geomspace')
  116. def logspace(xmin, xmax, N):
  117. '''
  118. Return N values logarithmically spaced between xmin and xmax.
  119. '''
  120. return np.exp(np.linspace(np.log(xmin), np.log(xmax), N))
  121. def window_hanning(x):
  122. '''
  123. Return x times the hanning window of len(x).
  124. See Also
  125. --------
  126. :func:`window_none`
  127. :func:`window_none` is another window algorithm.
  128. '''
  129. return np.hanning(len(x))*x
  130. def window_none(x):
  131. '''
  132. No window function; simply return x.
  133. See Also
  134. --------
  135. :func:`window_hanning`
  136. :func:`window_hanning` is another window algorithm.
  137. '''
  138. return x
  139. def apply_window(x, window, axis=0, return_window=None):
  140. '''
  141. Apply the given window to the given 1D or 2D array along the given axis.
  142. Parameters
  143. ----------
  144. x : 1D or 2D array or sequence
  145. Array or sequence containing the data.
  146. window : function or array.
  147. Either a function to generate a window or an array with length
  148. *x*.shape[*axis*]
  149. axis : integer
  150. The axis over which to do the repetition.
  151. Must be 0 or 1. The default is 0
  152. return_window : bool
  153. If true, also return the 1D values of the window that was applied
  154. '''
  155. x = np.asarray(x)
  156. if x.ndim < 1 or x.ndim > 2:
  157. raise ValueError('only 1D or 2D arrays can be used')
  158. if axis+1 > x.ndim:
  159. raise ValueError('axis(=%s) out of bounds' % axis)
  160. xshape = list(x.shape)
  161. xshapetarg = xshape.pop(axis)
  162. if cbook.iterable(window):
  163. if len(window) != xshapetarg:
  164. raise ValueError('The len(window) must be the same as the shape '
  165. 'of x for the chosen axis')
  166. windowVals = window
  167. else:
  168. windowVals = window(np.ones(xshapetarg, dtype=x.dtype))
  169. if x.ndim == 1:
  170. if return_window:
  171. return windowVals * x, windowVals
  172. else:
  173. return windowVals * x
  174. xshapeother = xshape.pop()
  175. otheraxis = (axis+1) % 2
  176. windowValsRep = stride_repeat(windowVals, xshapeother, axis=otheraxis)
  177. if return_window:
  178. return windowValsRep * x, windowVals
  179. else:
  180. return windowValsRep * x
  181. def detrend(x, key=None, axis=None):
  182. '''
  183. Return x with its trend removed.
  184. Parameters
  185. ----------
  186. x : array or sequence
  187. Array or sequence containing the data.
  188. key : [ 'default' | 'constant' | 'mean' | 'linear' | 'none'] or function
  189. Specifies the detrend algorithm to use. 'default' is 'mean', which is
  190. the same as :func:`detrend_mean`. 'constant' is the same. 'linear' is
  191. the same as :func:`detrend_linear`. 'none' is the same as
  192. :func:`detrend_none`. The default is 'mean'. See the corresponding
  193. functions for more details regarding the algorithms. Can also be a
  194. function that carries out the detrend operation.
  195. axis : integer
  196. The axis along which to do the detrending.
  197. See Also
  198. --------
  199. :func:`detrend_mean`
  200. :func:`detrend_mean` implements the 'mean' algorithm.
  201. :func:`detrend_linear`
  202. :func:`detrend_linear` implements the 'linear' algorithm.
  203. :func:`detrend_none`
  204. :func:`detrend_none` implements the 'none' algorithm.
  205. '''
  206. if key is None or key in ['constant', 'mean', 'default']:
  207. return detrend(x, key=detrend_mean, axis=axis)
  208. elif key == 'linear':
  209. return detrend(x, key=detrend_linear, axis=axis)
  210. elif key == 'none':
  211. return detrend(x, key=detrend_none, axis=axis)
  212. elif isinstance(key, str):
  213. raise ValueError("Unknown value for key %s, must be one of: "
  214. "'default', 'constant', 'mean', "
  215. "'linear', or a function" % key)
  216. if not callable(key):
  217. raise ValueError("Unknown value for key %s, must be one of: "
  218. "'default', 'constant', 'mean', "
  219. "'linear', or a function" % key)
  220. x = np.asarray(x)
  221. if axis is not None and axis+1 > x.ndim:
  222. raise ValueError('axis(=%s) out of bounds' % axis)
  223. if (axis is None and x.ndim == 0) or (not axis and x.ndim == 1):
  224. return key(x)
  225. # try to use the 'axis' argument if the function supports it,
  226. # otherwise use apply_along_axis to do it
  227. try:
  228. return key(x, axis=axis)
  229. except TypeError:
  230. return np.apply_along_axis(key, axis=axis, arr=x)
  231. def demean(x, axis=0):
  232. '''
  233. Return x minus its mean along the specified axis.
  234. Parameters
  235. ----------
  236. x : array or sequence
  237. Array or sequence containing the data
  238. Can have any dimensionality
  239. axis : integer
  240. The axis along which to take the mean. See numpy.mean for a
  241. description of this argument.
  242. See Also
  243. --------
  244. :func:`delinear`
  245. :func:`denone`
  246. :func:`delinear` and :func:`denone` are other detrend algorithms.
  247. :func:`detrend_mean`
  248. This function is the same as :func:`detrend_mean` except for the
  249. default *axis*.
  250. '''
  251. return detrend_mean(x, axis=axis)
  252. def detrend_mean(x, axis=None):
  253. '''
  254. Return x minus the mean(x).
  255. Parameters
  256. ----------
  257. x : array or sequence
  258. Array or sequence containing the data
  259. Can have any dimensionality
  260. axis : integer
  261. The axis along which to take the mean. See numpy.mean for a
  262. description of this argument.
  263. See Also
  264. --------
  265. :func:`demean`
  266. This function is the same as :func:`demean` except for the default
  267. *axis*.
  268. :func:`detrend_linear`
  269. :func:`detrend_none`
  270. :func:`detrend_linear` and :func:`detrend_none` are other detrend
  271. algorithms.
  272. :func:`detrend`
  273. :func:`detrend` is a wrapper around all the detrend algorithms.
  274. '''
  275. x = np.asarray(x)
  276. if axis is not None and axis+1 > x.ndim:
  277. raise ValueError('axis(=%s) out of bounds' % axis)
  278. return x - x.mean(axis, keepdims=True)
  279. def detrend_none(x, axis=None):
  280. '''
  281. Return x: no detrending.
  282. Parameters
  283. ----------
  284. x : any object
  285. An object containing the data
  286. axis : integer
  287. This parameter is ignored.
  288. It is included for compatibility with detrend_mean
  289. See Also
  290. --------
  291. :func:`denone`
  292. This function is the same as :func:`denone` except for the default
  293. *axis*, which has no effect.
  294. :func:`detrend_mean`
  295. :func:`detrend_linear`
  296. :func:`detrend_mean` and :func:`detrend_linear` are other detrend
  297. algorithms.
  298. :func:`detrend`
  299. :func:`detrend` is a wrapper around all the detrend algorithms.
  300. '''
  301. return x
  302. def detrend_linear(y):
  303. '''
  304. Return x minus best fit line; 'linear' detrending.
  305. Parameters
  306. ----------
  307. y : 0-D or 1-D array or sequence
  308. Array or sequence containing the data
  309. axis : integer
  310. The axis along which to take the mean. See numpy.mean for a
  311. description of this argument.
  312. See Also
  313. --------
  314. :func:`delinear`
  315. This function is the same as :func:`delinear` except for the default
  316. *axis*.
  317. :func:`detrend_mean`
  318. :func:`detrend_none`
  319. :func:`detrend_mean` and :func:`detrend_none` are other detrend
  320. algorithms.
  321. :func:`detrend`
  322. :func:`detrend` is a wrapper around all the detrend algorithms.
  323. '''
  324. # This is faster than an algorithm based on linalg.lstsq.
  325. y = np.asarray(y)
  326. if y.ndim > 1:
  327. raise ValueError('y cannot have ndim > 1')
  328. # short-circuit 0-D array.
  329. if not y.ndim:
  330. return np.array(0., dtype=y.dtype)
  331. x = np.arange(y.size, dtype=float)
  332. C = np.cov(x, y, bias=1)
  333. b = C[0, 1]/C[0, 0]
  334. a = y.mean() - b*x.mean()
  335. return y - (b*x + a)
  336. def stride_windows(x, n, noverlap=None, axis=0):
  337. '''
  338. Get all windows of x with length n as a single array,
  339. using strides to avoid data duplication.
  340. .. warning::
  341. It is not safe to write to the output array. Multiple
  342. elements may point to the same piece of memory,
  343. so modifying one value may change others.
  344. Parameters
  345. ----------
  346. x : 1D array or sequence
  347. Array or sequence containing the data.
  348. n : integer
  349. The number of data points in each window.
  350. noverlap : integer
  351. The overlap between adjacent windows.
  352. Default is 0 (no overlap)
  353. axis : integer
  354. The axis along which the windows will run.
  355. References
  356. ----------
  357. `stackoverflow: Rolling window for 1D arrays in Numpy?
  358. <http://stackoverflow.com/a/6811241>`_
  359. `stackoverflow: Using strides for an efficient moving average filter
  360. <http://stackoverflow.com/a/4947453>`_
  361. '''
  362. if noverlap is None:
  363. noverlap = 0
  364. if noverlap >= n:
  365. raise ValueError('noverlap must be less than n')
  366. if n < 1:
  367. raise ValueError('n cannot be less than 1')
  368. x = np.asarray(x)
  369. if x.ndim != 1:
  370. raise ValueError('only 1-dimensional arrays can be used')
  371. if n == 1 and noverlap == 0:
  372. if axis == 0:
  373. return x[np.newaxis]
  374. else:
  375. return x[np.newaxis].transpose()
  376. if n > x.size:
  377. raise ValueError('n cannot be greater than the length of x')
  378. # np.lib.stride_tricks.as_strided easily leads to memory corruption for
  379. # non integer shape and strides, i.e. noverlap or n. See #3845.
  380. noverlap = int(noverlap)
  381. n = int(n)
  382. step = n - noverlap
  383. if axis == 0:
  384. shape = (n, (x.shape[-1]-noverlap)//step)
  385. strides = (x.strides[0], step*x.strides[0])
  386. else:
  387. shape = ((x.shape[-1]-noverlap)//step, n)
  388. strides = (step*x.strides[0], x.strides[0])
  389. return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  390. def stride_repeat(x, n, axis=0):
  391. '''
  392. Repeat the values in an array in a memory-efficient manner. Array x is
  393. stacked vertically n times.
  394. .. warning::
  395. It is not safe to write to the output array. Multiple
  396. elements may point to the same piece of memory, so
  397. modifying one value may change others.
  398. Parameters
  399. ----------
  400. x : 1D array or sequence
  401. Array or sequence containing the data.
  402. n : integer
  403. The number of time to repeat the array.
  404. axis : integer
  405. The axis along which the data will run.
  406. References
  407. ----------
  408. `stackoverflow: Repeat NumPy array without replicating data?
  409. <http://stackoverflow.com/a/5568169>`_
  410. '''
  411. if axis not in [0, 1]:
  412. raise ValueError('axis must be 0 or 1')
  413. x = np.asarray(x)
  414. if x.ndim != 1:
  415. raise ValueError('only 1-dimensional arrays can be used')
  416. if n == 1:
  417. if axis == 0:
  418. return np.atleast_2d(x)
  419. else:
  420. return np.atleast_2d(x).T
  421. if n < 1:
  422. raise ValueError('n cannot be less than 1')
  423. # np.lib.stride_tricks.as_strided easily leads to memory corruption for
  424. # non integer shape and strides, i.e. n. See #3845.
  425. n = int(n)
  426. if axis == 0:
  427. shape = (n, x.size)
  428. strides = (0, x.strides[0])
  429. else:
  430. shape = (x.size, n)
  431. strides = (x.strides[0], 0)
  432. return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  433. def _spectral_helper(x, y=None, NFFT=None, Fs=None, detrend_func=None,
  434. window=None, noverlap=None, pad_to=None,
  435. sides=None, scale_by_freq=None, mode=None):
  436. '''
  437. This is a helper function that implements the commonality between the
  438. psd, csd, spectrogram and complex, magnitude, angle, and phase spectrums.
  439. It is *NOT* meant to be used outside of mlab and may change at any time.
  440. '''
  441. if y is None:
  442. # if y is None use x for y
  443. same_data = True
  444. else:
  445. # The checks for if y is x are so that we can use the same function to
  446. # implement the core of psd(), csd(), and spectrogram() without doing
  447. # extra calculations. We return the unaveraged Pxy, freqs, and t.
  448. same_data = y is x
  449. if Fs is None:
  450. Fs = 2
  451. if noverlap is None:
  452. noverlap = 0
  453. if detrend_func is None:
  454. detrend_func = detrend_none
  455. if window is None:
  456. window = window_hanning
  457. # if NFFT is set to None use the whole signal
  458. if NFFT is None:
  459. NFFT = 256
  460. if mode is None or mode == 'default':
  461. mode = 'psd'
  462. elif mode not in ['psd', 'complex', 'magnitude', 'angle', 'phase']:
  463. raise ValueError("Unknown value for mode %s, must be one of: "
  464. "'default', 'psd', 'complex', "
  465. "'magnitude', 'angle', 'phase'" % mode)
  466. if not same_data and mode != 'psd':
  467. raise ValueError("x and y must be equal if mode is not 'psd'")
  468. # Make sure we're dealing with a numpy array. If y and x were the same
  469. # object to start with, keep them that way
  470. x = np.asarray(x)
  471. if not same_data:
  472. y = np.asarray(y)
  473. if sides is None or sides == 'default':
  474. if np.iscomplexobj(x):
  475. sides = 'twosided'
  476. else:
  477. sides = 'onesided'
  478. elif sides not in ['onesided', 'twosided']:
  479. raise ValueError("Unknown value for sides %s, must be one of: "
  480. "'default', 'onesided', or 'twosided'" % sides)
  481. # zero pad x and y up to NFFT if they are shorter than NFFT
  482. if len(x) < NFFT:
  483. n = len(x)
  484. x = np.resize(x, (NFFT,))
  485. x[n:] = 0
  486. if not same_data and len(y) < NFFT:
  487. n = len(y)
  488. y = np.resize(y, (NFFT,))
  489. y[n:] = 0
  490. if pad_to is None:
  491. pad_to = NFFT
  492. if mode != 'psd':
  493. scale_by_freq = False
  494. elif scale_by_freq is None:
  495. scale_by_freq = True
  496. # For real x, ignore the negative frequencies unless told otherwise
  497. if sides == 'twosided':
  498. numFreqs = pad_to
  499. if pad_to % 2:
  500. freqcenter = (pad_to - 1)//2 + 1
  501. else:
  502. freqcenter = pad_to//2
  503. scaling_factor = 1.
  504. elif sides == 'onesided':
  505. if pad_to % 2:
  506. numFreqs = (pad_to + 1)//2
  507. else:
  508. numFreqs = pad_to//2 + 1
  509. scaling_factor = 2.
  510. result = stride_windows(x, NFFT, noverlap, axis=0)
  511. result = detrend(result, detrend_func, axis=0)
  512. result, windowVals = apply_window(result, window, axis=0,
  513. return_window=True)
  514. result = np.fft.fft(result, n=pad_to, axis=0)[:numFreqs, :]
  515. freqs = np.fft.fftfreq(pad_to, 1/Fs)[:numFreqs]
  516. if not same_data:
  517. # if same_data is False, mode must be 'psd'
  518. resultY = stride_windows(y, NFFT, noverlap)
  519. resultY = detrend(resultY, detrend_func, axis=0)
  520. resultY = apply_window(resultY, window, axis=0)
  521. resultY = np.fft.fft(resultY, n=pad_to, axis=0)[:numFreqs, :]
  522. result = np.conj(result) * resultY
  523. elif mode == 'psd':
  524. result = np.conj(result) * result
  525. elif mode == 'magnitude':
  526. result = np.abs(result) / np.abs(windowVals).sum()
  527. elif mode == 'angle' or mode == 'phase':
  528. # we unwrap the phase later to handle the onesided vs. twosided case
  529. result = np.angle(result)
  530. elif mode == 'complex':
  531. result /= np.abs(windowVals).sum()
  532. if mode == 'psd':
  533. # Also include scaling factors for one-sided densities and dividing by
  534. # the sampling frequency, if desired. Scale everything, except the DC
  535. # component and the NFFT/2 component:
  536. # if we have a even number of frequencies, don't scale NFFT/2
  537. if not NFFT % 2:
  538. slc = slice(1, -1, None)
  539. # if we have an odd number, just don't scale DC
  540. else:
  541. slc = slice(1, None, None)
  542. result[slc] *= scaling_factor
  543. # MATLAB divides by the sampling frequency so that density function
  544. # has units of dB/Hz and can be integrated by the plotted frequency
  545. # values. Perform the same scaling here.
  546. if scale_by_freq:
  547. result /= Fs
  548. # Scale the spectrum by the norm of the window to compensate for
  549. # windowing loss; see Bendat & Piersol Sec 11.5.2.
  550. result /= (np.abs(windowVals)**2).sum()
  551. else:
  552. # In this case, preserve power in the segment, not amplitude
  553. result /= np.abs(windowVals).sum()**2
  554. t = np.arange(NFFT/2, len(x) - NFFT/2 + 1, NFFT - noverlap)/Fs
  555. if sides == 'twosided':
  556. # center the frequency range at zero
  557. freqs = np.concatenate((freqs[freqcenter:], freqs[:freqcenter]))
  558. result = np.concatenate((result[freqcenter:, :],
  559. result[:freqcenter, :]), 0)
  560. elif not pad_to % 2:
  561. # get the last value correctly, it is negative otherwise
  562. freqs[-1] *= -1
  563. # we unwrap the phase here to handle the onesided vs. twosided case
  564. if mode == 'phase':
  565. result = np.unwrap(result, axis=0)
  566. return result, freqs, t
  567. def _single_spectrum_helper(x, mode, Fs=None, window=None, pad_to=None,
  568. sides=None):
  569. '''
  570. This is a helper function that implements the commonality between the
  571. complex, magnitude, angle, and phase spectrums.
  572. It is *NOT* meant to be used outside of mlab and may change at any time.
  573. '''
  574. if mode is None or mode == 'psd' or mode == 'default':
  575. raise ValueError('_single_spectrum_helper does not work with %s mode'
  576. % mode)
  577. if pad_to is None:
  578. pad_to = len(x)
  579. spec, freqs, _ = _spectral_helper(x=x, y=None, NFFT=len(x), Fs=Fs,
  580. detrend_func=detrend_none, window=window,
  581. noverlap=0, pad_to=pad_to,
  582. sides=sides,
  583. scale_by_freq=False,
  584. mode=mode)
  585. if mode != 'complex':
  586. spec = spec.real
  587. if spec.ndim == 2 and spec.shape[1] == 1:
  588. spec = spec[:, 0]
  589. return spec, freqs
  590. # Split out these keyword docs so that they can be used elsewhere
  591. docstring.interpd.update(Spectral=cbook.dedent("""
  592. Fs : scalar
  593. The sampling frequency (samples per time unit). It is used
  594. to calculate the Fourier frequencies, freqs, in cycles per time
  595. unit. The default value is 2.
  596. window : callable or ndarray
  597. A function or a vector of length *NFFT*. To create window
  598. vectors see :func:`window_hanning`, :func:`window_none`,
  599. :func:`numpy.blackman`, :func:`numpy.hamming`,
  600. :func:`numpy.bartlett`, :func:`scipy.signal`,
  601. :func:`scipy.signal.get_window`, etc. The default is
  602. :func:`window_hanning`. If a function is passed as the
  603. argument, it must take a data segment as an argument and
  604. return the windowed version of the segment.
  605. sides : {'default', 'onesided', 'twosided'}
  606. Specifies which sides of the spectrum to return. Default gives the
  607. default behavior, which returns one-sided for real data and both
  608. for complex data. 'onesided' forces the return of a one-sided
  609. spectrum, while 'twosided' forces two-sided.
  610. """))
  611. docstring.interpd.update(Single_Spectrum=cbook.dedent("""
  612. pad_to : int
  613. The number of points to which the data segment is padded when
  614. performing the FFT. While not increasing the actual resolution of
  615. the spectrum (the minimum distance between resolvable peaks),
  616. this can give more points in the plot, allowing for more
  617. detail. This corresponds to the *n* parameter in the call to fft().
  618. The default is None, which sets *pad_to* equal to the length of the
  619. input signal (i.e. no padding).
  620. """))
  621. docstring.interpd.update(PSD=cbook.dedent("""
  622. pad_to : int
  623. The number of points to which the data segment is padded when
  624. performing the FFT. This can be different from *NFFT*, which
  625. specifies the number of data points used. While not increasing
  626. the actual resolution of the spectrum (the minimum distance between
  627. resolvable peaks), this can give more points in the plot,
  628. allowing for more detail. This corresponds to the *n* parameter
  629. in the call to fft(). The default is None, which sets *pad_to*
  630. equal to *NFFT*
  631. NFFT : int
  632. The number of data points used in each block for the FFT.
  633. A power 2 is most efficient. The default value is 256.
  634. This should *NOT* be used to get zero padding, or the scaling of the
  635. result will be incorrect. Use *pad_to* for this instead.
  636. detrend : {'default', 'constant', 'mean', 'linear', 'none'} or callable
  637. The function applied to each segment before fft-ing,
  638. designed to remove the mean or linear trend. Unlike in
  639. MATLAB, where the *detrend* parameter is a vector, in
  640. matplotlib is it a function. The :mod:`~matplotlib.mlab`
  641. module defines :func:`~matplotlib.mlab.detrend_none`,
  642. :func:`~matplotlib.mlab.detrend_mean`, and
  643. :func:`~matplotlib.mlab.detrend_linear`, but you can use
  644. a custom function as well. You can also use a string to choose
  645. one of the functions. 'default', 'constant', and 'mean' call
  646. :func:`~matplotlib.mlab.detrend_mean`. 'linear' calls
  647. :func:`~matplotlib.mlab.detrend_linear`. 'none' calls
  648. :func:`~matplotlib.mlab.detrend_none`.
  649. scale_by_freq : bool, optional
  650. Specifies whether the resulting density values should be scaled
  651. by the scaling frequency, which gives density in units of Hz^-1.
  652. This allows for integration over the returned frequency values.
  653. The default is True for MATLAB compatibility.
  654. """))
  655. @docstring.dedent_interpd
  656. def psd(x, NFFT=None, Fs=None, detrend=None, window=None,
  657. noverlap=None, pad_to=None, sides=None, scale_by_freq=None):
  658. r"""
  659. Compute the power spectral density.
  660. Call signature::
  661. psd(x, NFFT=256, Fs=2, detrend=mlab.detrend_none,
  662. window=mlab.window_hanning, noverlap=0, pad_to=None,
  663. sides='default', scale_by_freq=None)
  664. The power spectral density :math:`P_{xx}` by Welch's average
  665. periodogram method. The vector *x* is divided into *NFFT* length
  666. segments. Each segment is detrended by function *detrend* and
  667. windowed by function *window*. *noverlap* gives the length of
  668. the overlap between segments. The :math:`|\mathrm{fft}(i)|^2`
  669. of each segment :math:`i` are averaged to compute :math:`P_{xx}`.
  670. If len(*x*) < *NFFT*, it will be zero padded to *NFFT*.
  671. Parameters
  672. ----------
  673. x : 1-D array or sequence
  674. Array or sequence containing the data
  675. %(Spectral)s
  676. %(PSD)s
  677. noverlap : integer
  678. The number of points of overlap between segments.
  679. The default value is 0 (no overlap).
  680. Returns
  681. -------
  682. Pxx : 1-D array
  683. The values for the power spectrum `P_{xx}` (real valued)
  684. freqs : 1-D array
  685. The frequencies corresponding to the elements in *Pxx*
  686. References
  687. ----------
  688. Bendat & Piersol -- Random Data: Analysis and Measurement Procedures, John
  689. Wiley & Sons (1986)
  690. See Also
  691. --------
  692. :func:`specgram`
  693. :func:`specgram` differs in the default overlap; in not returning the
  694. mean of the segment periodograms; and in returning the times of the
  695. segments.
  696. :func:`magnitude_spectrum`
  697. :func:`magnitude_spectrum` returns the magnitude spectrum.
  698. :func:`csd`
  699. :func:`csd` returns the spectral density between two signals.
  700. """
  701. Pxx, freqs = csd(x=x, y=None, NFFT=NFFT, Fs=Fs, detrend=detrend,
  702. window=window, noverlap=noverlap, pad_to=pad_to,
  703. sides=sides, scale_by_freq=scale_by_freq)
  704. return Pxx.real, freqs
  705. @docstring.dedent_interpd
  706. def csd(x, y, NFFT=None, Fs=None, detrend=None, window=None,
  707. noverlap=None, pad_to=None, sides=None, scale_by_freq=None):
  708. """
  709. Compute the cross-spectral density.
  710. Call signature::
  711. csd(x, y, NFFT=256, Fs=2, detrend=mlab.detrend_none,
  712. window=mlab.window_hanning, noverlap=0, pad_to=None,
  713. sides='default', scale_by_freq=None)
  714. The cross spectral density :math:`P_{xy}` by Welch's average
  715. periodogram method. The vectors *x* and *y* are divided into
  716. *NFFT* length segments. Each segment is detrended by function
  717. *detrend* and windowed by function *window*. *noverlap* gives
  718. the length of the overlap between segments. The product of
  719. the direct FFTs of *x* and *y* are averaged over each segment
  720. to compute :math:`P_{xy}`, with a scaling to correct for power
  721. loss due to windowing.
  722. If len(*x*) < *NFFT* or len(*y*) < *NFFT*, they will be zero
  723. padded to *NFFT*.
  724. Parameters
  725. ----------
  726. x, y : 1-D arrays or sequences
  727. Arrays or sequences containing the data
  728. %(Spectral)s
  729. %(PSD)s
  730. noverlap : integer
  731. The number of points of overlap between segments.
  732. The default value is 0 (no overlap).
  733. Returns
  734. -------
  735. Pxy : 1-D array
  736. The values for the cross spectrum `P_{xy}` before scaling (real valued)
  737. freqs : 1-D array
  738. The frequencies corresponding to the elements in *Pxy*
  739. References
  740. ----------
  741. Bendat & Piersol -- Random Data: Analysis and Measurement Procedures, John
  742. Wiley & Sons (1986)
  743. See Also
  744. --------
  745. :func:`psd`
  746. :func:`psd` is the equivalent to setting y=x.
  747. """
  748. if NFFT is None:
  749. NFFT = 256
  750. Pxy, freqs, _ = _spectral_helper(x=x, y=y, NFFT=NFFT, Fs=Fs,
  751. detrend_func=detrend, window=window,
  752. noverlap=noverlap, pad_to=pad_to,
  753. sides=sides, scale_by_freq=scale_by_freq,
  754. mode='psd')
  755. if Pxy.ndim == 2:
  756. if Pxy.shape[1] > 1:
  757. Pxy = Pxy.mean(axis=1)
  758. else:
  759. Pxy = Pxy[:, 0]
  760. return Pxy, freqs
  761. @docstring.dedent_interpd
  762. def complex_spectrum(x, Fs=None, window=None, pad_to=None,
  763. sides=None):
  764. """
  765. Compute the complex-valued frequency spectrum of *x*. Data is padded to a
  766. length of *pad_to* and the windowing function *window* is applied to the
  767. signal.
  768. Parameters
  769. ----------
  770. x : 1-D array or sequence
  771. Array or sequence containing the data
  772. %(Spectral)s
  773. %(Single_Spectrum)s
  774. Returns
  775. -------
  776. spectrum : 1-D array
  777. The values for the complex spectrum (complex valued)
  778. freqs : 1-D array
  779. The frequencies corresponding to the elements in *spectrum*
  780. See Also
  781. --------
  782. :func:`magnitude_spectrum`
  783. :func:`magnitude_spectrum` returns the absolute value of this function.
  784. :func:`angle_spectrum`
  785. :func:`angle_spectrum` returns the angle of this function.
  786. :func:`phase_spectrum`
  787. :func:`phase_spectrum` returns the phase (unwrapped angle) of this
  788. function.
  789. :func:`specgram`
  790. :func:`specgram` can return the complex spectrum of segments within the
  791. signal.
  792. """
  793. return _single_spectrum_helper(x=x, Fs=Fs, window=window, pad_to=pad_to,
  794. sides=sides, mode='complex')
  795. @docstring.dedent_interpd
  796. def magnitude_spectrum(x, Fs=None, window=None, pad_to=None,
  797. sides=None):
  798. """
  799. Compute the magnitude (absolute value) of the frequency spectrum of
  800. *x*. Data is padded to a length of *pad_to* and the windowing function
  801. *window* is applied to the signal.
  802. Parameters
  803. ----------
  804. x : 1-D array or sequence
  805. Array or sequence containing the data
  806. %(Spectral)s
  807. %(Single_Spectrum)s
  808. Returns
  809. -------
  810. spectrum : 1-D array
  811. The values for the magnitude spectrum (real valued)
  812. freqs : 1-D array
  813. The frequencies corresponding to the elements in *spectrum*
  814. See Also
  815. --------
  816. :func:`psd`
  817. :func:`psd` returns the power spectral density.
  818. :func:`complex_spectrum`
  819. This function returns the absolute value of :func:`complex_spectrum`.
  820. :func:`angle_spectrum`
  821. :func:`angle_spectrum` returns the angles of the corresponding
  822. frequencies.
  823. :func:`phase_spectrum`
  824. :func:`phase_spectrum` returns the phase (unwrapped angle) of the
  825. corresponding frequencies.
  826. :func:`specgram`
  827. :func:`specgram` can return the magnitude spectrum of segments within
  828. the signal.
  829. """
  830. return _single_spectrum_helper(x=x, Fs=Fs, window=window, pad_to=pad_to,
  831. sides=sides, mode='magnitude')
  832. @docstring.dedent_interpd
  833. def angle_spectrum(x, Fs=None, window=None, pad_to=None,
  834. sides=None):
  835. """
  836. Compute the angle of the frequency spectrum (wrapped phase spectrum) of
  837. *x*. Data is padded to a length of *pad_to* and the windowing function
  838. *window* is applied to the signal.
  839. Parameters
  840. ----------
  841. x : 1-D array or sequence
  842. Array or sequence containing the data
  843. %(Spectral)s
  844. %(Single_Spectrum)s
  845. Returns
  846. -------
  847. spectrum : 1-D array
  848. The values for the angle spectrum in radians (real valued)
  849. freqs : 1-D array
  850. The frequencies corresponding to the elements in *spectrum*
  851. See Also
  852. --------
  853. :func:`complex_spectrum`
  854. This function returns the angle value of :func:`complex_spectrum`.
  855. :func:`magnitude_spectrum`
  856. :func:`angle_spectrum` returns the magnitudes of the corresponding
  857. frequencies.
  858. :func:`phase_spectrum`
  859. :func:`phase_spectrum` returns the unwrapped version of this function.
  860. :func:`specgram`
  861. :func:`specgram` can return the angle spectrum of segments within the
  862. signal.
  863. """
  864. return _single_spectrum_helper(x=x, Fs=Fs, window=window, pad_to=pad_to,
  865. sides=sides, mode='angle')
  866. @docstring.dedent_interpd
  867. def phase_spectrum(x, Fs=None, window=None, pad_to=None,
  868. sides=None):
  869. """
  870. Compute the phase of the frequency spectrum (unwrapped angle spectrum) of
  871. *x*. Data is padded to a length of *pad_to* and the windowing function
  872. *window* is applied to the signal.
  873. Parameters
  874. ----------
  875. x : 1-D array or sequence
  876. Array or sequence containing the data
  877. %(Spectral)s
  878. %(Single_Spectrum)s
  879. Returns
  880. -------
  881. spectrum : 1-D array
  882. The values for the phase spectrum in radians (real valued)
  883. freqs : 1-D array
  884. The frequencies corresponding to the elements in *spectrum*
  885. See Also
  886. --------
  887. :func:`complex_spectrum`
  888. This function returns the angle value of :func:`complex_spectrum`.
  889. :func:`magnitude_spectrum`
  890. :func:`magnitude_spectrum` returns the magnitudes of the corresponding
  891. frequencies.
  892. :func:`angle_spectrum`
  893. :func:`angle_spectrum` returns the wrapped version of this function.
  894. :func:`specgram`
  895. :func:`specgram` can return the phase spectrum of segments within the
  896. signal.
  897. """
  898. return _single_spectrum_helper(x=x, Fs=Fs, window=window, pad_to=pad_to,
  899. sides=sides, mode='phase')
  900. @docstring.dedent_interpd
  901. def specgram(x, NFFT=None, Fs=None, detrend=None, window=None,
  902. noverlap=None, pad_to=None, sides=None, scale_by_freq=None,
  903. mode=None):
  904. """
  905. Compute a spectrogram.
  906. Compute and plot a spectrogram of data in x. Data are split into
  907. NFFT length segments and the spectrum of each section is
  908. computed. The windowing function window is applied to each
  909. segment, and the amount of overlap of each segment is
  910. specified with noverlap.
  911. Parameters
  912. ----------
  913. x : array_like
  914. 1-D array or sequence.
  915. %(Spectral)s
  916. %(PSD)s
  917. noverlap : int, optional
  918. The number of points of overlap between blocks. The default
  919. value is 128.
  920. mode : str, optional
  921. What sort of spectrum to use, default is 'psd'.
  922. 'psd'
  923. Returns the power spectral density.
  924. 'complex'
  925. Returns the complex-valued frequency spectrum.
  926. 'magnitude'
  927. Returns the magnitude spectrum.
  928. 'angle'
  929. Returns the phase spectrum without unwrapping.
  930. 'phase'
  931. Returns the phase spectrum with unwrapping.
  932. Returns
  933. -------
  934. spectrum : array_like
  935. 2-D array, columns are the periodograms of successive segments.
  936. freqs : array_like
  937. 1-D array, frequencies corresponding to the rows in *spectrum*.
  938. t : array_like
  939. 1-D array, the times corresponding to midpoints of segments
  940. (i.e the columns in *spectrum*).
  941. See Also
  942. --------
  943. psd : differs in the overlap and in the return values.
  944. complex_spectrum : similar, but with complex valued frequencies.
  945. magnitude_spectrum : similar single segment when mode is 'magnitude'.
  946. angle_spectrum : similar to single segment when mode is 'angle'.
  947. phase_spectrum : similar to single segment when mode is 'phase'.
  948. Notes
  949. -----
  950. detrend and scale_by_freq only apply when *mode* is set to 'psd'.
  951. """
  952. if noverlap is None:
  953. noverlap = 128 # default in _spectral_helper() is noverlap = 0
  954. if NFFT is None:
  955. NFFT = 256 # same default as in _spectral_helper()
  956. if len(x) <= NFFT:
  957. warnings.warn("Only one segment is calculated since parameter NFFT " +
  958. "(=%d) >= signal length (=%d)." % (NFFT, len(x)))
  959. spec, freqs, t = _spectral_helper(x=x, y=None, NFFT=NFFT, Fs=Fs,
  960. detrend_func=detrend, window=window,
  961. noverlap=noverlap, pad_to=pad_to,
  962. sides=sides,
  963. scale_by_freq=scale_by_freq,
  964. mode=mode)
  965. if mode != 'complex':
  966. spec = spec.real # Needed since helper implements generically
  967. return spec, freqs, t
  968. _coh_error = """Coherence is calculated by averaging over *NFFT*
  969. length segments. Your signal is too short for your choice of *NFFT*.
  970. """
  971. @docstring.dedent_interpd
  972. def cohere(x, y, NFFT=256, Fs=2, detrend=detrend_none, window=window_hanning,
  973. noverlap=0, pad_to=None, sides='default', scale_by_freq=None):
  974. """
  975. The coherence between *x* and *y*. Coherence is the normalized
  976. cross spectral density:
  977. .. math::
  978. C_{xy} = \\frac{|P_{xy}|^2}{P_{xx}P_{yy}}
  979. Parameters
  980. ----------
  981. x, y
  982. Array or sequence containing the data
  983. %(Spectral)s
  984. %(PSD)s
  985. noverlap : integer
  986. The number of points of overlap between blocks. The default value
  987. is 0 (no overlap).
  988. Returns
  989. -------
  990. The return value is the tuple (*Cxy*, *f*), where *f* are the
  991. frequencies of the coherence vector. For cohere, scaling the
  992. individual densities by the sampling frequency has no effect,
  993. since the factors cancel out.
  994. See Also
  995. --------
  996. :func:`psd`, :func:`csd` :
  997. For information about the methods used to compute :math:`P_{xy}`,
  998. :math:`P_{xx}` and :math:`P_{yy}`.
  999. """
  1000. if len(x) < 2 * NFFT:
  1001. raise ValueError(_coh_error)
  1002. Pxx, f = psd(x, NFFT, Fs, detrend, window, noverlap, pad_to, sides,
  1003. scale_by_freq)
  1004. Pyy, f = psd(y, NFFT, Fs, detrend, window, noverlap, pad_to, sides,
  1005. scale_by_freq)
  1006. Pxy, f = csd(x, y, NFFT, Fs, detrend, window, noverlap, pad_to, sides,
  1007. scale_by_freq)
  1008. Cxy = np.abs(Pxy) ** 2 / (Pxx * Pyy)
  1009. return Cxy, f
  1010. @cbook.deprecated('2.2')
  1011. def donothing_callback(*args):
  1012. pass
  1013. @cbook.deprecated('2.2', 'scipy.signal.coherence')
  1014. def cohere_pairs(X, ij, NFFT=256, Fs=2, detrend=detrend_none,
  1015. window=window_hanning, noverlap=0,
  1016. preferSpeedOverMemory=True,
  1017. progressCallback=donothing_callback,
  1018. returnPxx=False):
  1019. """
  1020. Compute the coherence and phase for all pairs *ij*, in *X*.
  1021. *X* is a *numSamples* * *numCols* array
  1022. *ij* is a list of tuples. Each tuple is a pair of indexes into
  1023. the columns of X for which you want to compute coherence. For
  1024. example, if *X* has 64 columns, and you want to compute all
  1025. nonredundant pairs, define *ij* as::
  1026. ij = []
  1027. for i in range(64):
  1028. for j in range(i+1,64):
  1029. ij.append( (i,j) )
  1030. *preferSpeedOverMemory* is an optional bool. Defaults to true. If
  1031. False, limits the caching by only making one, rather than two,
  1032. complex cache arrays. This is useful if memory becomes critical.
  1033. Even when *preferSpeedOverMemory* is False, :func:`cohere_pairs`
  1034. will still give significant performance gains over calling
  1035. :func:`cohere` for each pair, and will use subtantially less
  1036. memory than if *preferSpeedOverMemory* is True. In my tests with
  1037. a 43000,64 array over all nonredundant pairs,
  1038. *preferSpeedOverMemory* = True delivered a 33% performance boost
  1039. on a 1.7GHZ Athlon with 512MB RAM compared with
  1040. *preferSpeedOverMemory* = False. But both solutions were more
  1041. than 10x faster than naively crunching all possible pairs through
  1042. :func:`cohere`.
  1043. Returns
  1044. -------
  1045. Cxy : dictionary of (*i*, *j*) tuples -> coherence vector for
  1046. that pair. i.e., ``Cxy[(i,j) = cohere(X[:,i], X[:,j])``.
  1047. Number of dictionary keys is ``len(ij)``.
  1048. Phase : dictionary of phases of the cross spectral density at
  1049. each frequency for each pair. Keys are (*i*, *j*).
  1050. freqs : vector of frequencies, equal in length to either the
  1051. coherence or phase vectors for any (*i*, *j*) key.
  1052. e.g., to make a coherence Bode plot::
  1053. subplot(211)
  1054. plot( freqs, Cxy[(12,19)])
  1055. subplot(212)
  1056. plot( freqs, Phase[(12,19)])
  1057. For a large number of pairs, :func:`cohere_pairs` can be much more
  1058. efficient than just calling :func:`cohere` for each pair, because
  1059. it caches most of the intensive computations. If :math:`N` is the
  1060. number of pairs, this function is :math:`O(N)` for most of the
  1061. heavy lifting, whereas calling cohere for each pair is
  1062. :math:`O(N^2)`. However, because of the caching, it is also more
  1063. memory intensive, making 2 additional complex arrays with
  1064. approximately the same number of elements as *X*.
  1065. See :file:`test/cohere_pairs_test.py` in the src tree for an
  1066. example script that shows that this :func:`cohere_pairs` and
  1067. :func:`cohere` give the same results for a given pair.
  1068. See Also
  1069. --------
  1070. :func:`psd`
  1071. For information about the methods used to compute :math:`P_{xy}`,
  1072. :math:`P_{xx}` and :math:`P_{yy}`.
  1073. """
  1074. numRows, numCols = X.shape
  1075. # zero pad if X is too short
  1076. if numRows < NFFT:
  1077. tmp = X
  1078. X = np.zeros((NFFT, numCols), X.dtype)
  1079. X[:numRows, :] = tmp
  1080. del tmp
  1081. numRows, numCols = X.shape
  1082. # get all the columns of X that we are interested in by checking
  1083. # the ij tuples
  1084. allColumns = set()
  1085. for i, j in ij:
  1086. allColumns.add(i)
  1087. allColumns.add(j)
  1088. Ncols = len(allColumns)
  1089. # for real X, ignore the negative frequencies
  1090. if np.iscomplexobj(X):
  1091. numFreqs = NFFT
  1092. else:
  1093. numFreqs = NFFT//2+1
  1094. # cache the FFT of every windowed, detrended NFFT length segment
  1095. # of every channel. If preferSpeedOverMemory, cache the conjugate
  1096. # as well
  1097. if cbook.iterable(window):
  1098. if len(window) != NFFT:
  1099. raise ValueError("The length of the window must be equal to NFFT")
  1100. windowVals = window
  1101. else:
  1102. windowVals = window(np.ones(NFFT, X.dtype))
  1103. ind = list(range(0, numRows-NFFT+1, NFFT-noverlap))
  1104. numSlices = len(ind)
  1105. FFTSlices = {}
  1106. FFTConjSlices = {}
  1107. Pxx = {}
  1108. slices = range(numSlices)
  1109. normVal = np.linalg.norm(windowVals)**2
  1110. for iCol in allColumns:
  1111. progressCallback(i/Ncols, 'Cacheing FFTs')
  1112. Slices = np.zeros((numSlices, numFreqs), dtype=np.complex_)
  1113. for iSlice in slices:
  1114. thisSlice = X[ind[iSlice]:ind[iSlice]+NFFT, iCol]
  1115. thisSlice = windowVals*detrend(thisSlice)
  1116. Slices[iSlice, :] = np.fft.fft(thisSlice)[:numFreqs]
  1117. FFTSlices[iCol] = Slices
  1118. if preferSpeedOverMemory:
  1119. FFTConjSlices[iCol] = np.conj(Slices)
  1120. Pxx[iCol] = np.divide(np.mean(abs(Slices)**2, axis=0), normVal)
  1121. del Slices, ind, windowVals
  1122. # compute the coherences and phases for all pairs using the
  1123. # cached FFTs
  1124. Cxy = {}
  1125. Phase = {}
  1126. count = 0
  1127. N = len(ij)
  1128. for i, j in ij:
  1129. count += 1
  1130. if count % 10 == 0:
  1131. progressCallback(count/N, 'Computing coherences')
  1132. if preferSpeedOverMemory:
  1133. Pxy = FFTSlices[i] * FFTConjSlices[j]
  1134. else:
  1135. Pxy = FFTSlices[i] * np.conj(FFTSlices[j])
  1136. if numSlices > 1:
  1137. Pxy = np.mean(Pxy, axis=0)
  1138. # Pxy = np.divide(Pxy, normVal)
  1139. Pxy /= normVal
  1140. # Cxy[(i,j)] = np.divide(np.absolute(Pxy)**2, Pxx[i]*Pxx[j])
  1141. Cxy[i, j] = abs(Pxy)**2 / (Pxx[i]*Pxx[j])
  1142. Phase[i, j] = np.arctan2(Pxy.imag, Pxy.real)
  1143. freqs = Fs/NFFT*np.arange(numFreqs)
  1144. if returnPxx:
  1145. return Cxy, Phase, freqs, Pxx
  1146. else:
  1147. return Cxy, Phase, freqs
  1148. @cbook.deprecated('2.2', 'scipy.stats.entropy')
  1149. def entropy(y, bins):
  1150. r"""
  1151. Return the entropy of the data in *y* in units of nat.
  1152. .. math::
  1153. -\sum p_i \ln(p_i)
  1154. where :math:`p_i` is the probability of observing *y* in the
  1155. :math:`i^{th}` bin of *bins*. *bins* can be a number of bins or a
  1156. range of bins; see :func:`numpy.histogram`.
  1157. Compare *S* with analytic calculation for a Gaussian::
  1158. x = mu + sigma * randn(200000)
  1159. Sanalytic = 0.5 * ( 1.0 + log(2*pi*sigma**2.0) )
  1160. """
  1161. n, bins = np.histogram(y, bins)
  1162. n = n.astype(float)
  1163. n = np.take(n, np.nonzero(n)[0]) # get the positive
  1164. p = np.divide(n, len(y))
  1165. delta = bins[1] - bins[0]
  1166. S = -1.0 * np.sum(p * np.log(p)) + np.log(delta)
  1167. return S
  1168. @cbook.deprecated('2.2', 'scipy.stats.norm.pdf')
  1169. def normpdf(x, *args):
  1170. "Return the normal pdf evaluated at *x*; args provides *mu*, *sigma*"
  1171. mu, sigma = args
  1172. return 1./(np.sqrt(2*np.pi)*sigma)*np.exp(-0.5 * (1./sigma*(x - mu))**2)
  1173. @cbook.deprecated('2.2')
  1174. def find(condition):
  1175. "Return the indices where ravel(condition) is true"
  1176. res, = np.nonzero(np.ravel(condition))
  1177. return res
  1178. @cbook.deprecated('2.2')
  1179. def longest_contiguous_ones(x):
  1180. """
  1181. Return the indices of the longest stretch of contiguous ones in *x*,
  1182. assuming *x* is a vector of zeros and ones. If there are two
  1183. equally long stretches, pick the first.
  1184. """
  1185. x = np.ravel(x)
  1186. if len(x) == 0:
  1187. return np.array([])
  1188. ind = (x == 0).nonzero()[0]
  1189. if len(ind) == 0:
  1190. return np.arange(len(x))
  1191. if len(ind) == len(x):
  1192. return np.array([])
  1193. y = np.zeros((len(x)+2,), x.dtype)
  1194. y[1:-1] = x
  1195. dif = np.diff(y)
  1196. up = (dif == 1).nonzero()[0]
  1197. dn = (dif == -1).nonzero()[0]
  1198. i = (dn-up == max(dn - up)).nonzero()[0][0]
  1199. ind = np.arange(up[i], dn[i])
  1200. return ind
  1201. @cbook.deprecated('2.2')
  1202. def longest_ones(x):
  1203. '''alias for longest_contiguous_ones'''
  1204. return longest_contiguous_ones(x)
  1205. @cbook.deprecated('2.2')
  1206. class PCA(object):
  1207. def __init__(self, a, standardize=True):
  1208. """
  1209. compute the SVD of a and store data for PCA. Use project to
  1210. project the data onto a reduced set of dimensions
  1211. Parameters
  1212. ----------
  1213. a : np.ndarray
  1214. A numobservations x numdims array
  1215. standardize : bool
  1216. True if input data are to be standardized. If False, only centering
  1217. will be carried out.
  1218. Attributes
  1219. ----------
  1220. a
  1221. A centered unit sigma version of input ``a``.
  1222. numrows, numcols
  1223. The dimensions of ``a``.
  1224. mu
  1225. A numdims array of means of ``a``. This is the vector that points
  1226. to the origin of PCA space.
  1227. sigma
  1228. A numdims array of standard deviation of ``a``.
  1229. fracs
  1230. The proportion of variance of each of the principal components.
  1231. s
  1232. The actual eigenvalues of the decomposition.
  1233. Wt
  1234. The weight vector for projecting a numdims point or array into
  1235. PCA space.
  1236. Y
  1237. A projected into PCA space.
  1238. Notes
  1239. -----
  1240. The factor loadings are in the ``Wt`` factor, i.e., the factor loadings
  1241. for the first principal component are given by ``Wt[0]``. This row is
  1242. also the first eigenvector.
  1243. """
  1244. n, m = a.shape
  1245. if n < m:
  1246. raise RuntimeError('we assume data in a is organized with '
  1247. 'numrows>numcols')
  1248. self.numrows, self.numcols = n, m
  1249. self.mu = a.mean(axis=0)
  1250. self.sigma = a.std(axis=0)
  1251. self.standardize = standardize
  1252. a = self.center(a)
  1253. self.a = a
  1254. U, s, Vh = np.linalg.svd(a, full_matrices=False)
  1255. # Note: .H indicates the conjugate transposed / Hermitian.
  1256. # The SVD is commonly written as a = U s V.H.
  1257. # If U is a unitary matrix, it means that it satisfies U.H = inv(U).
  1258. # The rows of Vh are the eigenvectors of a.H a.
  1259. # The columns of U are the eigenvectors of a a.H.
  1260. # For row i in Vh and column i in U, the corresponding eigenvalue is
  1261. # s[i]**2.
  1262. self.Wt = Vh
  1263. # save the transposed coordinates
  1264. Y = np.dot(Vh, a.T).T
  1265. self.Y = Y
  1266. # save the eigenvalues
  1267. self.s = s**2
  1268. # and now the contribution of the individual components
  1269. vars = self.s / len(s)
  1270. self.fracs = vars/vars.sum()
  1271. def project(self, x, minfrac=0.):
  1272. '''
  1273. project x onto the principle axes, dropping any axes where fraction
  1274. of variance<minfrac
  1275. '''
  1276. x = np.asarray(x)
  1277. if x.shape[-1] != self.numcols:
  1278. raise ValueError('Expected an array with dims[-1]==%d' %
  1279. self.numcols)
  1280. Y = np.dot(self.Wt, self.center(x).T).T
  1281. mask = self.fracs >= minfrac
  1282. if x.ndim == 2:
  1283. Yreduced = Y[:, mask]
  1284. else:
  1285. Yreduced = Y[mask]
  1286. return Yreduced
  1287. def center(self, x):
  1288. '''
  1289. center and optionally standardize the data using the mean and sigma
  1290. from training set a
  1291. '''
  1292. if self.standardize:
  1293. return (x - self.mu)/self.sigma
  1294. else:
  1295. return (x - self.mu)
  1296. @staticmethod
  1297. def _get_colinear():
  1298. c0 = np.array([
  1299. 0.19294738, 0.6202667, 0.45962655, 0.07608613, 0.135818,
  1300. 0.83580842, 0.07218851, 0.48318321, 0.84472463, 0.18348462,
  1301. 0.81585306, 0.96923926, 0.12835919, 0.35075355, 0.15807861,
  1302. 0.837437, 0.10824303, 0.1723387, 0.43926494, 0.83705486])
  1303. c1 = np.array([
  1304. -1.17705601, -0.513883, -0.26614584, 0.88067144, 1.00474954,
  1305. -1.1616545, 0.0266109, 0.38227157, 1.80489433, 0.21472396,
  1306. -1.41920399, -2.08158544, -0.10559009, 1.68999268, 0.34847107,
  1307. -0.4685737, 1.23980423, -0.14638744, -0.35907697, 0.22442616])
  1308. c2 = c0 + 2*c1
  1309. c3 = -3*c0 + 4*c1
  1310. a = np.array([c3, c0, c1, c2]).T
  1311. return a
  1312. @cbook.deprecated('2.2', 'numpy.percentile')
  1313. def prctile(x, p=(0.0, 25.0, 50.0, 75.0, 100.0)):
  1314. """
  1315. Return the percentiles of *x*. *p* can either be a sequence of
  1316. percentile values or a scalar. If *p* is a sequence, the ith
  1317. element of the return sequence is the *p*(i)-th percentile of *x*.
  1318. If *p* is a scalar, the largest value of *x* less than or equal to
  1319. the *p* percentage point in the sequence is returned.
  1320. """
  1321. # This implementation derived from scipy.stats.scoreatpercentile
  1322. def _interpolate(a, b, fraction):
  1323. """Returns the point at the given fraction between a and b, where
  1324. 'fraction' must be between 0 and 1.
  1325. """
  1326. return a + (b - a) * fraction
  1327. per = np.array(p)
  1328. values = np.sort(x, axis=None)
  1329. idxs = per / 100 * (values.shape[0] - 1)
  1330. ai = idxs.astype(int)
  1331. bi = ai + 1
  1332. frac = idxs % 1
  1333. # handle cases where attempting to interpolate past last index
  1334. cond = bi >= len(values)
  1335. if per.ndim:
  1336. ai[cond] -= 1
  1337. bi[cond] -= 1
  1338. frac[cond] += 1
  1339. else:
  1340. if cond:
  1341. ai -= 1
  1342. bi -= 1
  1343. frac += 1
  1344. return _interpolate(values[ai], values[bi], frac)
  1345. @cbook.deprecated('2.2')
  1346. def prctile_rank(x, p):
  1347. """
  1348. Return the rank for each element in *x*, return the rank
  1349. 0..len(*p*). e.g., if *p* = (25, 50, 75), the return value will be a
  1350. len(*x*) array with values in [0,1,2,3] where 0 indicates the
  1351. value is less than the 25th percentile, 1 indicates the value is
  1352. >= the 25th and < 50th percentile, ... and 3 indicates the value
  1353. is above the 75th percentile cutoff.
  1354. *p* is either an array of percentiles in [0..100] or a scalar which
  1355. indicates how many quantiles of data you want ranked.
  1356. """
  1357. if not cbook.iterable(p):
  1358. p = np.arange(100.0/p, 100.0, 100.0/p)
  1359. else:
  1360. p = np.asarray(p)
  1361. if p.max() <= 1 or p.min() < 0 or p.max() > 100:
  1362. raise ValueError('percentiles should be in range 0..100, not 0..1')
  1363. ptiles = prctile(x, p)
  1364. return np.searchsorted(ptiles, x)
  1365. @cbook.deprecated('2.2')
  1366. def center_matrix(M, dim=0):
  1367. """
  1368. Return the matrix *M* with each row having zero mean and unit std.
  1369. If *dim* = 1 operate on columns instead of rows. (*dim* is
  1370. opposite to the numpy axis kwarg.)
  1371. """
  1372. M = np.asarray(M, float)
  1373. if dim:
  1374. M = (M - M.mean(axis=0)) / M.std(axis=0)
  1375. else:
  1376. M = (M - M.mean(axis=1)[:, np.newaxis])
  1377. M = M / M.std(axis=1)[:, np.newaxis]
  1378. return M
  1379. @cbook.deprecated('2.2', 'scipy.integrate.ode')
  1380. def rk4(derivs, y0, t):
  1381. """
  1382. Integrate 1D or ND system of ODEs using 4-th order Runge-Kutta.
  1383. This is a toy implementation which may be useful if you find
  1384. yourself stranded on a system w/o scipy. Otherwise use
  1385. :func:`scipy.integrate`.
  1386. Parameters
  1387. ----------
  1388. y0
  1389. initial state vector
  1390. t
  1391. sample times
  1392. derivs
  1393. returns the derivative of the system and has the
  1394. signature ``dy = derivs(yi, ti)``
  1395. Examples
  1396. --------
  1397. A 2D system::
  1398. def derivs6(x,t):
  1399. d1 = x[0] + 2*x[1]
  1400. d2 = -3*x[0] + 4*x[1]
  1401. return (d1, d2)
  1402. dt = 0.0005
  1403. t = arange(0.0, 2.0, dt)
  1404. y0 = (1,2)
  1405. yout = rk4(derivs6, y0, t)
  1406. A 1D system::
  1407. alpha = 2
  1408. def derivs(x,t):
  1409. return -alpha*x + exp(-t)
  1410. y0 = 1
  1411. yout = rk4(derivs, y0, t)
  1412. If you have access to scipy, you should probably be using the
  1413. scipy.integrate tools rather than this function.
  1414. """
  1415. try:
  1416. Ny = len(y0)
  1417. except TypeError:
  1418. yout = np.zeros((len(t),), float)
  1419. else:
  1420. yout = np.zeros((len(t), Ny), float)
  1421. yout[0] = y0
  1422. i = 0
  1423. for i in np.arange(len(t)-1):
  1424. thist = t[i]
  1425. dt = t[i+1] - thist
  1426. dt2 = dt/2.0
  1427. y0 = yout[i]
  1428. k1 = np.asarray(derivs(y0, thist))
  1429. k2 = np.asarray(derivs(y0 + dt2*k1, thist+dt2))
  1430. k3 = np.asarray(derivs(y0 + dt2*k2, thist+dt2))
  1431. k4 = np.asarray(derivs(y0 + dt*k3, thist+dt))
  1432. yout[i+1] = y0 + dt/6.0*(k1 + 2*k2 + 2*k3 + k4)
  1433. return yout
  1434. @cbook.deprecated('2.2')
  1435. def bivariate_normal(X, Y, sigmax=1.0, sigmay=1.0,
  1436. mux=0.0, muy=0.0, sigmaxy=0.0):
  1437. """
  1438. Bivariate Gaussian distribution for equal shape *X*, *Y*.
  1439. See `bivariate normal
  1440. <http://mathworld.wolfram.com/BivariateNormalDistribution.html>`_
  1441. at mathworld.
  1442. """
  1443. Xmu = X-mux
  1444. Ymu = Y-muy
  1445. rho = sigmaxy/(sigmax*sigmay)
  1446. z = Xmu**2/sigmax**2 + Ymu**2/sigmay**2 - 2*rho*Xmu*Ymu/(sigmax*sigmay)
  1447. denom = 2*np.pi*sigmax*sigmay*np.sqrt(1-rho**2)
  1448. return np.exp(-z/(2*(1-rho**2))) / denom
  1449. @cbook.deprecated('2.2')
  1450. def get_xyz_where(Z, Cond):
  1451. """
  1452. *Z* and *Cond* are *M* x *N* matrices. *Z* are data and *Cond* is
  1453. a boolean matrix where some condition is satisfied. Return value
  1454. is (*x*, *y*, *z*) where *x* and *y* are the indices into *Z* and
  1455. *z* are the values of *Z* at those indices. *x*, *y*, and *z* are
  1456. 1D arrays.
  1457. """
  1458. X, Y = np.indices(Z.shape)
  1459. return X[Cond], Y[Cond], Z[Cond]
  1460. @cbook.deprecated('2.2')
  1461. def get_sparse_matrix(M, N, frac=0.1):
  1462. """
  1463. Return a *M* x *N* sparse matrix with *frac* elements randomly
  1464. filled.
  1465. """
  1466. data = np.zeros((M, N))*0.
  1467. for i in range(int(M*N*frac)):
  1468. x = np.random.randint(0, M-1)
  1469. y = np.random.randint(0, N-1)
  1470. data[x, y] = np.random.rand()
  1471. return data
  1472. @cbook.deprecated('2.2', 'numpy.hypot')
  1473. def dist(x, y):
  1474. """
  1475. Return the distance between two points.
  1476. """
  1477. d = x-y
  1478. return np.sqrt(np.dot(d, d))
  1479. @cbook.deprecated('2.2')
  1480. def dist_point_to_segment(p, s0, s1):
  1481. """
  1482. Get the distance of a point to a segment.
  1483. *p*, *s0*, *s1* are *xy* sequences
  1484. This algorithm from
  1485. http://geomalgorithms.com/a02-_lines.html
  1486. """
  1487. p = np.asarray(p, float)
  1488. s0 = np.asarray(s0, float)
  1489. s1 = np.asarray(s1, float)
  1490. v = s1 - s0
  1491. w = p - s0
  1492. c1 = np.dot(w, v)
  1493. if c1 <= 0:
  1494. return dist(p, s0)
  1495. c2 = np.dot(v, v)
  1496. if c2 <= c1:
  1497. return dist(p, s1)
  1498. b = c1 / c2
  1499. pb = s0 + b * v
  1500. return dist(p, pb)
  1501. @cbook.deprecated('2.2')
  1502. def segments_intersect(s1, s2):
  1503. """
  1504. Return *True* if *s1* and *s2* intersect.
  1505. *s1* and *s2* are defined as::
  1506. s1: (x1, y1), (x2, y2)
  1507. s2: (x3, y3), (x4, y4)
  1508. """
  1509. (x1, y1), (x2, y2) = s1
  1510. (x3, y3), (x4, y4) = s2
  1511. den = ((y4-y3) * (x2-x1)) - ((x4-x3)*(y2-y1))
  1512. n1 = ((x4-x3) * (y1-y3)) - ((y4-y3)*(x1-x3))
  1513. n2 = ((x2-x1) * (y1-y3)) - ((y2-y1)*(x1-x3))
  1514. if den == 0:
  1515. # lines parallel
  1516. return False
  1517. u1 = n1/den
  1518. u2 = n2/den
  1519. return 0.0 <= u1 <= 1.0 and 0.0 <= u2 <= 1.0
  1520. @cbook.deprecated('2.2')
  1521. def fftsurr(x, detrend=detrend_none, window=window_none):
  1522. """
  1523. Compute an FFT phase randomized surrogate of *x*.
  1524. """
  1525. if cbook.iterable(window):
  1526. x = window*detrend(x)
  1527. else:
  1528. x = window(detrend(x))
  1529. z = np.fft.fft(x)
  1530. a = 2.*np.pi*1j
  1531. phase = a * np.random.rand(len(x))
  1532. z = z*np.exp(phase)
  1533. return np.fft.ifft(z).real
  1534. @cbook.deprecated('2.2')
  1535. def movavg(x, n):
  1536. """
  1537. Compute the len(*n*) moving average of *x*.
  1538. """
  1539. w = np.empty((n,), dtype=float)
  1540. w[:] = 1.0/n
  1541. return np.convolve(x, w, mode='valid')
  1542. # the following code was written and submitted by Fernando Perez
  1543. # from the ipython numutils package under a BSD license
  1544. # begin fperez functions
  1545. """
  1546. A set of convenient utilities for numerical work.
  1547. Most of this module requires numpy or is meant to be used with it.
  1548. Copyright (c) 2001-2004, Fernando Perez. <Fernando.Perez@colorado.edu>
  1549. All rights reserved.
  1550. This license was generated from the BSD license template as found in:
  1551. http://www.opensource.org/licenses/bsd-license.php
  1552. Redistribution and use in source and binary forms, with or without
  1553. modification, are permitted provided that the following conditions are met:
  1554. * Redistributions of source code must retain the above copyright notice,
  1555. this list of conditions and the following disclaimer.
  1556. * Redistributions in binary form must reproduce the above copyright
  1557. notice, this list of conditions and the following disclaimer in the
  1558. documentation and/or other materials provided with the distribution.
  1559. * Neither the name of the IPython project nor the names of its
  1560. contributors may be used to endorse or promote products derived from
  1561. this software without specific prior written permission.
  1562. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  1563. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  1564. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  1565. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
  1566. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  1567. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  1568. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  1569. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  1570. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  1571. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  1572. """
  1573. # *****************************************************************************
  1574. # Globals
  1575. # ****************************************************************************
  1576. # function definitions
  1577. exp_safe_MIN = math.log(2.2250738585072014e-308)
  1578. exp_safe_MAX = 1.7976931348623157e+308
  1579. @cbook.deprecated("2.2", 'numpy.exp')
  1580. def exp_safe(x):
  1581. """
  1582. Compute exponentials which safely underflow to zero.
  1583. Slow, but convenient to use. Note that numpy provides proper
  1584. floating point exception handling with access to the underlying
  1585. hardware.
  1586. """
  1587. if type(x) is np.ndarray:
  1588. return np.exp(np.clip(x, exp_safe_MIN, exp_safe_MAX))
  1589. else:
  1590. return math.exp(x)
  1591. @cbook.deprecated("2.2", alternative='numpy.array(list(map(...)))')
  1592. def amap(fn, *args):
  1593. """
  1594. amap(function, sequence[, sequence, ...]) -> array.
  1595. Works like :func:`map`, but it returns an array. This is just a
  1596. convenient shorthand for ``numpy.array(map(...))``.
  1597. """
  1598. return np.array(list(map(fn, *args)))
  1599. @cbook.deprecated("2.2")
  1600. def rms_flat(a):
  1601. """
  1602. Return the root mean square of all the elements of *a*, flattened out.
  1603. """
  1604. return np.sqrt(np.mean(np.abs(a) ** 2))
  1605. @cbook.deprecated("2.2", alternative='numpy.linalg.norm(a, ord=1)')
  1606. def l1norm(a):
  1607. """
  1608. Return the *l1* norm of *a*, flattened out.
  1609. Implemented as a separate function (not a call to :func:`norm` for speed).
  1610. """
  1611. return np.sum(np.abs(a))
  1612. @cbook.deprecated("2.2", alternative='numpy.linalg.norm(a, ord=2)')
  1613. def l2norm(a):
  1614. """
  1615. Return the *l2* norm of *a*, flattened out.
  1616. Implemented as a separate function (not a call to :func:`norm` for speed).
  1617. """
  1618. return np.sqrt(np.sum(np.abs(a) ** 2))
  1619. @cbook.deprecated("2.2", alternative='numpy.linalg.norm(a.flat, ord=p)')
  1620. def norm_flat(a, p=2):
  1621. """
  1622. norm(a,p=2) -> l-p norm of a.flat
  1623. Return the l-p norm of *a*, considered as a flat array. This is NOT a true
  1624. matrix norm, since arrays of arbitrary rank are always flattened.
  1625. *p* can be a number or the string 'Infinity' to get the L-infinity norm.
  1626. """
  1627. # This function was being masked by a more general norm later in
  1628. # the file. We may want to simply delete it.
  1629. if p == 'Infinity':
  1630. return np.max(np.abs(a))
  1631. else:
  1632. return np.sum(np.abs(a) ** p) ** (1 / p)
  1633. @cbook.deprecated("2.2", 'numpy.arange')
  1634. def frange(xini, xfin=None, delta=None, **kw):
  1635. """
  1636. frange([start,] stop[, step, keywords]) -> array of floats
  1637. Return a numpy ndarray containing a progression of floats. Similar to
  1638. :func:`numpy.arange`, but defaults to a closed interval.
  1639. ``frange(x0, x1)`` returns ``[x0, x0+1, x0+2, ..., x1]``; *start*
  1640. defaults to 0, and the endpoint *is included*. This behavior is
  1641. different from that of :func:`range` and
  1642. :func:`numpy.arange`. This is deliberate, since :func:`frange`
  1643. will probably be more useful for generating lists of points for
  1644. function evaluation, and endpoints are often desired in this
  1645. use. The usual behavior of :func:`range` can be obtained by
  1646. setting the keyword *closed* = 0, in this case, :func:`frange`
  1647. basically becomes :func:numpy.arange`.
  1648. When *step* is given, it specifies the increment (or
  1649. decrement). All arguments can be floating point numbers.
  1650. ``frange(x0,x1,d)`` returns ``[x0,x0+d,x0+2d,...,xfin]`` where
  1651. *xfin* <= *x1*.
  1652. :func:`frange` can also be called with the keyword *npts*. This
  1653. sets the number of points the list should contain (and overrides
  1654. the value *step* might have been given). :func:`numpy.arange`
  1655. doesn't offer this option.
  1656. Examples::
  1657. >>> frange(3)
  1658. array([ 0., 1., 2., 3.])
  1659. >>> frange(3,closed=0)
  1660. array([ 0., 1., 2.])
  1661. >>> frange(1,6,2)
  1662. array([1, 3, 5]) or 1,3,5,7, depending on floating point vagueries
  1663. >>> frange(1,6.5,npts=5)
  1664. array([ 1. , 2.375, 3.75 , 5.125, 6.5 ])
  1665. """
  1666. # defaults
  1667. kw.setdefault('closed', 1)
  1668. endpoint = kw['closed'] != 0
  1669. # funny logic to allow the *first* argument to be optional (like range())
  1670. # This was modified with a simpler version from a similar frange() found
  1671. # at http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66472
  1672. if xfin is None:
  1673. xfin = xini + 0.0
  1674. xini = 0.0
  1675. if delta is None:
  1676. delta = 1.0
  1677. # compute # of points, spacing and return final list
  1678. try:
  1679. npts = kw['npts']
  1680. delta = (xfin-xini) / (npts-endpoint)
  1681. except KeyError:
  1682. npts = int(np.round((xfin-xini)/delta)) + endpoint
  1683. # round finds the nearest, so the endpoint can be up to
  1684. # delta/2 larger than xfin.
  1685. return np.arange(npts)*delta+xini
  1686. # end frange()
  1687. @cbook.deprecated("2.2", 'numpy.identity')
  1688. def identity(n, rank=2, dtype='l', typecode=None):
  1689. """
  1690. Returns the identity matrix of shape (*n*, *n*, ..., *n*) (rank *r*).
  1691. For ranks higher than 2, this object is simply a multi-index Kronecker
  1692. delta::
  1693. / 1 if i0=i1=...=iR,
  1694. id[i0,i1,...,iR] = -|
  1695. \\ 0 otherwise.
  1696. Optionally a *dtype* (or typecode) may be given (it defaults to 'l').
  1697. Since rank defaults to 2, this function behaves in the default case (when
  1698. only *n* is given) like ``numpy.identity(n)`` -- but surprisingly, it is
  1699. much faster.
  1700. """
  1701. if typecode is not None:
  1702. dtype = typecode
  1703. iden = np.zeros((n,)*rank, dtype)
  1704. for i in range(n):
  1705. idx = (i,)*rank
  1706. iden[idx] = 1
  1707. return iden
  1708. @cbook.deprecated("2.2")
  1709. def base_repr(number, base=2, padding=0):
  1710. """
  1711. Return the representation of a *number* in any given *base*.
  1712. """
  1713. chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
  1714. if number < base:
  1715. return (padding - 1) * chars[0] + chars[int(number)]
  1716. max_exponent = int(math.log(number)/math.log(base))
  1717. max_power = int(base) ** max_exponent
  1718. lead_digit = int(number/max_power)
  1719. return (chars[lead_digit] +
  1720. base_repr(number - max_power * lead_digit, base,
  1721. max(padding - 1, max_exponent)))
  1722. @cbook.deprecated("2.2")
  1723. def binary_repr(number, max_length=1025):
  1724. """
  1725. Return the binary representation of the input *number* as a
  1726. string.
  1727. This is more efficient than using :func:`base_repr` with base 2.
  1728. Increase the value of max_length for very large numbers. Note that
  1729. on 32-bit machines, 2**1023 is the largest integer power of 2
  1730. which can be converted to a Python float.
  1731. """
  1732. # assert number < 2L << max_length
  1733. shifts = map(operator.rshift, max_length * [number],
  1734. range(max_length - 1, -1, -1))
  1735. digits = list(map(operator.mod, shifts, max_length * [2]))
  1736. if not digits.count(1):
  1737. return 0
  1738. digits = digits[digits.index(1):]
  1739. return ''.join(map(repr, digits)).replace('L', '')
  1740. @cbook.deprecated("2.2", 'numpy.log2')
  1741. def log2(x, ln2=math.log(2.0)):
  1742. """
  1743. Return the log(*x*) in base 2.
  1744. This is a _slow_ function but which is guaranteed to return the correct
  1745. integer value if the input is an integer exact power of 2.
  1746. """
  1747. try:
  1748. bin_n = binary_repr(x)[1:]
  1749. except (AssertionError, TypeError):
  1750. return math.log(x)/ln2
  1751. else:
  1752. if '1' in bin_n:
  1753. return math.log(x)/ln2
  1754. else:
  1755. return len(bin_n)
  1756. @cbook.deprecated("2.2")
  1757. def ispower2(n):
  1758. """
  1759. Returns the log base 2 of *n* if *n* is a power of 2, zero otherwise.
  1760. Note the potential ambiguity if *n* == 1: 2**0 == 1, interpret accordingly.
  1761. """
  1762. bin_n = binary_repr(n)[1:]
  1763. if '1' in bin_n:
  1764. return 0
  1765. else:
  1766. return len(bin_n)
  1767. @cbook.deprecated("2.2")
  1768. def isvector(X):
  1769. """
  1770. Like the MATLAB function with the same name, returns *True*
  1771. if the supplied numpy array or matrix *X* looks like a vector,
  1772. meaning it has a one non-singleton axis (i.e., it can have
  1773. multiple axes, but all must have length 1, except for one of
  1774. them).
  1775. If you just want to see if the array has 1 axis, use X.ndim == 1.
  1776. """
  1777. return np.prod(X.shape) == np.max(X.shape)
  1778. # end fperez numutils code
  1779. # helpers for loading, saving, manipulating and viewing numpy record arrays
  1780. @cbook.deprecated("2.2", 'numpy.isnan')
  1781. def safe_isnan(x):
  1782. ':func:`numpy.isnan` for arbitrary types'
  1783. if isinstance(x, str):
  1784. return False
  1785. try:
  1786. b = np.isnan(x)
  1787. except NotImplementedError:
  1788. return False
  1789. except TypeError:
  1790. return False
  1791. else:
  1792. return b
  1793. @cbook.deprecated("2.2", 'numpy.isinf')
  1794. def safe_isinf(x):
  1795. ':func:`numpy.isinf` for arbitrary types'
  1796. if isinstance(x, str):
  1797. return False
  1798. try:
  1799. b = np.isinf(x)
  1800. except NotImplementedError:
  1801. return False
  1802. except TypeError:
  1803. return False
  1804. else:
  1805. return b
  1806. @cbook.deprecated("2.2")
  1807. def rec_append_fields(rec, names, arrs, dtypes=None):
  1808. """
  1809. Return a new record array with field names populated with data
  1810. from arrays in *arrs*. If appending a single field, then *names*,
  1811. *arrs* and *dtypes* do not have to be lists. They can just be the
  1812. values themselves.
  1813. """
  1814. if (not isinstance(names, str) and cbook.iterable(names)
  1815. and len(names) and isinstance(names[0], str)):
  1816. if len(names) != len(arrs):
  1817. raise ValueError("number of arrays do not match number of names")
  1818. else: # we have only 1 name and 1 array
  1819. names = [names]
  1820. arrs = [arrs]
  1821. arrs = list(map(np.asarray, arrs))
  1822. if dtypes is None:
  1823. dtypes = [a.dtype for a in arrs]
  1824. elif not cbook.iterable(dtypes):
  1825. dtypes = [dtypes]
  1826. if len(arrs) != len(dtypes):
  1827. if len(dtypes) == 1:
  1828. dtypes = dtypes * len(arrs)
  1829. else:
  1830. raise ValueError("dtypes must be None, a single dtype or a list")
  1831. old_dtypes = rec.dtype.descr
  1832. newdtype = np.dtype(old_dtypes + list(zip(names, dtypes)))
  1833. newrec = np.recarray(rec.shape, dtype=newdtype)
  1834. for field in rec.dtype.fields:
  1835. newrec[field] = rec[field]
  1836. for name, arr in zip(names, arrs):
  1837. newrec[name] = arr
  1838. return newrec
  1839. @cbook.deprecated("2.2")
  1840. def rec_drop_fields(rec, names):
  1841. """
  1842. Return a new numpy record array with fields in *names* dropped.
  1843. """
  1844. names = set(names)
  1845. newdtype = np.dtype([(name, rec.dtype[name]) for name in rec.dtype.names
  1846. if name not in names])
  1847. newrec = np.recarray(rec.shape, dtype=newdtype)
  1848. for field in newdtype.names:
  1849. newrec[field] = rec[field]
  1850. return newrec
  1851. @cbook.deprecated("2.2")
  1852. def rec_keep_fields(rec, names):
  1853. """
  1854. Return a new numpy record array with only fields listed in names
  1855. """
  1856. if isinstance(names, str):
  1857. names = names.split(',')
  1858. arrays = []
  1859. for name in names:
  1860. arrays.append(rec[name])
  1861. return np.rec.fromarrays(arrays, names=names)
  1862. @cbook.deprecated("2.2")
  1863. def rec_groupby(r, groupby, stats):
  1864. """
  1865. *r* is a numpy record array
  1866. *groupby* is a sequence of record array attribute names that
  1867. together form the grouping key. e.g., ('date', 'productcode')
  1868. *stats* is a sequence of (*attr*, *func*, *outname*) tuples which
  1869. will call ``x = func(attr)`` and assign *x* to the record array
  1870. output with attribute *outname*. For example::
  1871. stats = ( ('sales', len, 'numsales'), ('sales', np.mean, 'avgsale') )
  1872. Return record array has *dtype* names for each attribute name in
  1873. the *groupby* argument, with the associated group values, and
  1874. for each outname name in the *stats* argument, with the associated
  1875. stat summary output.
  1876. """
  1877. # build a dictionary from groupby keys-> list of indices into r with
  1878. # those keys
  1879. rowd = {}
  1880. for i, row in enumerate(r):
  1881. key = tuple([row[attr] for attr in groupby])
  1882. rowd.setdefault(key, []).append(i)
  1883. rows = []
  1884. # sort the output by groupby keys
  1885. for key in sorted(rowd):
  1886. row = list(key)
  1887. # get the indices for this groupby key
  1888. ind = rowd[key]
  1889. thisr = r[ind]
  1890. # call each stat function for this groupby slice
  1891. row.extend([func(thisr[attr]) for attr, func, outname in stats])
  1892. rows.append(row)
  1893. # build the output record array with groupby and outname attributes
  1894. attrs, funcs, outnames = list(zip(*stats))
  1895. names = list(groupby)
  1896. names.extend(outnames)
  1897. return np.rec.fromrecords(rows, names=names)
  1898. @cbook.deprecated("2.2")
  1899. def rec_summarize(r, summaryfuncs):
  1900. """
  1901. *r* is a numpy record array
  1902. *summaryfuncs* is a list of (*attr*, *func*, *outname*) tuples
  1903. which will apply *func* to the array *r*[attr] and assign the
  1904. output to a new attribute name *outname*. The returned record
  1905. array is identical to *r*, with extra arrays for each element in
  1906. *summaryfuncs*.
  1907. """
  1908. names = list(r.dtype.names)
  1909. arrays = [r[name] for name in names]
  1910. for attr, func, outname in summaryfuncs:
  1911. names.append(outname)
  1912. arrays.append(np.asarray(func(r[attr])))
  1913. return np.rec.fromarrays(arrays, names=names)
  1914. @cbook.deprecated("2.2")
  1915. def rec_join(key, r1, r2, jointype='inner', defaults=None, r1postfix='1',
  1916. r2postfix='2'):
  1917. """
  1918. Join record arrays *r1* and *r2* on *key*; *key* is a tuple of
  1919. field names -- if *key* is a string it is assumed to be a single
  1920. attribute name. If *r1* and *r2* have equal values on all the keys
  1921. in the *key* tuple, then their fields will be merged into a new
  1922. record array containing the intersection of the fields of *r1* and
  1923. *r2*.
  1924. *r1* (also *r2*) must not have any duplicate keys.
  1925. The *jointype* keyword can be 'inner', 'outer', 'leftouter'. To
  1926. do a rightouter join just reverse *r1* and *r2*.
  1927. The *defaults* keyword is a dictionary filled with
  1928. ``{column_name:default_value}`` pairs.
  1929. The keywords *r1postfix* and *r2postfix* are postfixed to column names
  1930. (other than keys) that are both in *r1* and *r2*.
  1931. """
  1932. if isinstance(key, str):
  1933. key = (key, )
  1934. for name in key:
  1935. if name not in r1.dtype.names:
  1936. raise ValueError('r1 does not have key field %s' % name)
  1937. if name not in r2.dtype.names:
  1938. raise ValueError('r2 does not have key field %s' % name)
  1939. def makekey(row):
  1940. return tuple([row[name] for name in key])
  1941. r1d = {makekey(row): i for i, row in enumerate(r1)}
  1942. r2d = {makekey(row): i for i, row in enumerate(r2)}
  1943. r1keys = set(r1d)
  1944. r2keys = set(r2d)
  1945. common_keys = r1keys & r2keys
  1946. r1ind = np.array([r1d[k] for k in common_keys])
  1947. r2ind = np.array([r2d[k] for k in common_keys])
  1948. common_len = len(common_keys)
  1949. left_len = right_len = 0
  1950. if jointype == "outer" or jointype == "leftouter":
  1951. left_keys = r1keys.difference(r2keys)
  1952. left_ind = np.array([r1d[k] for k in left_keys])
  1953. left_len = len(left_ind)
  1954. if jointype == "outer":
  1955. right_keys = r2keys.difference(r1keys)
  1956. right_ind = np.array([r2d[k] for k in right_keys])
  1957. right_len = len(right_ind)
  1958. def key_desc(name):
  1959. '''
  1960. if name is a string key, use the larger size of r1 or r2 before
  1961. merging
  1962. '''
  1963. dt1 = r1.dtype[name]
  1964. if dt1.type != np.string_:
  1965. return (name, dt1.descr[0][1])
  1966. dt2 = r2.dtype[name]
  1967. if dt1 != dt2:
  1968. raise ValueError("The '{}' fields in arrays 'r1' and 'r2' must "
  1969. "have the same dtype".format(name))
  1970. if dt1.num > dt2.num:
  1971. return (name, dt1.descr[0][1])
  1972. else:
  1973. return (name, dt2.descr[0][1])
  1974. keydesc = [key_desc(name) for name in key]
  1975. def mapped_r1field(name):
  1976. """
  1977. The column name in *newrec* that corresponds to the column in *r1*.
  1978. """
  1979. if name in key or name not in r2.dtype.names:
  1980. return name
  1981. else:
  1982. return name + r1postfix
  1983. def mapped_r2field(name):
  1984. """
  1985. The column name in *newrec* that corresponds to the column in *r2*.
  1986. """
  1987. if name in key or name not in r1.dtype.names:
  1988. return name
  1989. else:
  1990. return name + r2postfix
  1991. r1desc = [(mapped_r1field(desc[0]), desc[1]) for desc in r1.dtype.descr
  1992. if desc[0] not in key]
  1993. r2desc = [(mapped_r2field(desc[0]), desc[1]) for desc in r2.dtype.descr
  1994. if desc[0] not in key]
  1995. all_dtypes = keydesc + r1desc + r2desc
  1996. newdtype = np.dtype(all_dtypes)
  1997. newrec = np.recarray((common_len + left_len + right_len,), dtype=newdtype)
  1998. if defaults is not None:
  1999. for thiskey in defaults:
  2000. if thiskey not in newdtype.names:
  2001. warnings.warn('rec_join defaults key="%s" not in new dtype '
  2002. 'names "%s"' % (thiskey, newdtype.names))
  2003. for name in newdtype.names:
  2004. dt = newdtype[name]
  2005. if dt.kind in ('f', 'i'):
  2006. newrec[name] = 0
  2007. if jointype != 'inner' and defaults is not None:
  2008. # fill in the defaults enmasse
  2009. newrec_fields = list(newrec.dtype.fields)
  2010. for k, v in defaults.items():
  2011. if k in newrec_fields:
  2012. newrec[k] = v
  2013. for field in r1.dtype.names:
  2014. newfield = mapped_r1field(field)
  2015. if common_len:
  2016. newrec[newfield][:common_len] = r1[field][r1ind]
  2017. if (jointype == "outer" or jointype == "leftouter") and left_len:
  2018. newrec[newfield][common_len:(common_len+left_len)] = (
  2019. r1[field][left_ind]
  2020. )
  2021. for field in r2.dtype.names:
  2022. newfield = mapped_r2field(field)
  2023. if field not in key and common_len:
  2024. newrec[newfield][:common_len] = r2[field][r2ind]
  2025. if jointype == "outer" and right_len:
  2026. newrec[newfield][-right_len:] = r2[field][right_ind]
  2027. newrec.sort(order=key)
  2028. return newrec
  2029. @cbook.deprecated("2.2")
  2030. def recs_join(key, name, recs, jointype='outer', missing=0., postfixes=None):
  2031. """
  2032. Join a sequence of record arrays on single column key.
  2033. This function only joins a single column of the multiple record arrays
  2034. *key*
  2035. is the column name that acts as a key
  2036. *name*
  2037. is the name of the column that we want to join
  2038. *recs*
  2039. is a list of record arrays to join
  2040. *jointype*
  2041. is a string 'inner' or 'outer'
  2042. *missing*
  2043. is what any missing field is replaced by
  2044. *postfixes*
  2045. if not None, a len recs sequence of postfixes
  2046. returns a record array with columns [rowkey, name0, name1, ... namen-1].
  2047. or if postfixes [PF0, PF1, ..., PFN-1] are supplied,
  2048. [rowkey, namePF0, namePF1, ... namePFN-1].
  2049. Example::
  2050. r = recs_join("date", "close", recs=[r0, r1], missing=0.)
  2051. """
  2052. results = []
  2053. aligned_iters = cbook.align_iterators(operator.attrgetter(key),
  2054. *[iter(r) for r in recs])
  2055. def extract(r):
  2056. if r is None:
  2057. return missing
  2058. else:
  2059. return r[name]
  2060. if jointype == "outer":
  2061. for rowkey, row in aligned_iters:
  2062. results.append([rowkey] + list(map(extract, row)))
  2063. elif jointype == "inner":
  2064. for rowkey, row in aligned_iters:
  2065. if None not in row: # throw out any Nones
  2066. results.append([rowkey] + list(map(extract, row)))
  2067. if postfixes is None:
  2068. postfixes = ['%d' % i for i in range(len(recs))]
  2069. names = ",".join([key] + ["%s%s" % (name, postfix)
  2070. for postfix in postfixes])
  2071. return np.rec.fromrecords(results, names=names)
  2072. @cbook.deprecated("2.2")
  2073. def csv2rec(fname, comments='#', skiprows=0, checkrows=0, delimiter=',',
  2074. converterd=None, names=None, missing='', missingd=None,
  2075. use_mrecords=False, dayfirst=False, yearfirst=False):
  2076. """
  2077. Load data from comma/space/tab delimited file in *fname* into a
  2078. numpy record array and return the record array.
  2079. If *names* is *None*, a header row is required to automatically
  2080. assign the recarray names. The headers will be lower cased,
  2081. spaces will be converted to underscores, and illegal attribute
  2082. name characters removed. If *names* is not *None*, it is a
  2083. sequence of names to use for the column names. In this case, it
  2084. is assumed there is no header row.
  2085. - *fname*: can be a filename or a file handle. Support for gzipped
  2086. files is automatic, if the filename ends in '.gz'
  2087. - *comments*: the character used to indicate the start of a comment
  2088. in the file, or *None* to switch off the removal of comments
  2089. - *skiprows*: is the number of rows from the top to skip
  2090. - *checkrows*: is the number of rows to check to validate the column
  2091. data type. When set to zero all rows are validated.
  2092. - *converterd*: if not *None*, is a dictionary mapping column number or
  2093. munged column name to a converter function.
  2094. - *names*: if not None, is a list of header names. In this case, no
  2095. header will be read from the file
  2096. - *missingd* is a dictionary mapping munged column names to field values
  2097. which signify that the field does not contain actual data and should
  2098. be masked, e.g., '0000-00-00' or 'unused'
  2099. - *missing*: a string whose value signals a missing field regardless of
  2100. the column it appears in
  2101. - *use_mrecords*: if True, return an mrecords.fromrecords record array if
  2102. any of the data are missing
  2103. - *dayfirst*: default is False so that MM-DD-YY has precedence over
  2104. DD-MM-YY. See
  2105. http://labix.org/python-dateutil#head-b95ce2094d189a89f80f5ae52a05b4ab7b41af47
  2106. for further information.
  2107. - *yearfirst*: default is False so that MM-DD-YY has precedence over
  2108. YY-MM-DD. See
  2109. http://labix.org/python-dateutil#head-b95ce2094d189a89f80f5ae52a05b4ab7b41af47
  2110. for further information.
  2111. If no rows are found, *None* is returned
  2112. """
  2113. if converterd is None:
  2114. converterd = dict()
  2115. if missingd is None:
  2116. missingd = {}
  2117. import dateutil.parser
  2118. import datetime
  2119. fh = cbook.to_filehandle(fname)
  2120. delimiter = str(delimiter)
  2121. class FH:
  2122. """
  2123. For space-delimited files, we want different behavior than
  2124. comma or tab. Generally, we want multiple spaces to be
  2125. treated as a single separator, whereas with comma and tab we
  2126. want multiple commas to return multiple (empty) fields. The
  2127. join/strip trick below effects this.
  2128. """
  2129. def __init__(self, fh):
  2130. self.fh = fh
  2131. def close(self):
  2132. self.fh.close()
  2133. def seek(self, arg):
  2134. self.fh.seek(arg)
  2135. def fix(self, s):
  2136. return ' '.join(s.split())
  2137. def __next__(self):
  2138. return self.fix(next(self.fh))
  2139. def __iter__(self):
  2140. for line in self.fh:
  2141. yield self.fix(line)
  2142. if delimiter == ' ':
  2143. fh = FH(fh)
  2144. reader = csv.reader(fh, delimiter=delimiter)
  2145. def process_skiprows(reader):
  2146. if skiprows:
  2147. for i, row in enumerate(reader):
  2148. if i >= (skiprows-1):
  2149. break
  2150. return fh, reader
  2151. process_skiprows(reader)
  2152. def ismissing(name, val):
  2153. "Should the value val in column name be masked?"
  2154. return val == missing or val == missingd.get(name) or val == ''
  2155. def with_default_value(func, default):
  2156. def newfunc(name, val):
  2157. if ismissing(name, val):
  2158. return default
  2159. else:
  2160. return func(val)
  2161. return newfunc
  2162. def mybool(x):
  2163. if x == 'True':
  2164. return True
  2165. elif x == 'False':
  2166. return False
  2167. else:
  2168. raise ValueError('invalid bool')
  2169. dateparser = dateutil.parser.parse
  2170. def mydateparser(x):
  2171. # try and return a datetime object
  2172. d = dateparser(x, dayfirst=dayfirst, yearfirst=yearfirst)
  2173. return d
  2174. mydateparser = with_default_value(mydateparser, datetime.datetime(1, 1, 1))
  2175. myfloat = with_default_value(float, np.nan)
  2176. myint = with_default_value(int, -1)
  2177. mystr = with_default_value(str, '')
  2178. mybool = with_default_value(mybool, None)
  2179. def mydate(x):
  2180. # try and return a date object
  2181. d = dateparser(x, dayfirst=dayfirst, yearfirst=yearfirst)
  2182. if d.hour > 0 or d.minute > 0 or d.second > 0:
  2183. raise ValueError('not a date')
  2184. return d.date()
  2185. mydate = with_default_value(mydate, datetime.date(1, 1, 1))
  2186. def get_func(name, item, func):
  2187. # promote functions in this order
  2188. funcs = [mybool, myint, myfloat, mydate, mydateparser, mystr]
  2189. for func in funcs[funcs.index(func):]:
  2190. try:
  2191. func(name, item)
  2192. except Exception:
  2193. continue
  2194. return func
  2195. raise ValueError('Could not find a working conversion function')
  2196. # map column names that clash with builtins -- TODO - extend this list
  2197. itemd = {
  2198. 'return': 'return_',
  2199. 'file': 'file_',
  2200. 'print': 'print_',
  2201. }
  2202. def get_converters(reader, comments):
  2203. converters = None
  2204. i = 0
  2205. for row in reader:
  2206. if (len(row) and comments is not None and
  2207. row[0].startswith(comments)):
  2208. continue
  2209. if i == 0:
  2210. converters = [mybool]*len(row)
  2211. if checkrows and i > checkrows:
  2212. break
  2213. i += 1
  2214. for j, (name, item) in enumerate(zip(names, row)):
  2215. func = converterd.get(j)
  2216. if func is None:
  2217. func = converterd.get(name)
  2218. if func is None:
  2219. func = converters[j]
  2220. if len(item.strip()):
  2221. func = get_func(name, item, func)
  2222. else:
  2223. # how should we handle custom converters and defaults?
  2224. func = with_default_value(func, None)
  2225. converters[j] = func
  2226. return converters
  2227. # Get header and remove invalid characters
  2228. needheader = names is None
  2229. if needheader:
  2230. for row in reader:
  2231. if (len(row) and comments is not None and
  2232. row[0].startswith(comments)):
  2233. continue
  2234. headers = row
  2235. break
  2236. # remove these chars
  2237. delete = set(r"""~!@#$%^&*()-=+~\|}[]{';: /?.>,<""")
  2238. delete.add('"')
  2239. names = []
  2240. seen = dict()
  2241. for i, item in enumerate(headers):
  2242. item = item.strip().lower().replace(' ', '_')
  2243. item = ''.join([c for c in item if c not in delete])
  2244. if not len(item):
  2245. item = 'column%d' % i
  2246. item = itemd.get(item, item)
  2247. cnt = seen.get(item, 0)
  2248. if cnt > 0:
  2249. names.append(item + '_%d' % cnt)
  2250. else:
  2251. names.append(item)
  2252. seen[item] = cnt+1
  2253. else:
  2254. if isinstance(names, str):
  2255. names = [n.strip() for n in names.split(',')]
  2256. # get the converter functions by inspecting checkrows
  2257. converters = get_converters(reader, comments)
  2258. if converters is None:
  2259. raise ValueError('Could not find any valid data in CSV file')
  2260. # reset the reader and start over
  2261. fh.seek(0)
  2262. reader = csv.reader(fh, delimiter=delimiter)
  2263. process_skiprows(reader)
  2264. if needheader:
  2265. while True:
  2266. # skip past any comments and consume one line of column header
  2267. row = next(reader)
  2268. if (len(row) and comments is not None and
  2269. row[0].startswith(comments)):
  2270. continue
  2271. break
  2272. # iterate over the remaining rows and convert the data to date
  2273. # objects, ints, or floats as appropriate
  2274. rows = []
  2275. rowmasks = []
  2276. for i, row in enumerate(reader):
  2277. if not len(row):
  2278. continue
  2279. if comments is not None and row[0].startswith(comments):
  2280. continue
  2281. # Ensure that the row returned always has the same nr of elements
  2282. row.extend([''] * (len(converters) - len(row)))
  2283. rows.append([func(name, val)
  2284. for func, name, val in zip(converters, names, row)])
  2285. rowmasks.append([ismissing(name, val)
  2286. for name, val in zip(names, row)])
  2287. fh.close()
  2288. if not len(rows):
  2289. return None
  2290. if use_mrecords and np.any(rowmasks):
  2291. r = np.ma.mrecords.fromrecords(rows, names=names, mask=rowmasks)
  2292. else:
  2293. r = np.rec.fromrecords(rows, names=names)
  2294. return r
  2295. # a series of classes for describing the format intentions of various rec views
  2296. @cbook.deprecated("2.2")
  2297. class FormatObj(object):
  2298. def tostr(self, x):
  2299. return self.toval(x)
  2300. def toval(self, x):
  2301. return str(x)
  2302. def fromstr(self, s):
  2303. return s
  2304. def __hash__(self):
  2305. """
  2306. override the hash function of any of the formatters, so that we don't
  2307. create duplicate excel format styles
  2308. """
  2309. return hash(self.__class__)
  2310. @cbook.deprecated("2.2")
  2311. class FormatString(FormatObj):
  2312. def tostr(self, x):
  2313. val = repr(x)
  2314. return val[1:-1]
  2315. @cbook.deprecated("2.2")
  2316. class FormatFormatStr(FormatObj):
  2317. def __init__(self, fmt):
  2318. self.fmt = fmt
  2319. def tostr(self, x):
  2320. if x is None:
  2321. return 'None'
  2322. return self.fmt % self.toval(x)
  2323. @cbook.deprecated("2.2")
  2324. class FormatFloat(FormatFormatStr):
  2325. def __init__(self, precision=4, scale=1.):
  2326. FormatFormatStr.__init__(self, '%%1.%df' % precision)
  2327. self.precision = precision
  2328. self.scale = scale
  2329. def __hash__(self):
  2330. return hash((self.__class__, self.precision, self.scale))
  2331. def toval(self, x):
  2332. if x is not None:
  2333. x = x * self.scale
  2334. return x
  2335. def fromstr(self, s):
  2336. return float(s)/self.scale
  2337. @cbook.deprecated("2.2")
  2338. class FormatInt(FormatObj):
  2339. def tostr(self, x):
  2340. return '%d' % int(x)
  2341. def toval(self, x):
  2342. return int(x)
  2343. def fromstr(self, s):
  2344. return int(s)
  2345. @cbook.deprecated("2.2")
  2346. class FormatBool(FormatObj):
  2347. def toval(self, x):
  2348. return str(x)
  2349. def fromstr(self, s):
  2350. return bool(s)
  2351. @cbook.deprecated("2.2")
  2352. class FormatPercent(FormatFloat):
  2353. def __init__(self, precision=4):
  2354. FormatFloat.__init__(self, precision, scale=100.)
  2355. @cbook.deprecated("2.2")
  2356. class FormatThousands(FormatFloat):
  2357. def __init__(self, precision=4):
  2358. FormatFloat.__init__(self, precision, scale=1e-3)
  2359. @cbook.deprecated("2.2")
  2360. class FormatMillions(FormatFloat):
  2361. def __init__(self, precision=4):
  2362. FormatFloat.__init__(self, precision, scale=1e-6)
  2363. @cbook.deprecated("2.2", alternative='date.strftime')
  2364. class FormatDate(FormatObj):
  2365. def __init__(self, fmt):
  2366. self.fmt = fmt
  2367. def __hash__(self):
  2368. return hash((self.__class__, self.fmt))
  2369. def toval(self, x):
  2370. if x is None:
  2371. return 'None'
  2372. return x.strftime(self.fmt)
  2373. def fromstr(self, x):
  2374. import dateutil.parser
  2375. return dateutil.parser.parse(x).date()
  2376. @cbook.deprecated("2.2", alternative='datetime.strftime')
  2377. class FormatDatetime(FormatDate):
  2378. def __init__(self, fmt='%Y-%m-%d %H:%M:%S'):
  2379. FormatDate.__init__(self, fmt)
  2380. def fromstr(self, x):
  2381. import dateutil.parser
  2382. return dateutil.parser.parse(x)
  2383. @cbook.deprecated("2.2")
  2384. def get_formatd(r, formatd=None):
  2385. 'build a formatd guaranteed to have a key for every dtype name'
  2386. defaultformatd = {
  2387. np.bool_: FormatBool(),
  2388. np.int16: FormatInt(),
  2389. np.int32: FormatInt(),
  2390. np.int64: FormatInt(),
  2391. np.float32: FormatFloat(),
  2392. np.float64: FormatFloat(),
  2393. np.object_: FormatObj(),
  2394. np.string_: FormatString()}
  2395. if formatd is None:
  2396. formatd = dict()
  2397. for i, name in enumerate(r.dtype.names):
  2398. dt = r.dtype[name]
  2399. format = formatd.get(name)
  2400. if format is None:
  2401. format = defaultformatd.get(dt.type, FormatObj())
  2402. formatd[name] = format
  2403. return formatd
  2404. @cbook.deprecated("2.2")
  2405. def csvformat_factory(format):
  2406. format = copy.deepcopy(format)
  2407. if isinstance(format, FormatFloat):
  2408. format.scale = 1. # override scaling for storage
  2409. format.fmt = '%r'
  2410. return format
  2411. @cbook.deprecated("2.2", alternative='numpy.recarray.tofile')
  2412. def rec2txt(r, header=None, padding=3, precision=3, fields=None):
  2413. """
  2414. Returns a textual representation of a record array.
  2415. Parameters
  2416. ----------
  2417. r: numpy recarray
  2418. header: list
  2419. column headers
  2420. padding:
  2421. space between each column
  2422. precision: number of decimal places to use for floats.
  2423. Set to an integer to apply to all floats. Set to a
  2424. list of integers to apply precision individually.
  2425. Precision for non-floats is simply ignored.
  2426. fields : list
  2427. If not None, a list of field names to print. fields
  2428. can be a list of strings like ['field1', 'field2'] or a single
  2429. comma separated string like 'field1,field2'
  2430. Examples
  2431. --------
  2432. For ``precision=[0,2,3]``, the output is ::
  2433. ID Price Return
  2434. ABC 12.54 0.234
  2435. XYZ 6.32 -0.076
  2436. """
  2437. if fields is not None:
  2438. r = rec_keep_fields(r, fields)
  2439. if cbook.is_numlike(precision):
  2440. precision = [precision]*len(r.dtype)
  2441. def get_type(item, atype=int):
  2442. tdict = {None: int, int: float, float: str}
  2443. try:
  2444. atype(str(item))
  2445. except:
  2446. return get_type(item, tdict[atype])
  2447. return atype
  2448. def get_justify(colname, column, precision):
  2449. ntype = column.dtype
  2450. if np.issubdtype(ntype, np.character):
  2451. fixed_width = int(ntype.str[2:])
  2452. length = max(len(colname), fixed_width)
  2453. return 0, length+padding, "%s" # left justify
  2454. if np.issubdtype(ntype, np.integer):
  2455. length = max(len(colname),
  2456. np.max(list(map(len, list(map(str, column))))))
  2457. return 1, length+padding, "%d" # right justify
  2458. if np.issubdtype(ntype, np.floating):
  2459. fmt = "%." + str(precision) + "f"
  2460. length = max(
  2461. len(colname),
  2462. np.max(list(map(len, list(map(lambda x: fmt % x, column)))))
  2463. )
  2464. return 1, length+padding, fmt # right justify
  2465. return (0,
  2466. max(len(colname),
  2467. np.max(list(map(len, list(map(str, column))))))+padding,
  2468. "%s")
  2469. if header is None:
  2470. header = r.dtype.names
  2471. justify_pad_prec = [get_justify(header[i], r.__getitem__(colname),
  2472. precision[i])
  2473. for i, colname in enumerate(r.dtype.names)]
  2474. justify_pad_prec_spacer = []
  2475. for i in range(len(justify_pad_prec)):
  2476. just, pad, prec = justify_pad_prec[i]
  2477. if i == 0:
  2478. justify_pad_prec_spacer.append((just, pad, prec, 0))
  2479. else:
  2480. pjust, ppad, pprec = justify_pad_prec[i-1]
  2481. if pjust == 0 and just == 1:
  2482. justify_pad_prec_spacer.append((just, pad-padding, prec, 0))
  2483. elif pjust == 1 and just == 0:
  2484. justify_pad_prec_spacer.append((just, pad, prec, padding))
  2485. else:
  2486. justify_pad_prec_spacer.append((just, pad, prec, 0))
  2487. def format(item, just_pad_prec_spacer):
  2488. just, pad, prec, spacer = just_pad_prec_spacer
  2489. if just == 0:
  2490. return spacer*' ' + str(item).ljust(pad)
  2491. else:
  2492. if get_type(item) == float:
  2493. item = (prec % float(item))
  2494. elif get_type(item) == int:
  2495. item = (prec % int(item))
  2496. return item.rjust(pad)
  2497. textl = []
  2498. textl.append(''.join([format(colitem, justify_pad_prec_spacer[j])
  2499. for j, colitem in enumerate(header)]))
  2500. for i, row in enumerate(r):
  2501. textl.append(''.join([format(colitem, justify_pad_prec_spacer[j])
  2502. for j, colitem in enumerate(row)]))
  2503. if i == 0:
  2504. textl[0] = textl[0].rstrip()
  2505. text = os.linesep.join(textl)
  2506. return text
  2507. @cbook.deprecated("2.2", alternative='numpy.recarray.tofile')
  2508. def rec2csv(r, fname, delimiter=',', formatd=None, missing='',
  2509. missingd=None, withheader=True):
  2510. """
  2511. Save the data from numpy recarray *r* into a
  2512. comma-/space-/tab-delimited file. The record array dtype names
  2513. will be used for column headers.
  2514. *fname*: can be a filename or a file handle. Support for gzipped
  2515. files is automatic, if the filename ends in '.gz'
  2516. *withheader*: if withheader is False, do not write the attribute
  2517. names in the first row
  2518. for formatd type FormatFloat, we override the precision to store
  2519. full precision floats in the CSV file
  2520. See Also
  2521. --------
  2522. :func:`csv2rec`
  2523. For information about *missing* and *missingd*, which can be used to
  2524. fill in masked values into your CSV file.
  2525. """
  2526. delimiter = str(delimiter)
  2527. if missingd is None:
  2528. missingd = dict()
  2529. def with_mask(func):
  2530. def newfunc(val, mask, mval):
  2531. if mask:
  2532. return mval
  2533. else:
  2534. return func(val)
  2535. return newfunc
  2536. if r.ndim != 1:
  2537. raise ValueError('rec2csv only operates on 1 dimensional recarrays')
  2538. formatd = get_formatd(r, formatd)
  2539. funcs = []
  2540. for i, name in enumerate(r.dtype.names):
  2541. funcs.append(with_mask(csvformat_factory(formatd[name]).tostr))
  2542. fh, opened = cbook.to_filehandle(fname, 'wb', return_opened=True)
  2543. writer = csv.writer(fh, delimiter=delimiter)
  2544. header = r.dtype.names
  2545. if withheader:
  2546. writer.writerow(header)
  2547. # Our list of specials for missing values
  2548. mvals = []
  2549. for name in header:
  2550. mvals.append(missingd.get(name, missing))
  2551. ismasked = False
  2552. if len(r):
  2553. row = r[0]
  2554. ismasked = hasattr(row, '_fieldmask')
  2555. for row in r:
  2556. if ismasked:
  2557. row, rowmask = row.item(), row._fieldmask.item()
  2558. else:
  2559. rowmask = [False] * len(row)
  2560. writer.writerow([func(val, mask, mval) for func, val, mask, mval
  2561. in zip(funcs, row, rowmask, mvals)])
  2562. if opened:
  2563. fh.close()
  2564. @cbook.deprecated('2.2', alternative='scipy.interpolate.griddata')
  2565. def griddata(x, y, z, xi, yi, interp='nn'):
  2566. """
  2567. Interpolates from a nonuniformly spaced grid to some other grid.
  2568. Fits a surface of the form z = f(`x`, `y`) to the data in the
  2569. (usually) nonuniformly spaced vectors (`x`, `y`, `z`), then
  2570. interpolates this surface at the points specified by
  2571. (`xi`, `yi`) to produce `zi`.
  2572. Parameters
  2573. ----------
  2574. x, y, z : 1d array_like
  2575. Coordinates of grid points to interpolate from.
  2576. xi, yi : 1d or 2d array_like
  2577. Coordinates of grid points to interpolate to.
  2578. interp : string key from {'nn', 'linear'}
  2579. Interpolation algorithm, either 'nn' for natural neighbor, or
  2580. 'linear' for linear interpolation.
  2581. Returns
  2582. -------
  2583. 2d float array
  2584. Array of values interpolated at (`xi`, `yi`) points. Array
  2585. will be masked is any of (`xi`, `yi`) are outside the convex
  2586. hull of (`x`, `y`).
  2587. Notes
  2588. -----
  2589. If `interp` is 'nn' (the default), uses natural neighbor
  2590. interpolation based on Delaunay triangulation. This option is
  2591. only available if the mpl_toolkits.natgrid module is installed.
  2592. This can be downloaded from https://github.com/matplotlib/natgrid.
  2593. The (`xi`, `yi`) grid must be regular and monotonically increasing
  2594. in this case.
  2595. If `interp` is 'linear', linear interpolation is used via
  2596. matplotlib.tri.LinearTriInterpolator.
  2597. Instead of using `griddata`, more flexible functionality and other
  2598. interpolation options are available using a
  2599. matplotlib.tri.Triangulation and a matplotlib.tri.TriInterpolator.
  2600. """
  2601. # Check input arguments.
  2602. x = np.asanyarray(x, dtype=np.float64)
  2603. y = np.asanyarray(y, dtype=np.float64)
  2604. z = np.asanyarray(z, dtype=np.float64)
  2605. if x.shape != y.shape or x.shape != z.shape or x.ndim != 1:
  2606. raise ValueError("x, y and z must be equal-length 1-D arrays")
  2607. xi = np.asanyarray(xi, dtype=np.float64)
  2608. yi = np.asanyarray(yi, dtype=np.float64)
  2609. if xi.ndim != yi.ndim:
  2610. raise ValueError("xi and yi must be arrays with the same number of "
  2611. "dimensions (1 or 2)")
  2612. if xi.ndim == 2 and xi.shape != yi.shape:
  2613. raise ValueError("if xi and yi are 2D arrays, they must have the same "
  2614. "shape")
  2615. if xi.ndim == 1:
  2616. xi, yi = np.meshgrid(xi, yi)
  2617. if interp == 'nn':
  2618. use_nn_interpolation = True
  2619. elif interp == 'linear':
  2620. use_nn_interpolation = False
  2621. else:
  2622. raise ValueError("interp keyword must be one of 'linear' (for linear "
  2623. "interpolation) or 'nn' (for natural neighbor "
  2624. "interpolation). Default is 'nn'.")
  2625. # Remove masked points.
  2626. mask = np.ma.getmask(z)
  2627. if mask is not np.ma.nomask:
  2628. x = x.compress(~mask)
  2629. y = y.compress(~mask)
  2630. z = z.compressed()
  2631. if use_nn_interpolation:
  2632. try:
  2633. from mpl_toolkits.natgrid import _natgrid
  2634. except ImportError:
  2635. raise RuntimeError(
  2636. "To use interp='nn' (Natural Neighbor interpolation) in "
  2637. "griddata, natgrid must be installed. Either install it "
  2638. "from http://github.com/matplotlib/natgrid or use "
  2639. "interp='linear' instead.")
  2640. if xi.ndim == 2:
  2641. # natgrid expects 1D xi and yi arrays.
  2642. xi = xi[0, :]
  2643. yi = yi[:, 0]
  2644. # Override default natgrid internal parameters.
  2645. _natgrid.seti(b'ext', 0)
  2646. _natgrid.setr(b'nul', np.nan)
  2647. if np.min(np.diff(xi)) < 0 or np.min(np.diff(yi)) < 0:
  2648. raise ValueError("Output grid defined by xi,yi must be monotone "
  2649. "increasing")
  2650. # Allocate array for output (buffer will be overwritten by natgridd)
  2651. zi = np.empty((yi.shape[0], xi.shape[0]), np.float64)
  2652. # Natgrid requires each array to be contiguous rather than e.g. a view
  2653. # that is a non-contiguous slice of another array. Use numpy.require
  2654. # to deal with this, which will copy if necessary.
  2655. x = np.require(x, requirements=['C'])
  2656. y = np.require(y, requirements=['C'])
  2657. z = np.require(z, requirements=['C'])
  2658. xi = np.require(xi, requirements=['C'])
  2659. yi = np.require(yi, requirements=['C'])
  2660. _natgrid.natgridd(x, y, z, xi, yi, zi)
  2661. # Mask points on grid outside convex hull of input data.
  2662. if np.any(np.isnan(zi)):
  2663. zi = np.ma.masked_where(np.isnan(zi), zi)
  2664. return zi
  2665. else:
  2666. # Linear interpolation performed using a matplotlib.tri.Triangulation
  2667. # and a matplotlib.tri.LinearTriInterpolator.
  2668. from .tri import Triangulation, LinearTriInterpolator
  2669. triang = Triangulation(x, y)
  2670. interpolator = LinearTriInterpolator(triang, z)
  2671. return interpolator(xi, yi)
  2672. ##################################################
  2673. # Linear interpolation algorithms
  2674. ##################################################
  2675. @cbook.deprecated("2.2", alternative="numpy.interp")
  2676. def less_simple_linear_interpolation(x, y, xi, extrap=False):
  2677. """
  2678. This function provides simple (but somewhat less so than
  2679. :func:`cbook.simple_linear_interpolation`) linear interpolation.
  2680. :func:`simple_linear_interpolation` will give a list of point
  2681. between a start and an end, while this does true linear
  2682. interpolation at an arbitrary set of points.
  2683. This is very inefficient linear interpolation meant to be used
  2684. only for a small number of points in relatively non-intensive use
  2685. cases. For real linear interpolation, use scipy.
  2686. """
  2687. x = np.asarray(x)
  2688. y = np.asarray(y)
  2689. xi = np.atleast_1d(xi)
  2690. s = list(y.shape)
  2691. s[0] = len(xi)
  2692. yi = np.tile(np.nan, s)
  2693. for ii, xx in enumerate(xi):
  2694. bb = x == xx
  2695. if np.any(bb):
  2696. jj, = np.nonzero(bb)
  2697. yi[ii] = y[jj[0]]
  2698. elif xx < x[0]:
  2699. if extrap:
  2700. yi[ii] = y[0]
  2701. elif xx > x[-1]:
  2702. if extrap:
  2703. yi[ii] = y[-1]
  2704. else:
  2705. jj, = np.nonzero(x < xx)
  2706. jj = max(jj)
  2707. yi[ii] = y[jj] + (xx-x[jj])/(x[jj+1]-x[jj]) * (y[jj+1]-y[jj])
  2708. return yi
  2709. @cbook.deprecated("2.2")
  2710. def slopes(x, y):
  2711. """
  2712. :func:`slopes` calculates the slope *y*'(*x*)
  2713. The slope is estimated using the slope obtained from that of a
  2714. parabola through any three consecutive points.
  2715. This method should be superior to that described in the appendix
  2716. of A CONSISTENTLY WELL BEHAVED METHOD OF INTERPOLATION by Russel
  2717. W. Stineman (Creative Computing July 1980) in at least one aspect:
  2718. Circles for interpolation demand a known aspect ratio between
  2719. *x*- and *y*-values. For many functions, however, the abscissa
  2720. are given in different dimensions, so an aspect ratio is
  2721. completely arbitrary.
  2722. The parabola method gives very similar results to the circle
  2723. method for most regular cases but behaves much better in special
  2724. cases.
  2725. Norbert Nemec, Institute of Theoretical Physics, University or
  2726. Regensburg, April 2006 Norbert.Nemec at physik.uni-regensburg.de
  2727. (inspired by a original implementation by Halldor Bjornsson,
  2728. Icelandic Meteorological Office, March 2006 halldor at vedur.is)
  2729. """
  2730. # Cast key variables as float.
  2731. x = np.asarray(x, float)
  2732. y = np.asarray(y, float)
  2733. yp = np.zeros(y.shape, float)
  2734. dx = x[1:] - x[:-1]
  2735. dy = y[1:] - y[:-1]
  2736. dydx = dy/dx
  2737. yp[1:-1] = (dydx[:-1] * dx[1:] + dydx[1:] * dx[:-1])/(dx[1:] + dx[:-1])
  2738. yp[0] = 2.0 * dy[0]/dx[0] - yp[1]
  2739. yp[-1] = 2.0 * dy[-1]/dx[-1] - yp[-2]
  2740. return yp
  2741. @cbook.deprecated("2.2")
  2742. def stineman_interp(xi, x, y, yp=None):
  2743. """
  2744. Given data vectors *x* and *y*, the slope vector *yp* and a new
  2745. abscissa vector *xi*, the function :func:`stineman_interp` uses
  2746. Stineman interpolation to calculate a vector *yi* corresponding to
  2747. *xi*.
  2748. Here's an example that generates a coarse sine curve, then
  2749. interpolates over a finer abscissa::
  2750. x = linspace(0,2*pi,20); y = sin(x); yp = cos(x)
  2751. xi = linspace(0,2*pi,40);
  2752. yi = stineman_interp(xi,x,y,yp);
  2753. plot(x,y,'o',xi,yi)
  2754. The interpolation method is described in the article A
  2755. CONSISTENTLY WELL BEHAVED METHOD OF INTERPOLATION by Russell
  2756. W. Stineman. The article appeared in the July 1980 issue of
  2757. Creative Computing with a note from the editor stating that while
  2758. they were:
  2759. not an academic journal but once in a while something serious
  2760. and original comes in adding that this was
  2761. "apparently a real solution" to a well known problem.
  2762. For *yp* = *None*, the routine automatically determines the slopes
  2763. using the :func:`slopes` routine.
  2764. *x* is assumed to be sorted in increasing order.
  2765. For values ``xi[j] < x[0]`` or ``xi[j] > x[-1]``, the routine
  2766. tries an extrapolation. The relevance of the data obtained from
  2767. this, of course, is questionable...
  2768. Original implementation by Halldor Bjornsson, Icelandic
  2769. Meteorolocial Office, March 2006 halldor at vedur.is
  2770. Completely reworked and optimized for Python by Norbert Nemec,
  2771. Institute of Theoretical Physics, University or Regensburg, April
  2772. 2006 Norbert.Nemec at physik.uni-regensburg.de
  2773. """
  2774. # Cast key variables as float.
  2775. x = np.asarray(x, float)
  2776. y = np.asarray(y, float)
  2777. if x.shape != y.shape:
  2778. raise ValueError("'x' and 'y' must be of same shape")
  2779. if yp is None:
  2780. yp = slopes(x, y)
  2781. else:
  2782. yp = np.asarray(yp, float)
  2783. xi = np.asarray(xi, float)
  2784. yi = np.zeros(xi.shape, float)
  2785. # calculate linear slopes
  2786. dx = x[1:] - x[:-1]
  2787. dy = y[1:] - y[:-1]
  2788. s = dy/dx # note length of s is N-1 so last element is #N-2
  2789. # find the segment each xi is in
  2790. # this line actually is the key to the efficiency of this implementation
  2791. idx = np.searchsorted(x[1:-1], xi)
  2792. # now we have generally: x[idx[j]] <= xi[j] <= x[idx[j]+1]
  2793. # except at the boundaries, where it may be that xi[j] < x[0] or
  2794. # xi[j] > x[-1]
  2795. # the y-values that would come out from a linear interpolation:
  2796. sidx = s.take(idx)
  2797. xidx = x.take(idx)
  2798. yidx = y.take(idx)
  2799. xidxp1 = x.take(idx+1)
  2800. yo = yidx + sidx * (xi - xidx)
  2801. # the difference that comes when using the slopes given in yp
  2802. # using the yp slope of the left point
  2803. dy1 = (yp.take(idx) - sidx) * (xi - xidx)
  2804. # using the yp slope of the right point
  2805. dy2 = (yp.take(idx+1)-sidx) * (xi - xidxp1)
  2806. dy1dy2 = dy1*dy2
  2807. # The following is optimized for Python. The solution actually
  2808. # does more calculations than necessary but exploiting the power
  2809. # of numpy, this is far more efficient than coding a loop by hand
  2810. # in Python
  2811. yi = yo + dy1dy2 * np.choose(np.array(np.sign(dy1dy2), np.int32)+1,
  2812. ((2*xi-xidx-xidxp1)/((dy1-dy2)*(xidxp1-xidx)),
  2813. 0.0,
  2814. 1/(dy1+dy2),))
  2815. return yi
  2816. class GaussianKDE(object):
  2817. """
  2818. Representation of a kernel-density estimate using Gaussian kernels.
  2819. Parameters
  2820. ----------
  2821. dataset : array_like
  2822. Datapoints to estimate from. In case of univariate data this is a 1-D
  2823. array, otherwise a 2-D array with shape (# of dims, # of data).
  2824. bw_method : str, scalar or callable, optional
  2825. The method used to calculate the estimator bandwidth. This can be
  2826. 'scott', 'silverman', a scalar constant or a callable. If a
  2827. scalar, this will be used directly as `kde.factor`. If a
  2828. callable, it should take a `GaussianKDE` instance as only
  2829. parameter and return a scalar. If None (default), 'scott' is used.
  2830. Attributes
  2831. ----------
  2832. dataset : ndarray
  2833. The dataset with which `gaussian_kde` was initialized.
  2834. dim : int
  2835. Number of dimensions.
  2836. num_dp : int
  2837. Number of datapoints.
  2838. factor : float
  2839. The bandwidth factor, obtained from `kde.covariance_factor`, with which
  2840. the covariance matrix is multiplied.
  2841. covariance : ndarray
  2842. The covariance matrix of `dataset`, scaled by the calculated bandwidth
  2843. (`kde.factor`).
  2844. inv_cov : ndarray
  2845. The inverse of `covariance`.
  2846. Methods
  2847. -------
  2848. kde.evaluate(points) : ndarray
  2849. Evaluate the estimated pdf on a provided set of points.
  2850. kde(points) : ndarray
  2851. Same as kde.evaluate(points)
  2852. """
  2853. # This implementation with minor modification was too good to pass up.
  2854. # from scipy: https://github.com/scipy/scipy/blob/master/scipy/stats/kde.py
  2855. def __init__(self, dataset, bw_method=None):
  2856. self.dataset = np.atleast_2d(dataset)
  2857. if not np.array(self.dataset).size > 1:
  2858. raise ValueError("`dataset` input should have multiple elements.")
  2859. self.dim, self.num_dp = np.array(self.dataset).shape
  2860. isString = isinstance(bw_method, str)
  2861. if bw_method is None:
  2862. pass
  2863. elif (isString and bw_method == 'scott'):
  2864. self.covariance_factor = self.scotts_factor
  2865. elif (isString and bw_method == 'silverman'):
  2866. self.covariance_factor = self.silverman_factor
  2867. elif (np.isscalar(bw_method) and not isString):
  2868. self._bw_method = 'use constant'
  2869. self.covariance_factor = lambda: bw_method
  2870. elif callable(bw_method):
  2871. self._bw_method = bw_method
  2872. self.covariance_factor = lambda: self._bw_method(self)
  2873. else:
  2874. raise ValueError("`bw_method` should be 'scott', 'silverman', a "
  2875. "scalar or a callable")
  2876. # Computes the covariance matrix for each Gaussian kernel using
  2877. # covariance_factor().
  2878. self.factor = self.covariance_factor()
  2879. # Cache covariance and inverse covariance of the data
  2880. if not hasattr(self, '_data_inv_cov'):
  2881. self.data_covariance = np.atleast_2d(
  2882. np.cov(
  2883. self.dataset,
  2884. rowvar=1,
  2885. bias=False))
  2886. self.data_inv_cov = np.linalg.inv(self.data_covariance)
  2887. self.covariance = self.data_covariance * self.factor ** 2
  2888. self.inv_cov = self.data_inv_cov / self.factor ** 2
  2889. self.norm_factor = np.sqrt(
  2890. np.linalg.det(
  2891. 2 * np.pi * self.covariance)) * self.num_dp
  2892. def scotts_factor(self):
  2893. return np.power(self.num_dp, -1. / (self.dim + 4))
  2894. def silverman_factor(self):
  2895. return np.power(
  2896. self.num_dp * (self.dim + 2.0) / 4.0, -1. / (self.dim + 4))
  2897. # Default method to calculate bandwidth, can be overwritten by subclass
  2898. covariance_factor = scotts_factor
  2899. def evaluate(self, points):
  2900. """Evaluate the estimated pdf on a set of points.
  2901. Parameters
  2902. ----------
  2903. points : (# of dimensions, # of points)-array
  2904. Alternatively, a (# of dimensions,) vector can be passed in and
  2905. treated as a single point.
  2906. Returns
  2907. -------
  2908. values : (# of points,)-array
  2909. The values at each point.
  2910. Raises
  2911. ------
  2912. ValueError : if the dimensionality of the input points is different
  2913. than the dimensionality of the KDE.
  2914. """
  2915. points = np.atleast_2d(points)
  2916. dim, num_m = np.array(points).shape
  2917. if dim != self.dim:
  2918. raise ValueError("points have dimension {}, dataset has dimension "
  2919. "{}".format(dim, self.dim))
  2920. result = np.zeros((num_m,), dtype=float)
  2921. if num_m >= self.num_dp:
  2922. # there are more points than data, so loop over data
  2923. for i in range(self.num_dp):
  2924. diff = self.dataset[:, i, np.newaxis] - points
  2925. tdiff = np.dot(self.inv_cov, diff)
  2926. energy = np.sum(diff * tdiff, axis=0) / 2.0
  2927. result = result + np.exp(-energy)
  2928. else:
  2929. # loop over points
  2930. for i in range(num_m):
  2931. diff = self.dataset - points[:, i, np.newaxis]
  2932. tdiff = np.dot(self.inv_cov, diff)
  2933. energy = np.sum(diff * tdiff, axis=0) / 2.0
  2934. result[i] = np.sum(np.exp(-energy), axis=0)
  2935. result = result / self.norm_factor
  2936. return result
  2937. __call__ = evaluate
  2938. ##################################################
  2939. # Code related to things in and around polygons
  2940. ##################################################
  2941. @cbook.deprecated("2.2")
  2942. def inside_poly(points, verts):
  2943. """
  2944. *points* is a sequence of *x*, *y* points.
  2945. *verts* is a sequence of *x*, *y* vertices of a polygon.
  2946. Return value is a sequence of indices into points for the points
  2947. that are inside the polygon.
  2948. """
  2949. # Make a closed polygon path
  2950. poly = Path(verts)
  2951. # Check to see which points are contained within the Path
  2952. return [idx for idx, p in enumerate(points) if poly.contains_point(p)]
  2953. @cbook.deprecated("2.2")
  2954. def poly_below(xmin, xs, ys):
  2955. """
  2956. Given a sequence of *xs* and *ys*, return the vertices of a
  2957. polygon that has a horizontal base at *xmin* and an upper bound at
  2958. the *ys*. *xmin* is a scalar.
  2959. Intended for use with :meth:`matplotlib.axes.Axes.fill`, e.g.,::
  2960. xv, yv = poly_below(0, x, y)
  2961. ax.fill(xv, yv)
  2962. """
  2963. if any(isinstance(var, np.ma.MaskedArray) for var in [xs, ys]):
  2964. numpy = np.ma
  2965. else:
  2966. numpy = np
  2967. xs = numpy.asarray(xs)
  2968. ys = numpy.asarray(ys)
  2969. Nx = len(xs)
  2970. Ny = len(ys)
  2971. if Nx != Ny:
  2972. raise ValueError("'xs' and 'ys' must have the same length")
  2973. x = xmin*numpy.ones(2*Nx)
  2974. y = numpy.ones(2*Nx)
  2975. x[:Nx] = xs
  2976. y[:Nx] = ys
  2977. y[Nx:] = ys[::-1]
  2978. return x, y
  2979. @cbook.deprecated("2.2")
  2980. def poly_between(x, ylower, yupper):
  2981. """
  2982. Given a sequence of *x*, *ylower* and *yupper*, return the polygon
  2983. that fills the regions between them. *ylower* or *yupper* can be
  2984. scalar or iterable. If they are iterable, they must be equal in
  2985. length to *x*.
  2986. Return value is *x*, *y* arrays for use with
  2987. :meth:`matplotlib.axes.Axes.fill`.
  2988. """
  2989. if any(isinstance(var, np.ma.MaskedArray) for var in [ylower, yupper, x]):
  2990. numpy = np.ma
  2991. else:
  2992. numpy = np
  2993. Nx = len(x)
  2994. if not cbook.iterable(ylower):
  2995. ylower = ylower*numpy.ones(Nx)
  2996. if not cbook.iterable(yupper):
  2997. yupper = yupper*numpy.ones(Nx)
  2998. x = numpy.concatenate((x, x[::-1]))
  2999. y = numpy.concatenate((yupper, ylower[::-1]))
  3000. return x, y
  3001. @cbook.deprecated('2.2')
  3002. def is_closed_polygon(X):
  3003. """
  3004. Tests whether first and last object in a sequence are the same. These are
  3005. presumably coordinates on a polygonal curve, in which case this function
  3006. tests if that curve is closed.
  3007. """
  3008. return np.all(X[0] == X[-1])
  3009. @cbook.deprecated("2.2", message='Moved to matplotlib.cbook')
  3010. def contiguous_regions(mask):
  3011. """
  3012. return a list of (ind0, ind1) such that mask[ind0:ind1].all() is
  3013. True and we cover all such regions
  3014. """
  3015. return cbook.contiguous_regions(mask)
  3016. @cbook.deprecated("2.2")
  3017. def cross_from_below(x, threshold):
  3018. """
  3019. return the indices into *x* where *x* crosses some threshold from
  3020. below, e.g., the i's where::
  3021. x[i-1]<threshold and x[i]>=threshold
  3022. Example code::
  3023. import matplotlib.pyplot as plt
  3024. t = np.arange(0.0, 2.0, 0.1)
  3025. s = np.sin(2*np.pi*t)
  3026. fig, ax = plt.subplots()
  3027. ax.plot(t, s, '-o')
  3028. ax.axhline(0.5)
  3029. ax.axhline(-0.5)
  3030. ind = cross_from_below(s, 0.5)
  3031. ax.vlines(t[ind], -1, 1)
  3032. ind = cross_from_above(s, -0.5)
  3033. ax.vlines(t[ind], -1, 1)
  3034. plt.show()
  3035. See Also
  3036. --------
  3037. :func:`cross_from_above` and :func:`contiguous_regions`
  3038. """
  3039. x = np.asarray(x)
  3040. ind = np.nonzero((x[:-1] < threshold) & (x[1:] >= threshold))[0]
  3041. if len(ind):
  3042. return ind+1
  3043. else:
  3044. return ind
  3045. @cbook.deprecated("2.2")
  3046. def cross_from_above(x, threshold):
  3047. """
  3048. return the indices into *x* where *x* crosses some threshold from
  3049. below, e.g., the i's where::
  3050. x[i-1]>threshold and x[i]<=threshold
  3051. See Also
  3052. --------
  3053. :func:`cross_from_below` and :func:`contiguous_regions`
  3054. """
  3055. x = np.asarray(x)
  3056. ind = np.nonzero((x[:-1] >= threshold) & (x[1:] < threshold))[0]
  3057. if len(ind):
  3058. return ind+1
  3059. else:
  3060. return ind
  3061. ##################################################
  3062. # Vector and path length geometry calculations
  3063. ##################################################
  3064. @cbook.deprecated('2.2')
  3065. def vector_lengths(X, P=2., axis=None):
  3066. """
  3067. Finds the length of a set of vectors in *n* dimensions. This is
  3068. like the :func:`numpy.norm` function for vectors, but has the ability to
  3069. work over a particular axis of the supplied array or matrix.
  3070. Computes ``(sum((x_i)^P))^(1/P)`` for each ``{x_i}`` being the
  3071. elements of *X* along the given axis. If *axis* is *None*,
  3072. compute over all elements of *X*.
  3073. """
  3074. X = np.asarray(X)
  3075. return (np.sum(X**(P), axis=axis))**(1./P)
  3076. @cbook.deprecated('2.2')
  3077. def distances_along_curve(X):
  3078. """
  3079. Computes the distance between a set of successive points in *N* dimensions.
  3080. Where *X* is an *M* x *N* array or matrix. The distances between
  3081. successive rows is computed. Distance is the standard Euclidean
  3082. distance.
  3083. """
  3084. X = np.diff(X, axis=0)
  3085. return vector_lengths(X, axis=1)
  3086. @cbook.deprecated('2.2')
  3087. def path_length(X):
  3088. """
  3089. Computes the distance travelled along a polygonal curve in *N* dimensions.
  3090. Where *X* is an *M* x *N* array or matrix. Returns an array of
  3091. length *M* consisting of the distance along the curve at each point
  3092. (i.e., the rows of *X*).
  3093. """
  3094. X = distances_along_curve(X)
  3095. return np.concatenate((np.zeros(1), np.cumsum(X)))
  3096. @cbook.deprecated('2.2')
  3097. def quad2cubic(q0x, q0y, q1x, q1y, q2x, q2y):
  3098. """
  3099. Converts a quadratic Bezier curve to a cubic approximation.
  3100. The inputs are the *x* and *y* coordinates of the three control
  3101. points of a quadratic curve, and the output is a tuple of *x* and
  3102. *y* coordinates of the four control points of the cubic curve.
  3103. """
  3104. # TODO: Candidate for deprecation -- no longer used internally
  3105. # c0x, c0y = q0x, q0y
  3106. c1x, c1y = q0x + 2./3. * (q1x - q0x), q0y + 2./3. * (q1y - q0y)
  3107. c2x, c2y = c1x + 1./3. * (q2x - q0x), c1y + 1./3. * (q2y - q0y)
  3108. # c3x, c3y = q2x, q2y
  3109. return q0x, q0y, c1x, c1y, c2x, c2y, q2x, q2y
  3110. @cbook.deprecated("2.2")
  3111. def offset_line(y, yerr):
  3112. """
  3113. Offsets an array *y* by +/- an error and returns a tuple
  3114. (y - err, y + err).
  3115. The error term can be:
  3116. * A scalar. In this case, the returned tuple is obvious.
  3117. * A vector of the same length as *y*. The quantities y +/- err are computed
  3118. component-wise.
  3119. * A tuple of length 2. In this case, yerr[0] is the error below *y* and
  3120. yerr[1] is error above *y*. For example::
  3121. import numpy as np
  3122. import matplotlib.pyplot as plt
  3123. x = np.linspace(0, 2*np.pi, num=100, endpoint=True)
  3124. y = np.sin(x)
  3125. y_minus, y_plus = mlab.offset_line(y, 0.1)
  3126. plt.plot(x, y)
  3127. plt.fill_between(x, y_minus, y2=y_plus)
  3128. plt.show()
  3129. """
  3130. if cbook.is_numlike(yerr) or (cbook.iterable(yerr) and
  3131. len(yerr) == len(y)):
  3132. ymin = y - yerr
  3133. ymax = y + yerr
  3134. elif len(yerr) == 2:
  3135. ymin, ymax = y - yerr[0], y + yerr[1]
  3136. else:
  3137. raise ValueError("yerr must be scalar, 1xN or 2xN")
  3138. return ymin, ymax