124 lines
2.9 KiB
Python
124 lines
2.9 KiB
Python
|
from functools import wraps
|
||
|
|
||
|
|
||
|
class AttributeDescription(object):
|
||
|
def __init__(self, text, value=None, *args, **kwargs):
|
||
|
self.name = None
|
||
|
self.text = text
|
||
|
self.value = value
|
||
|
|
||
|
def __call__(self, attr, model):
|
||
|
self.name = attr
|
||
|
|
||
|
def __get__(self, obj, type=None): # pragma: no cover
|
||
|
return self.value
|
||
|
|
||
|
def __set__(self, obj, val): # pragma: no cover
|
||
|
self.value = val
|
||
|
|
||
|
|
||
|
class Dimension(AttributeDescription):
|
||
|
def __get__(self, obj, type=None):
|
||
|
return obj._dims.get(self.name, None)
|
||
|
|
||
|
def __set__(self, obj, value):
|
||
|
obj._dims[self.name] = value
|
||
|
|
||
|
|
||
|
class Weights(AttributeDescription):
|
||
|
def __init__(self, text, get_shape, init=None):
|
||
|
self.name = None
|
||
|
self.text = text
|
||
|
self.get_shape = get_shape
|
||
|
self.init = init
|
||
|
|
||
|
def __get__(self, obj, type=None):
|
||
|
key = (obj.id, self.name)
|
||
|
if key in obj._mem:
|
||
|
return obj._mem[key]
|
||
|
else:
|
||
|
shape = self.get_shape(obj)
|
||
|
data = obj._mem.add(key, shape)
|
||
|
if self.init is not None:
|
||
|
self.init(data, obj.ops)
|
||
|
return data
|
||
|
|
||
|
def __set__(self, obj, val):
|
||
|
data = obj._mem.get((obj.id, self.name))
|
||
|
data[:] = val
|
||
|
|
||
|
|
||
|
class Gradient(AttributeDescription):
|
||
|
def __init__(self, param_name):
|
||
|
self.name = None
|
||
|
self.text = "Gradient of %s" % param_name
|
||
|
self.param_name = param_name
|
||
|
|
||
|
def __get__(self, obj, type=None):
|
||
|
key = (obj.id, self.name)
|
||
|
if key in obj._mem:
|
||
|
return obj._mem.get(key)
|
||
|
else:
|
||
|
param_key = (obj.id, self.param_name)
|
||
|
grad = obj._mem.add_gradient(key, param_key)
|
||
|
return grad
|
||
|
|
||
|
def __set__(self, obj, val):
|
||
|
data = obj._mem.get((obj.id, self.name))
|
||
|
data[:] = val
|
||
|
|
||
|
|
||
|
class Synapses(Weights):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Biases(Weights):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Moment(Weights):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def attributes(**specs):
|
||
|
if not specs: # pragma: no cover
|
||
|
raise ValueError("Must describe at least one attribute")
|
||
|
def wrapped(cls):
|
||
|
cls.descriptions = dict(cls.descriptions)
|
||
|
cls.descriptions.update(specs)
|
||
|
for attr, desc in cls.descriptions.items():
|
||
|
setattr(cls, attr, desc)
|
||
|
desc.name = attr
|
||
|
return cls
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
def on_init(*callbacks):
|
||
|
def wrapped(cls):
|
||
|
cls.on_init_hooks = list(cls.on_init_hooks)
|
||
|
cls.on_init_hooks.extend(callbacks)
|
||
|
return cls
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
def on_data(*callbacks):
|
||
|
def wrapped(cls):
|
||
|
cls.on_data_hooks = list(cls.on_data_hooks)
|
||
|
cls.on_data_hooks.extend(callbacks)
|
||
|
return cls
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
def input(getter):
|
||
|
def wrapped(cls):
|
||
|
cls.describe_input = getter
|
||
|
return cls
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
def output(getter):
|
||
|
def wrapped(cls):
|
||
|
cls.describe_output = getter
|
||
|
return cls
|
||
|
return wrapped
|