You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
3.5 KiB

4 years ago
  1. # coding: utf8
  2. from __future__ import unicode_literals
  3. import os
  4. import sys
  5. import ujson
  6. import itertools
  7. import locale
  8. from thinc.neural.util import copy_array
  9. try:
  10. import cPickle as pickle
  11. except ImportError:
  12. import pickle
  13. try:
  14. import copy_reg
  15. except ImportError:
  16. import copyreg as copy_reg
  17. try:
  18. from cupy.cuda.stream import Stream as CudaStream
  19. except ImportError:
  20. CudaStream = None
  21. try:
  22. import cupy
  23. except ImportError:
  24. cupy = None
  25. try:
  26. from thinc.neural.optimizers import Optimizer
  27. except ImportError:
  28. from thinc.neural.optimizers import Adam as Optimizer
  29. pickle = pickle
  30. copy_reg = copy_reg
  31. CudaStream = CudaStream
  32. cupy = cupy
  33. copy_array = copy_array
  34. izip = getattr(itertools, "izip", zip)
  35. is_windows = sys.platform.startswith("win")
  36. is_linux = sys.platform.startswith("linux")
  37. is_osx = sys.platform == "darwin"
  38. # See: https://github.com/benjaminp/six/blob/master/six.py
  39. is_python2 = sys.version_info[0] == 2
  40. is_python3 = sys.version_info[0] == 3
  41. is_python_pre_3_5 = is_python2 or (is_python3 and sys.version_info[1] < 5)
  42. if is_python2:
  43. bytes_ = str
  44. unicode_ = unicode # noqa: F821
  45. basestring_ = basestring # noqa: F821
  46. input_ = raw_input # noqa: F821
  47. json_dumps = lambda data: ujson.dumps(
  48. data, indent=2, escape_forward_slashes=False
  49. ).decode("utf8")
  50. path2str = lambda path: str(path).decode("utf8")
  51. elif is_python3:
  52. bytes_ = bytes
  53. unicode_ = str
  54. basestring_ = str
  55. input_ = input
  56. json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False)
  57. path2str = lambda path: str(path)
  58. def b_to_str(b_str):
  59. if is_python2:
  60. return b_str
  61. # important: if no encoding is set, string becomes "b'...'"
  62. return str(b_str, encoding="utf8")
  63. def getattr_(obj, name, *default):
  64. if is_python3 and isinstance(name, bytes):
  65. name = name.decode("utf8")
  66. return getattr(obj, name, *default)
  67. def symlink_to(orig, dest):
  68. if is_windows:
  69. import subprocess
  70. subprocess.call(["mklink", "/d", path2str(orig), path2str(dest)], shell=True)
  71. else:
  72. orig.symlink_to(dest)
  73. def symlink_remove(link):
  74. # https://stackoverflow.com/q/26554135/6400719
  75. if os.path.isdir(path2str(link)) and is_windows:
  76. # this should only be on Py2.7 and windows
  77. os.rmdir(path2str(link))
  78. else:
  79. os.unlink(path2str(link))
  80. def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
  81. return (
  82. python2 in (None, is_python2)
  83. and python3 in (None, is_python3)
  84. and windows in (None, is_windows)
  85. and linux in (None, is_linux)
  86. and osx in (None, is_osx)
  87. )
  88. def normalize_string_keys(old):
  89. """Given a dictionary, make sure keys are unicode strings, not bytes."""
  90. new = {}
  91. for key, value in old.items():
  92. if isinstance(key, bytes_):
  93. new[key.decode("utf8")] = value
  94. else:
  95. new[key] = value
  96. return new
  97. def import_file(name, loc):
  98. loc = str(loc)
  99. if is_python_pre_3_5:
  100. import imp
  101. return imp.load_source(name, loc)
  102. else:
  103. import importlib.util
  104. spec = importlib.util.spec_from_file_location(name, str(loc))
  105. module = importlib.util.module_from_spec(spec)
  106. spec.loader.exec_module(module)
  107. return module
  108. def locale_escape(string, errors="replace"):
  109. """
  110. Mangle non-supported characters, for savages with ascii terminals.
  111. """
  112. encoding = locale.getpreferredencoding()
  113. string = string.encode(encoding, errors).decode("utf8")
  114. return string