You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
2.9 KiB

4 years ago
  1. from functools import wraps
  2. class AttributeDescription(object):
  3. def __init__(self, text, value=None, *args, **kwargs):
  4. self.name = None
  5. self.text = text
  6. self.value = value
  7. def __call__(self, attr, model):
  8. self.name = attr
  9. def __get__(self, obj, type=None): # pragma: no cover
  10. return self.value
  11. def __set__(self, obj, val): # pragma: no cover
  12. self.value = val
  13. class Dimension(AttributeDescription):
  14. def __get__(self, obj, type=None):
  15. return obj._dims.get(self.name, None)
  16. def __set__(self, obj, value):
  17. obj._dims[self.name] = value
  18. class Weights(AttributeDescription):
  19. def __init__(self, text, get_shape, init=None):
  20. self.name = None
  21. self.text = text
  22. self.get_shape = get_shape
  23. self.init = init
  24. def __get__(self, obj, type=None):
  25. key = (obj.id, self.name)
  26. if key in obj._mem:
  27. return obj._mem[key]
  28. else:
  29. shape = self.get_shape(obj)
  30. data = obj._mem.add(key, shape)
  31. if self.init is not None:
  32. self.init(data, obj.ops)
  33. return data
  34. def __set__(self, obj, val):
  35. data = obj._mem.get((obj.id, self.name))
  36. data[:] = val
  37. class Gradient(AttributeDescription):
  38. def __init__(self, param_name):
  39. self.name = None
  40. self.text = "Gradient of %s" % param_name
  41. self.param_name = param_name
  42. def __get__(self, obj, type=None):
  43. key = (obj.id, self.name)
  44. if key in obj._mem:
  45. return obj._mem.get(key)
  46. else:
  47. param_key = (obj.id, self.param_name)
  48. grad = obj._mem.add_gradient(key, param_key)
  49. return grad
  50. def __set__(self, obj, val):
  51. data = obj._mem.get((obj.id, self.name))
  52. data[:] = val
  53. class Synapses(Weights):
  54. pass
  55. class Biases(Weights):
  56. pass
  57. class Moment(Weights):
  58. pass
  59. def attributes(**specs):
  60. if not specs: # pragma: no cover
  61. raise ValueError("Must describe at least one attribute")
  62. def wrapped(cls):
  63. cls.descriptions = dict(cls.descriptions)
  64. cls.descriptions.update(specs)
  65. for attr, desc in cls.descriptions.items():
  66. setattr(cls, attr, desc)
  67. desc.name = attr
  68. return cls
  69. return wrapped
  70. def on_init(*callbacks):
  71. def wrapped(cls):
  72. cls.on_init_hooks = list(cls.on_init_hooks)
  73. cls.on_init_hooks.extend(callbacks)
  74. return cls
  75. return wrapped
  76. def on_data(*callbacks):
  77. def wrapped(cls):
  78. cls.on_data_hooks = list(cls.on_data_hooks)
  79. cls.on_data_hooks.extend(callbacks)
  80. return cls
  81. return wrapped
  82. def input(getter):
  83. def wrapped(cls):
  84. cls.describe_input = getter
  85. return cls
  86. return wrapped
  87. def output(getter):
  88. def wrapped(cls):
  89. cls.describe_output = getter
  90. return cls
  91. return wrapped