from functools import wraps class AttributeDescription(object): def __init__(self, text, value=None, *args, **kwargs): self.name = None self.text = text self.value = value def __call__(self, attr, model): self.name = attr def __get__(self, obj, type=None): # pragma: no cover return self.value def __set__(self, obj, val): # pragma: no cover self.value = val class Dimension(AttributeDescription): def __get__(self, obj, type=None): return obj._dims.get(self.name, None) def __set__(self, obj, value): obj._dims[self.name] = value class Weights(AttributeDescription): def __init__(self, text, get_shape, init=None): self.name = None self.text = text self.get_shape = get_shape self.init = init def __get__(self, obj, type=None): key = (obj.id, self.name) if key in obj._mem: return obj._mem[key] else: shape = self.get_shape(obj) data = obj._mem.add(key, shape) if self.init is not None: self.init(data, obj.ops) return data def __set__(self, obj, val): data = obj._mem.get((obj.id, self.name)) data[:] = val class Gradient(AttributeDescription): def __init__(self, param_name): self.name = None self.text = "Gradient of %s" % param_name self.param_name = param_name def __get__(self, obj, type=None): key = (obj.id, self.name) if key in obj._mem: return obj._mem.get(key) else: param_key = (obj.id, self.param_name) grad = obj._mem.add_gradient(key, param_key) return grad def __set__(self, obj, val): data = obj._mem.get((obj.id, self.name)) data[:] = val class Synapses(Weights): pass class Biases(Weights): pass class Moment(Weights): pass def attributes(**specs): if not specs: # pragma: no cover raise ValueError("Must describe at least one attribute") def wrapped(cls): cls.descriptions = dict(cls.descriptions) cls.descriptions.update(specs) for attr, desc in cls.descriptions.items(): setattr(cls, attr, desc) desc.name = attr return cls return wrapped def on_init(*callbacks): def wrapped(cls): cls.on_init_hooks = list(cls.on_init_hooks) cls.on_init_hooks.extend(callbacks) return cls return wrapped def on_data(*callbacks): def wrapped(cls): cls.on_data_hooks = list(cls.on_data_hooks) cls.on_data_hooks.extend(callbacks) return cls return wrapped def input(getter): def wrapped(cls): cls.describe_input = getter return cls return wrapped def output(getter): def wrapped(cls): cls.describe_output = getter return cls return wrapped