131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
# coding: utf-8
|
|
from __future__ import unicode_literals
|
|
from collections import Sized
|
|
|
|
import os
|
|
import traceback
|
|
|
|
|
|
class UndefinedOperatorError(TypeError):
|
|
def __init__(self, op, arg1, arg2, operators):
|
|
self.tb = traceback.extract_stack()
|
|
TypeError.__init__(self, get_error(
|
|
"Undefined operator: {op}".format(op=op),
|
|
"Called by ({arg1}, {arg2})".format(arg1=arg1, arg2=arg2),
|
|
"Available: {ops}".format(ops= ', '.join(operators.keys())),
|
|
tb=self.tb,
|
|
highlight=op
|
|
))
|
|
|
|
|
|
class OutsideRangeError(ValueError):
|
|
def __init__(self, arg, val, operator):
|
|
self.tb = traceback.extract_stack()
|
|
ValueError.__init__(self, get_error(
|
|
"Outside range: {v} needs to be {o} {v2}".format(
|
|
v=_repr(arg), o=operator, v2=_repr(val)),
|
|
tb=self.tb
|
|
))
|
|
|
|
|
|
class DifferentLengthError(ValueError):
|
|
def __init__(self, lengths, arg):
|
|
self.tb = traceback.extract_stack()
|
|
ValueError.__init__(self, get_error(
|
|
"Values need to be equal length: {v}".format(v=_repr(lengths)),
|
|
tb=self.tb
|
|
))
|
|
|
|
|
|
class ShapeMismatchError(ValueError):
|
|
def __init__(self, shape, dim, shape_names):
|
|
self.tb = traceback.extract_stack()
|
|
shape = _repr(shape)
|
|
dim = _repr(dim)
|
|
ValueError.__init__(self, get_error(
|
|
"Shape mismatch: input {s} not compatible with {d}.".format(s=shape, d=dim),
|
|
tb=self.tb
|
|
))
|
|
|
|
|
|
class TooFewDimensionsError(ValueError):
|
|
def __init__(self, shape, axis):
|
|
self.tb = traceback.extract_stack()
|
|
ValueError.__init__(self, get_error(
|
|
"Shape mismatch: input {s} has too short for axis {d}.".format(
|
|
s=_repr(shape), d=axis), tb=self.tb
|
|
))
|
|
|
|
|
|
class ExpectedTypeError(TypeError):
|
|
max_to_print_of_value = 200
|
|
def __init__(self, bad_type, expected):
|
|
if isinstance(expected, str):
|
|
expected = [expected]
|
|
self.tb = traceback.extract_stack()
|
|
TypeError.__init__(self, get_error(
|
|
"Expected type {e}, but got: {v} ({t})".format(e='/'.join(expected), v=_repr(bad_type), t=type(bad_type)),
|
|
tb=self.tb,
|
|
highlight=_repr(bad_type)
|
|
))
|
|
|
|
|
|
def get_error(title, *args, **kwargs):
|
|
template = '\n\n\t{title}{info}{tb}\n'
|
|
info = '\n'.join(['\t' + l for l in args]) if args else ''
|
|
highlight = kwargs['highlight'] if 'highlight' in kwargs else False
|
|
tb = _get_traceback(kwargs['tb'], highlight) if 'tb' in kwargs else ''
|
|
return template.format(title=color(title, 'red', attrs=['bold']),
|
|
info=info, tb=tb)
|
|
|
|
def _repr(obj, max_len=50):
|
|
string = repr(obj)
|
|
if len(string) >= max_len:
|
|
half = int(max_len/2)
|
|
return string[:half] + ' ... ' + string[-half:]
|
|
else:
|
|
return string
|
|
|
|
|
|
def _get_traceback(tb, highlight):
|
|
template = '\n\n\t{title}\n\t{tb}'
|
|
# Prune "check.py" from tb (hacky)
|
|
tb = [record for record in tb if not record[0].endswith('check.py')]
|
|
tb_range = tb[-5:-2]
|
|
tb_list = [_format_traceback(p, l, fn, t, i, len(tb_range), highlight) for i, (p, l, fn, t) in enumerate(tb_range)]
|
|
return template.format(title=color('Traceback:', 'blue', attrs=['bold']),
|
|
tb='\n'.join(tb_list).strip())
|
|
|
|
|
|
def _format_traceback(path, line, fn, text, i, count, highlight):
|
|
template = '\t{i} {fn} [{l}] in {p}{t}'
|
|
indent = ('└─' if i == count-1 else '├─') + '──'*i
|
|
filename = path.rsplit('/thinc/', 1)[1] if '/thinc/' in path else path
|
|
text = _format_user_error(text, i, highlight) if i == count-1 else ''
|
|
return template.format(l=str(line), i=indent, t=text,
|
|
fn=color(fn, attrs=['bold']),
|
|
p=color(filename, attrs=['underline']))
|
|
|
|
|
|
def _format_user_error(text, i, highlight):
|
|
template = '\n\t {sp} {t}'
|
|
spacing = ' '*i + color(' >>>', 'red')
|
|
if highlight:
|
|
text = text.replace(str(highlight), color(str(highlight), 'yellow'))
|
|
return template.format(sp=spacing, t=text)
|
|
|
|
|
|
def color(text, fg=None, attrs=None):
|
|
"""Wrap text in color / style ANSI escape sequences."""
|
|
if os.getenv('ANSI_COLORS_DISABLED') is not None:
|
|
return text
|
|
attrs = attrs or []
|
|
tpl = '\x1b[{}m'
|
|
styles = {'red': 31, 'blue': 34, 'yellow': 33, 'bold': 1, 'underline': 4}
|
|
style = ''
|
|
for attr in attrs:
|
|
if attr in styles:
|
|
style += tpl.format(styles[attr])
|
|
if fg and fg in styles:
|
|
style += tpl.format(styles[fg])
|
|
return '{}{}\x1b[0m'.format(style, text)
|