168 lines
5.7 KiB
Python
168 lines
5.7 KiB
Python
|
from collections import defaultdict, Sequence, Sized, Iterable, Callable
|
||
|
import inspect
|
||
|
import wrapt
|
||
|
from cytoolz import curry
|
||
|
from numpy import ndarray
|
||
|
from six import integer_types
|
||
|
|
||
|
from .exceptions import UndefinedOperatorError, DifferentLengthError
|
||
|
from .exceptions import ExpectedTypeError, ShapeMismatchError
|
||
|
from .exceptions import OutsideRangeError
|
||
|
|
||
|
def is_docs(arg_id, args, kwargs):
|
||
|
from spacy.tokens.doc import Doc
|
||
|
docs = args[arg_id]
|
||
|
if not isinstance(docs, Sequence):
|
||
|
raise ExpectedTypeError(type(docs), ['Sequence'])
|
||
|
if not isinstance(docs[0], Doc):
|
||
|
raise ExpectedTypeError(type(docs[0]), ['spacy.tokens.doc.Doc'])
|
||
|
|
||
|
|
||
|
|
||
|
def equal_length(*args):
|
||
|
'''Check that arguments have the same length.
|
||
|
'''
|
||
|
for i, arg in enumerate(args):
|
||
|
if not isinstance(arg, Sized):
|
||
|
raise ExpectedTypeError(arg, ['Sized'])
|
||
|
if i >= 1 and len(arg) != len(args[0]):
|
||
|
raise DifferentLengthError(args, arg)
|
||
|
|
||
|
|
||
|
def equal_axis(*args, **axis):
|
||
|
'''Check that elements have the same dimension on specified axis.
|
||
|
'''
|
||
|
axis = axis.get('axis', -1)
|
||
|
for i, arg in enumerate(args):
|
||
|
if not isinstance(arg, ndarray):
|
||
|
raise ExpectedTypeError(arg, ['ndarray'])
|
||
|
if axis >= 0 and (axis+1) < args[i].shape[axis]:
|
||
|
raise ShapeError(
|
||
|
"Shape: %s. Expected at least %d dimensions",
|
||
|
shape, axis)
|
||
|
if i >= 1 and arg.shape[axis] != args[0].shape[axis]:
|
||
|
lengths = [a.shape[axis] for a in args]
|
||
|
raise DifferentLengthError(lengths, arg)
|
||
|
|
||
|
@curry
|
||
|
def has_shape(shape, arg_id, args, kwargs):
|
||
|
'''Check that a particular argument is an array with a given shape. The
|
||
|
shape may contain string attributes, which will be fetched from arg0 to
|
||
|
the function (usually self).
|
||
|
'''
|
||
|
self = args[0]
|
||
|
arg = args[arg_id]
|
||
|
if not hasattr(arg, 'shape'):
|
||
|
raise ExpectedTypeError(arg, ['array'])
|
||
|
shape_values = []
|
||
|
for dim in shape:
|
||
|
if not isinstance(dim, integer_types):
|
||
|
dim = getattr(self, dim, None)
|
||
|
shape_values.append(dim)
|
||
|
if len(shape) != len(arg.shape):
|
||
|
raise ShapeMismatchError(arg.shape, tuple(shape_values), shape)
|
||
|
for i, dim in enumerate(shape_values):
|
||
|
# Allow underspecified dimensions
|
||
|
if dim != None and arg.shape[i] != dim:
|
||
|
raise ShapeMismatchError(arg.shape, shape_values, shape)
|
||
|
|
||
|
|
||
|
def is_shape(arg_id, args, func_kwargs, **kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, Iterable):
|
||
|
raise ExpectedTypeError(arg, ['iterable'])
|
||
|
for value in arg:
|
||
|
if not isinstance(value, integer_types) or value < 0:
|
||
|
raise ExpectedTypeError(arg, ['valid shape (positive ints)'])
|
||
|
|
||
|
|
||
|
def is_sequence(arg_id, args, kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, Iterable) and not hasattr(arg, '__getitem__'):
|
||
|
raise ExpectedTypeError(arg, ['iterable'])
|
||
|
|
||
|
|
||
|
def is_float(arg_id, args, func_kwargs, **kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, float):
|
||
|
raise ExpectedTypeError(arg, ['float'])
|
||
|
if 'min' in kwargs and arg < kwargs['min']:
|
||
|
raise OutsideRangeError(arg, kwargs['min'], '>=')
|
||
|
if 'max' in kwargs and arg > kwargs['max']:
|
||
|
raise OutsideRangeError(arg, kwargs['max'], '<=')
|
||
|
|
||
|
|
||
|
def is_int(arg_id, args, func_kwargs, **kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, integer_types):
|
||
|
raise ExpectedTypeError(arg, ['int'])
|
||
|
if 'min' in kwargs and arg < kwargs['min']:
|
||
|
raise OutsideRangeError(arg, kwargs['min'], '>=')
|
||
|
if 'max' in kwargs and arg > kwargs['max']:
|
||
|
raise OutsideRangeError(arg, kwargs['max'], '<=')
|
||
|
|
||
|
|
||
|
def is_array(arg_id, args, func_kwargs, **kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, ndarray):
|
||
|
raise ExpectedTypeError(arg, ['ndarray'])
|
||
|
|
||
|
|
||
|
def is_int_array(arg_id, args, func_kwargs, **kwargs):
|
||
|
arg = args[arg_id]
|
||
|
if not isinstance(arg, ndarray) or 'i' not in arg.dtype.kind:
|
||
|
raise ExpectedTypeError(arg, ['ndarray[int]'])
|
||
|
|
||
|
|
||
|
def operator_is_defined(op):
|
||
|
@wrapt.decorator
|
||
|
def checker(wrapped, instance, args, kwargs):
|
||
|
if instance is None:
|
||
|
instance = args[0]
|
||
|
if instance is None:
|
||
|
raise ExpectedTypeError(instance, ['Model'])
|
||
|
if op not in instance._operators:
|
||
|
raise UndefinedOperatorError(op, instance, args[0], instance._operators)
|
||
|
else:
|
||
|
return wrapped(*args, **kwargs)
|
||
|
return checker
|
||
|
|
||
|
|
||
|
def arg(arg_id, *constraints):
|
||
|
@wrapt.decorator
|
||
|
def checked_function(wrapped, instance, args, kwargs):
|
||
|
# for partial functions or other C-compiled functions
|
||
|
if not hasattr(wrapped, 'checks'): # pragma: no cover
|
||
|
return wrapped(*args, **kwargs)
|
||
|
if instance is not None:
|
||
|
fix_args = [instance] + list(args)
|
||
|
else:
|
||
|
fix_args = list(args)
|
||
|
for arg_id, checks in wrapped.checks.items():
|
||
|
for check in checks:
|
||
|
if not isinstance(check, Callable):
|
||
|
raise ExpectedTypeError(check, ['Callable'])
|
||
|
check(arg_id, fix_args, kwargs)
|
||
|
return wrapped(*args, **kwargs)
|
||
|
|
||
|
def arg_check_adder(func):
|
||
|
if hasattr(func, 'checks'):
|
||
|
func.checks.setdefault(arg_id, []).extend(constraints)
|
||
|
return func
|
||
|
else:
|
||
|
wrapped = checked_function(func)
|
||
|
wrapped.checks = {arg_id: list(constraints)}
|
||
|
return wrapped
|
||
|
return arg_check_adder
|
||
|
|
||
|
|
||
|
def args(*constraints):
|
||
|
@wrapt.decorator
|
||
|
def arg_check_adder(wrapped, instance, args, kwargs):
|
||
|
for check in constraints:
|
||
|
if not isinstance(check, Callable):
|
||
|
raise ExpectedTypeError(check, ['Callable'])
|
||
|
check(*args)
|
||
|
return wrapped(*args, **kwargs)
|
||
|
return arg_check_adder
|