You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
1.8 KiB

4 years ago
  1. """
  2. Aliases for functions which may be accelerated by Scipy.
  3. Scipy_ can be built to use accelerated or otherwise improved libraries
  4. for FFTs, linear algebra, and special functions. This module allows
  5. developers to transparently support these accelerated functions when
  6. scipy is available but still support users who have only installed
  7. NumPy.
  8. .. _Scipy : http://www.scipy.org
  9. """
  10. from __future__ import division, absolute_import, print_function
  11. # This module should be used for functions both in numpy and scipy if
  12. # you want to use the numpy version if available but the scipy version
  13. # otherwise.
  14. # Usage --- from numpy.dual import fft, inv
  15. __all__ = ['fft', 'ifft', 'fftn', 'ifftn', 'fft2', 'ifft2',
  16. 'norm', 'inv', 'svd', 'solve', 'det', 'eig', 'eigvals',
  17. 'eigh', 'eigvalsh', 'lstsq', 'pinv', 'cholesky', 'i0']
  18. import numpy.linalg as linpkg
  19. import numpy.fft as fftpkg
  20. from numpy.lib import i0
  21. import sys
  22. fft = fftpkg.fft
  23. ifft = fftpkg.ifft
  24. fftn = fftpkg.fftn
  25. ifftn = fftpkg.ifftn
  26. fft2 = fftpkg.fft2
  27. ifft2 = fftpkg.ifft2
  28. norm = linpkg.norm
  29. inv = linpkg.inv
  30. svd = linpkg.svd
  31. solve = linpkg.solve
  32. det = linpkg.det
  33. eig = linpkg.eig
  34. eigvals = linpkg.eigvals
  35. eigh = linpkg.eigh
  36. eigvalsh = linpkg.eigvalsh
  37. lstsq = linpkg.lstsq
  38. pinv = linpkg.pinv
  39. cholesky = linpkg.cholesky
  40. _restore_dict = {}
  41. def register_func(name, func):
  42. if name not in __all__:
  43. raise ValueError("%s not a dual function." % name)
  44. f = sys._getframe(0).f_globals
  45. _restore_dict[name] = f[name]
  46. f[name] = func
  47. def restore_func(name):
  48. if name not in __all__:
  49. raise ValueError("%s not a dual function." % name)
  50. try:
  51. val = _restore_dict[name]
  52. except KeyError:
  53. return
  54. else:
  55. sys._getframe(0).f_globals[name] = val
  56. def restore_all():
  57. for name in _restore_dict.keys():
  58. restore_func(name)