You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
2.7 KiB

4 years ago
  1. __all__ = ["generic"]
  2. try:
  3. from types import ClassType, InstanceType
  4. classtypes = type, ClassType
  5. except ImportError:
  6. classtypes = type
  7. InstanceType = None
  8. def generic(func):
  9. """Create a simple generic function"""
  10. _sentinel = object()
  11. def _by_class(*args, **kw):
  12. cls = args[0].__class__
  13. for t in type(cls.__name__, (cls,object), {}).__mro__:
  14. f = _gbt(t, _sentinel)
  15. if f is not _sentinel:
  16. return f(*args, **kw)
  17. else:
  18. return func(*args, **kw)
  19. _by_type = {object: func, InstanceType: _by_class}
  20. _gbt = _by_type.get
  21. def when_type(*types):
  22. """Decorator to add a method that will be called for the given types"""
  23. for t in types:
  24. if not isinstance(t, classtypes):
  25. raise TypeError(
  26. "%r is not a type or class" % (t,)
  27. )
  28. def decorate(f):
  29. for t in types:
  30. if _by_type.setdefault(t,f) is not f:
  31. raise TypeError(
  32. "%r already has method for type %r" % (func, t)
  33. )
  34. return f
  35. return decorate
  36. _by_object = {}
  37. _gbo = _by_object.get
  38. def when_object(*obs):
  39. """Decorator to add a method to be called for the given object(s)"""
  40. def decorate(f):
  41. for o in obs:
  42. if _by_object.setdefault(id(o), (o,f))[1] is not f:
  43. raise TypeError(
  44. "%r already has method for object %r" % (func, o)
  45. )
  46. return f
  47. return decorate
  48. def dispatch(*args, **kw):
  49. f = _gbo(id(args[0]), _sentinel)
  50. if f is _sentinel:
  51. for t in type(args[0]).__mro__:
  52. f = _gbt(t, _sentinel)
  53. if f is not _sentinel:
  54. return f(*args, **kw)
  55. else:
  56. return func(*args, **kw)
  57. else:
  58. return f[1](*args, **kw)
  59. dispatch.__name__ = func.__name__
  60. dispatch.__dict__ = func.__dict__.copy()
  61. dispatch.__doc__ = func.__doc__
  62. dispatch.__module__ = func.__module__
  63. dispatch.when_type = when_type
  64. dispatch.when_object = when_object
  65. dispatch.default = func
  66. dispatch.has_object = lambda o: id(o) in _by_object
  67. dispatch.has_type = lambda t: t in _by_type
  68. return dispatch
  69. def test_suite():
  70. import doctest
  71. return doctest.DocFileSuite(
  72. 'README.txt',
  73. optionflags=doctest.ELLIPSIS|doctest.REPORT_ONLY_FIRST_FAILURE,
  74. )
  75. if __name__=='__main__':
  76. import unittest
  77. r = unittest.TextTestRunner()
  78. r.run(test_suite())