73 lines
2.9 KiB
Python
73 lines
2.9 KiB
Python
|
"""Dependency tracking for checkpointable objects."""
|
||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
from tensorflow.python.training.checkpointable import base
|
||
|
from tensorflow.python.training.checkpointable import data_structures
|
||
|
|
||
|
|
||
|
class NotCheckpointable(object):
|
||
|
"""Marks instances of child classes as unsaveable using an object-based API.
|
||
|
|
||
|
Useful for marking objects which would otherwise look checkpointable because
|
||
|
of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from
|
||
|
`NotCheckpointable` does not prevent an object from being assigned to any
|
||
|
attributes, but will throw an error on save/restore.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Checkpointable(base.CheckpointableBase):
|
||
|
"""Manages dependencies on other objects.
|
||
|
|
||
|
`Checkpointable` objects may have dependencies: other `Checkpointable` objects
|
||
|
which should be saved if the object declaring the dependency is saved. A
|
||
|
correctly saveable program has a dependency graph such that if changing a
|
||
|
global variable affects an object (e.g. changes the behavior of any of its
|
||
|
methods) then there is a chain of dependencies from the influenced object to
|
||
|
the variable.
|
||
|
|
||
|
Dependency edges have names, and are created implicitly when a
|
||
|
`Checkpointable` object is assigned to an attribute of another
|
||
|
`Checkpointable` object. For example:
|
||
|
|
||
|
```
|
||
|
obj = Checkpointable()
|
||
|
obj.v = ResourceVariable(0.)
|
||
|
```
|
||
|
|
||
|
The `Checkpointable` object `obj` now has a dependency named "v" on a
|
||
|
variable.
|
||
|
|
||
|
`Checkpointable` objects may specify `Tensor`s to be saved and restored
|
||
|
directly (e.g. a `Variable` indicating how to save itself) rather than through
|
||
|
dependencies on other objects. See
|
||
|
`Checkpointable._gather_saveables_for_checkpoint` for details.
|
||
|
"""
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
"""Support self.foo = checkpointable syntax."""
|
||
|
if getattr(self, "_setattr_tracking", True):
|
||
|
value = data_structures.sticky_attribute_assignment(
|
||
|
checkpointable=self, value=value, name=name)
|
||
|
super(Checkpointable, self).__setattr__(name, value)
|
||
|
|
||
|
def _no_dependency(self, value):
|
||
|
"""Override to allow CheckpointableBase to disable dependency tracking."""
|
||
|
return data_structures.NoDependency(value)
|