You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.7 KiB

4 years ago
  1. from collections import defaultdict, Sequence, Sized, Iterable, Callable
  2. import inspect
  3. import wrapt
  4. from cytoolz import curry
  5. from numpy import ndarray
  6. from six import integer_types
  7. from .exceptions import UndefinedOperatorError, DifferentLengthError
  8. from .exceptions import ExpectedTypeError, ShapeMismatchError
  9. from .exceptions import OutsideRangeError
  10. def is_docs(arg_id, args, kwargs):
  11. from spacy.tokens.doc import Doc
  12. docs = args[arg_id]
  13. if not isinstance(docs, Sequence):
  14. raise ExpectedTypeError(type(docs), ['Sequence'])
  15. if not isinstance(docs[0], Doc):
  16. raise ExpectedTypeError(type(docs[0]), ['spacy.tokens.doc.Doc'])
  17. def equal_length(*args):
  18. '''Check that arguments have the same length.
  19. '''
  20. for i, arg in enumerate(args):
  21. if not isinstance(arg, Sized):
  22. raise ExpectedTypeError(arg, ['Sized'])
  23. if i >= 1 and len(arg) != len(args[0]):
  24. raise DifferentLengthError(args, arg)
  25. def equal_axis(*args, **axis):
  26. '''Check that elements have the same dimension on specified axis.
  27. '''
  28. axis = axis.get('axis', -1)
  29. for i, arg in enumerate(args):
  30. if not isinstance(arg, ndarray):
  31. raise ExpectedTypeError(arg, ['ndarray'])
  32. if axis >= 0 and (axis+1) < args[i].shape[axis]:
  33. raise ShapeError(
  34. "Shape: %s. Expected at least %d dimensions",
  35. shape, axis)
  36. if i >= 1 and arg.shape[axis] != args[0].shape[axis]:
  37. lengths = [a.shape[axis] for a in args]
  38. raise DifferentLengthError(lengths, arg)
  39. @curry
  40. def has_shape(shape, arg_id, args, kwargs):
  41. '''Check that a particular argument is an array with a given shape. The
  42. shape may contain string attributes, which will be fetched from arg0 to
  43. the function (usually self).
  44. '''
  45. self = args[0]
  46. arg = args[arg_id]
  47. if not hasattr(arg, 'shape'):
  48. raise ExpectedTypeError(arg, ['array'])
  49. shape_values = []
  50. for dim in shape:
  51. if not isinstance(dim, integer_types):
  52. dim = getattr(self, dim, None)
  53. shape_values.append(dim)
  54. if len(shape) != len(arg.shape):
  55. raise ShapeMismatchError(arg.shape, tuple(shape_values), shape)
  56. for i, dim in enumerate(shape_values):
  57. # Allow underspecified dimensions
  58. if dim != None and arg.shape[i] != dim:
  59. raise ShapeMismatchError(arg.shape, shape_values, shape)
  60. def is_shape(arg_id, args, func_kwargs, **kwargs):
  61. arg = args[arg_id]
  62. if not isinstance(arg, Iterable):
  63. raise ExpectedTypeError(arg, ['iterable'])
  64. for value in arg:
  65. if not isinstance(value, integer_types) or value < 0:
  66. raise ExpectedTypeError(arg, ['valid shape (positive ints)'])
  67. def is_sequence(arg_id, args, kwargs):
  68. arg = args[arg_id]
  69. if not isinstance(arg, Iterable) and not hasattr(arg, '__getitem__'):
  70. raise ExpectedTypeError(arg, ['iterable'])
  71. def is_float(arg_id, args, func_kwargs, **kwargs):
  72. arg = args[arg_id]
  73. if not isinstance(arg, float):
  74. raise ExpectedTypeError(arg, ['float'])
  75. if 'min' in kwargs and arg < kwargs['min']:
  76. raise OutsideRangeError(arg, kwargs['min'], '>=')
  77. if 'max' in kwargs and arg > kwargs['max']:
  78. raise OutsideRangeError(arg, kwargs['max'], '<=')
  79. def is_int(arg_id, args, func_kwargs, **kwargs):
  80. arg = args[arg_id]
  81. if not isinstance(arg, integer_types):
  82. raise ExpectedTypeError(arg, ['int'])
  83. if 'min' in kwargs and arg < kwargs['min']:
  84. raise OutsideRangeError(arg, kwargs['min'], '>=')
  85. if 'max' in kwargs and arg > kwargs['max']:
  86. raise OutsideRangeError(arg, kwargs['max'], '<=')
  87. def is_array(arg_id, args, func_kwargs, **kwargs):
  88. arg = args[arg_id]
  89. if not isinstance(arg, ndarray):
  90. raise ExpectedTypeError(arg, ['ndarray'])
  91. def is_int_array(arg_id, args, func_kwargs, **kwargs):
  92. arg = args[arg_id]
  93. if not isinstance(arg, ndarray) or 'i' not in arg.dtype.kind:
  94. raise ExpectedTypeError(arg, ['ndarray[int]'])
  95. def operator_is_defined(op):
  96. @wrapt.decorator
  97. def checker(wrapped, instance, args, kwargs):
  98. if instance is None:
  99. instance = args[0]
  100. if instance is None:
  101. raise ExpectedTypeError(instance, ['Model'])
  102. if op not in instance._operators:
  103. raise UndefinedOperatorError(op, instance, args[0], instance._operators)
  104. else:
  105. return wrapped(*args, **kwargs)
  106. return checker
  107. def arg(arg_id, *constraints):
  108. @wrapt.decorator
  109. def checked_function(wrapped, instance, args, kwargs):
  110. # for partial functions or other C-compiled functions
  111. if not hasattr(wrapped, 'checks'): # pragma: no cover
  112. return wrapped(*args, **kwargs)
  113. if instance is not None:
  114. fix_args = [instance] + list(args)
  115. else:
  116. fix_args = list(args)
  117. for arg_id, checks in wrapped.checks.items():
  118. for check in checks:
  119. if not isinstance(check, Callable):
  120. raise ExpectedTypeError(check, ['Callable'])
  121. check(arg_id, fix_args, kwargs)
  122. return wrapped(*args, **kwargs)
  123. def arg_check_adder(func):
  124. if hasattr(func, 'checks'):
  125. func.checks.setdefault(arg_id, []).extend(constraints)
  126. return func
  127. else:
  128. wrapped = checked_function(func)
  129. wrapped.checks = {arg_id: list(constraints)}
  130. return wrapped
  131. return arg_check_adder
  132. def args(*constraints):
  133. @wrapt.decorator
  134. def arg_check_adder(wrapped, instance, args, kwargs):
  135. for check in constraints:
  136. if not isinstance(check, Callable):
  137. raise ExpectedTypeError(check, ['Callable'])
  138. check(*args)
  139. return wrapped(*args, **kwargs)
  140. return arg_check_adder