You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

123 lines
2.9 KiB

from functools import wraps
class AttributeDescription(object):
def __init__(self, text, value=None, *args, **kwargs):
self.name = None
self.text = text
self.value = value
def __call__(self, attr, model):
self.name = attr
def __get__(self, obj, type=None): # pragma: no cover
return self.value
def __set__(self, obj, val): # pragma: no cover
self.value = val
class Dimension(AttributeDescription):
def __get__(self, obj, type=None):
return obj._dims.get(self.name, None)
def __set__(self, obj, value):
obj._dims[self.name] = value
class Weights(AttributeDescription):
def __init__(self, text, get_shape, init=None):
self.name = None
self.text = text
self.get_shape = get_shape
self.init = init
def __get__(self, obj, type=None):
key = (obj.id, self.name)
if key in obj._mem:
return obj._mem[key]
else:
shape = self.get_shape(obj)
data = obj._mem.add(key, shape)
if self.init is not None:
self.init(data, obj.ops)
return data
def __set__(self, obj, val):
data = obj._mem.get((obj.id, self.name))
data[:] = val
class Gradient(AttributeDescription):
def __init__(self, param_name):
self.name = None
self.text = "Gradient of %s" % param_name
self.param_name = param_name
def __get__(self, obj, type=None):
key = (obj.id, self.name)
if key in obj._mem:
return obj._mem.get(key)
else:
param_key = (obj.id, self.param_name)
grad = obj._mem.add_gradient(key, param_key)
return grad
def __set__(self, obj, val):
data = obj._mem.get((obj.id, self.name))
data[:] = val
class Synapses(Weights):
pass
class Biases(Weights):
pass
class Moment(Weights):
pass
def attributes(**specs):
if not specs: # pragma: no cover
raise ValueError("Must describe at least one attribute")
def wrapped(cls):
cls.descriptions = dict(cls.descriptions)
cls.descriptions.update(specs)
for attr, desc in cls.descriptions.items():
setattr(cls, attr, desc)
desc.name = attr
return cls
return wrapped
def on_init(*callbacks):
def wrapped(cls):
cls.on_init_hooks = list(cls.on_init_hooks)
cls.on_init_hooks.extend(callbacks)
return cls
return wrapped
def on_data(*callbacks):
def wrapped(cls):
cls.on_data_hooks = list(cls.on_data_hooks)
cls.on_data_hooks.extend(callbacks)
return cls
return wrapped
def input(getter):
def wrapped(cls):
cls.describe_input = getter
return cls
return wrapped
def output(getter):
def wrapped(cls):
cls.describe_output = getter
return cls
return wrapped