You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.5 KiB

4 years ago
  1. # coding: utf-8
  2. from __future__ import unicode_literals
  3. from collections import Sized
  4. import os
  5. import traceback
  6. class UndefinedOperatorError(TypeError):
  7. def __init__(self, op, arg1, arg2, operators):
  8. self.tb = traceback.extract_stack()
  9. TypeError.__init__(self, get_error(
  10. "Undefined operator: {op}".format(op=op),
  11. "Called by ({arg1}, {arg2})".format(arg1=arg1, arg2=arg2),
  12. "Available: {ops}".format(ops= ', '.join(operators.keys())),
  13. tb=self.tb,
  14. highlight=op
  15. ))
  16. class OutsideRangeError(ValueError):
  17. def __init__(self, arg, val, operator):
  18. self.tb = traceback.extract_stack()
  19. ValueError.__init__(self, get_error(
  20. "Outside range: {v} needs to be {o} {v2}".format(
  21. v=_repr(arg), o=operator, v2=_repr(val)),
  22. tb=self.tb
  23. ))
  24. class DifferentLengthError(ValueError):
  25. def __init__(self, lengths, arg):
  26. self.tb = traceback.extract_stack()
  27. ValueError.__init__(self, get_error(
  28. "Values need to be equal length: {v}".format(v=_repr(lengths)),
  29. tb=self.tb
  30. ))
  31. class ShapeMismatchError(ValueError):
  32. def __init__(self, shape, dim, shape_names):
  33. self.tb = traceback.extract_stack()
  34. shape = _repr(shape)
  35. dim = _repr(dim)
  36. ValueError.__init__(self, get_error(
  37. "Shape mismatch: input {s} not compatible with {d}.".format(s=shape, d=dim),
  38. tb=self.tb
  39. ))
  40. class TooFewDimensionsError(ValueError):
  41. def __init__(self, shape, axis):
  42. self.tb = traceback.extract_stack()
  43. ValueError.__init__(self, get_error(
  44. "Shape mismatch: input {s} has too short for axis {d}.".format(
  45. s=_repr(shape), d=axis), tb=self.tb
  46. ))
  47. class ExpectedTypeError(TypeError):
  48. max_to_print_of_value = 200
  49. def __init__(self, bad_type, expected):
  50. if isinstance(expected, str):
  51. expected = [expected]
  52. self.tb = traceback.extract_stack()
  53. TypeError.__init__(self, get_error(
  54. "Expected type {e}, but got: {v} ({t})".format(e='/'.join(expected), v=_repr(bad_type), t=type(bad_type)),
  55. tb=self.tb,
  56. highlight=_repr(bad_type)
  57. ))
  58. def get_error(title, *args, **kwargs):
  59. template = '\n\n\t{title}{info}{tb}\n'
  60. info = '\n'.join(['\t' + l for l in args]) if args else ''
  61. highlight = kwargs['highlight'] if 'highlight' in kwargs else False
  62. tb = _get_traceback(kwargs['tb'], highlight) if 'tb' in kwargs else ''
  63. return template.format(title=color(title, 'red', attrs=['bold']),
  64. info=info, tb=tb)
  65. def _repr(obj, max_len=50):
  66. string = repr(obj)
  67. if len(string) >= max_len:
  68. half = int(max_len/2)
  69. return string[:half] + ' ... ' + string[-half:]
  70. else:
  71. return string
  72. def _get_traceback(tb, highlight):
  73. template = '\n\n\t{title}\n\t{tb}'
  74. # Prune "check.py" from tb (hacky)
  75. tb = [record for record in tb if not record[0].endswith('check.py')]
  76. tb_range = tb[-5:-2]
  77. tb_list = [_format_traceback(p, l, fn, t, i, len(tb_range), highlight) for i, (p, l, fn, t) in enumerate(tb_range)]
  78. return template.format(title=color('Traceback:', 'blue', attrs=['bold']),
  79. tb='\n'.join(tb_list).strip())
  80. def _format_traceback(path, line, fn, text, i, count, highlight):
  81. template = '\t{i} {fn} [{l}] in {p}{t}'
  82. indent = ('└─' if i == count-1 else '├─') + '──'*i
  83. filename = path.rsplit('/thinc/', 1)[1] if '/thinc/' in path else path
  84. text = _format_user_error(text, i, highlight) if i == count-1 else ''
  85. return template.format(l=str(line), i=indent, t=text,
  86. fn=color(fn, attrs=['bold']),
  87. p=color(filename, attrs=['underline']))
  88. def _format_user_error(text, i, highlight):
  89. template = '\n\t {sp} {t}'
  90. spacing = ' '*i + color(' >>>', 'red')
  91. if highlight:
  92. text = text.replace(str(highlight), color(str(highlight), 'yellow'))
  93. return template.format(sp=spacing, t=text)
  94. def color(text, fg=None, attrs=None):
  95. """Wrap text in color / style ANSI escape sequences."""
  96. if os.getenv('ANSI_COLORS_DISABLED') is not None:
  97. return text
  98. attrs = attrs or []
  99. tpl = '\x1b[{}m'
  100. styles = {'red': 31, 'blue': 34, 'yellow': 33, 'bold': 1, 'underline': 4}
  101. style = ''
  102. for attr in attrs:
  103. if attr in styles:
  104. style += tpl.format(styles[attr])
  105. if fg and fg in styles:
  106. style += tpl.format(styles[fg])
  107. return '{}{}\x1b[0m'.format(style, text)