|
|
- from collections import defaultdict, Sequence, Sized, Iterable, Callable
- import inspect
- import wrapt
- from cytoolz import curry
- from numpy import ndarray
- from six import integer_types
-
- from .exceptions import UndefinedOperatorError, DifferentLengthError
- from .exceptions import ExpectedTypeError, ShapeMismatchError
- from .exceptions import OutsideRangeError
-
- def is_docs(arg_id, args, kwargs):
- from spacy.tokens.doc import Doc
- docs = args[arg_id]
- if not isinstance(docs, Sequence):
- raise ExpectedTypeError(type(docs), ['Sequence'])
- if not isinstance(docs[0], Doc):
- raise ExpectedTypeError(type(docs[0]), ['spacy.tokens.doc.Doc'])
-
-
-
- def equal_length(*args):
- '''Check that arguments have the same length.
- '''
- for i, arg in enumerate(args):
- if not isinstance(arg, Sized):
- raise ExpectedTypeError(arg, ['Sized'])
- if i >= 1 and len(arg) != len(args[0]):
- raise DifferentLengthError(args, arg)
-
-
- def equal_axis(*args, **axis):
- '''Check that elements have the same dimension on specified axis.
- '''
- axis = axis.get('axis', -1)
- for i, arg in enumerate(args):
- if not isinstance(arg, ndarray):
- raise ExpectedTypeError(arg, ['ndarray'])
- if axis >= 0 and (axis+1) < args[i].shape[axis]:
- raise ShapeError(
- "Shape: %s. Expected at least %d dimensions",
- shape, axis)
- if i >= 1 and arg.shape[axis] != args[0].shape[axis]:
- lengths = [a.shape[axis] for a in args]
- raise DifferentLengthError(lengths, arg)
-
- @curry
- def has_shape(shape, arg_id, args, kwargs):
- '''Check that a particular argument is an array with a given shape. The
- shape may contain string attributes, which will be fetched from arg0 to
- the function (usually self).
- '''
- self = args[0]
- arg = args[arg_id]
- if not hasattr(arg, 'shape'):
- raise ExpectedTypeError(arg, ['array'])
- shape_values = []
- for dim in shape:
- if not isinstance(dim, integer_types):
- dim = getattr(self, dim, None)
- shape_values.append(dim)
- if len(shape) != len(arg.shape):
- raise ShapeMismatchError(arg.shape, tuple(shape_values), shape)
- for i, dim in enumerate(shape_values):
- # Allow underspecified dimensions
- if dim != None and arg.shape[i] != dim:
- raise ShapeMismatchError(arg.shape, shape_values, shape)
-
-
- def is_shape(arg_id, args, func_kwargs, **kwargs):
- arg = args[arg_id]
- if not isinstance(arg, Iterable):
- raise ExpectedTypeError(arg, ['iterable'])
- for value in arg:
- if not isinstance(value, integer_types) or value < 0:
- raise ExpectedTypeError(arg, ['valid shape (positive ints)'])
-
-
- def is_sequence(arg_id, args, kwargs):
- arg = args[arg_id]
- if not isinstance(arg, Iterable) and not hasattr(arg, '__getitem__'):
- raise ExpectedTypeError(arg, ['iterable'])
-
-
- def is_float(arg_id, args, func_kwargs, **kwargs):
- arg = args[arg_id]
- if not isinstance(arg, float):
- raise ExpectedTypeError(arg, ['float'])
- if 'min' in kwargs and arg < kwargs['min']:
- raise OutsideRangeError(arg, kwargs['min'], '>=')
- if 'max' in kwargs and arg > kwargs['max']:
- raise OutsideRangeError(arg, kwargs['max'], '<=')
-
-
- def is_int(arg_id, args, func_kwargs, **kwargs):
- arg = args[arg_id]
- if not isinstance(arg, integer_types):
- raise ExpectedTypeError(arg, ['int'])
- if 'min' in kwargs and arg < kwargs['min']:
- raise OutsideRangeError(arg, kwargs['min'], '>=')
- if 'max' in kwargs and arg > kwargs['max']:
- raise OutsideRangeError(arg, kwargs['max'], '<=')
-
-
- def is_array(arg_id, args, func_kwargs, **kwargs):
- arg = args[arg_id]
- if not isinstance(arg, ndarray):
- raise ExpectedTypeError(arg, ['ndarray'])
-
-
- def is_int_array(arg_id, args, func_kwargs, **kwargs):
- arg = args[arg_id]
- if not isinstance(arg, ndarray) or 'i' not in arg.dtype.kind:
- raise ExpectedTypeError(arg, ['ndarray[int]'])
-
-
- def operator_is_defined(op):
- @wrapt.decorator
- def checker(wrapped, instance, args, kwargs):
- if instance is None:
- instance = args[0]
- if instance is None:
- raise ExpectedTypeError(instance, ['Model'])
- if op not in instance._operators:
- raise UndefinedOperatorError(op, instance, args[0], instance._operators)
- else:
- return wrapped(*args, **kwargs)
- return checker
-
-
- def arg(arg_id, *constraints):
- @wrapt.decorator
- def checked_function(wrapped, instance, args, kwargs):
- # for partial functions or other C-compiled functions
- if not hasattr(wrapped, 'checks'): # pragma: no cover
- return wrapped(*args, **kwargs)
- if instance is not None:
- fix_args = [instance] + list(args)
- else:
- fix_args = list(args)
- for arg_id, checks in wrapped.checks.items():
- for check in checks:
- if not isinstance(check, Callable):
- raise ExpectedTypeError(check, ['Callable'])
- check(arg_id, fix_args, kwargs)
- return wrapped(*args, **kwargs)
-
- def arg_check_adder(func):
- if hasattr(func, 'checks'):
- func.checks.setdefault(arg_id, []).extend(constraints)
- return func
- else:
- wrapped = checked_function(func)
- wrapped.checks = {arg_id: list(constraints)}
- return wrapped
- return arg_check_adder
-
-
- def args(*constraints):
- @wrapt.decorator
- def arg_check_adder(wrapped, instance, args, kwargs):
- for check in constraints:
- if not isinstance(check, Callable):
- raise ExpectedTypeError(check, ['Callable'])
- check(*args)
- return wrapped(*args, **kwargs)
- return arg_check_adder
|