laywerrobot/lib/python3.6/site-packages/tensorflow/python/estimator/keras.py
2020-08-27 21:55:39 +02:00

561 lines
21 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Home of estimator related functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
def _cast_tensor_to_floatx(x):
"""Cast tensor to keras's floatx dtype if it is not already the same dtype."""
if x.dtype == K.floatx():
return x
else:
return math_ops.cast(x, K.floatx())
def _convert_tensor(x):
"""Create or cast tensor if needed."""
if not tensor_util.is_tensor(x):
# x is a numpy array
x = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(x)
if check_ops.is_numeric_tensor(x):
# is_numeric_tensor returns False if provided with a numpy array
x = _cast_tensor_to_floatx(x)
return x
def _any_weight_initialized(keras_model):
"""Check if any weights has been initialized in the Keras model.
Args:
keras_model: An instance of compiled keras model.
Returns:
boolean, True if at least one weight has been initialized, else False.
Currently keras initialize all weights at get_session().
"""
if keras_model is None:
return False
for layer in keras_model.layers:
for weight in layer.weights:
if hasattr(weight, '_keras_initialized'):
return True
return False
def _create_ordered_io(keras_model, estimator_io, is_input=True):
"""Create a list of tensors from IO dictionary based on Keras IO order.
Args:
keras_model: An instance of compiled keras model.
estimator_io: The features or labels (dict or plain array) from model_fn.
is_input: True if dictionary is for inputs.
Returns:
A list of tensors based on Keras IO order.
Raises:
ValueError: if dictionary keys cannot be found in Keras model input_names
or output_names.
"""
if isinstance(estimator_io, (list, tuple)):
# Case currently not supported by most built-in input_fn,
# but it's good to have for sanity
return [_convert_tensor(x) for x in estimator_io]
elif isinstance(estimator_io, dict):
if is_input:
if keras_model._is_graph_network:
keras_io_names = keras_model.input_names
else:
keras_io_names = [
'input_%d' % i for i in range(1, len(estimator_io) + 1)]
else:
if keras_model._is_graph_network:
keras_io_names = keras_model.output_names
else:
keras_io_names = [
'output_%d' % i for i in range(1, len(estimator_io) + 1)]
for key in estimator_io:
if key not in keras_io_names:
raise ValueError(
'Cannot find %s with name "%s" in Keras Model. '
'It needs to match one '
'of the following: %s' % ('input' if is_input else 'output', key,
', '.join(keras_io_names)))
tensors = [_convert_tensor(estimator_io[io_name])
for io_name in keras_io_names]
return tensors
else:
# Plain array.
return _convert_tensor(estimator_io)
def _in_place_subclassed_model_reset(model):
"""Substitute for model cloning that works for subclassed models.
Subclassed models cannot be cloned because their topology is not serializable.
To "instantiate" an identical model in a new TF graph, we reuse the original
model object, but we clear its state.
After calling this function on a model instance, you can use the model
instance as if it were a model clone (in particular you can use it in a new
graph).
This method clears the state of the input model. It is thus destructive.
However the original state can be restored fully by calling
`_in_place_subclassed_model_state_restoration`.
Args:
model: Instance of a Keras model created via subclassing.
Raises:
ValueError: In case the model uses a subclassed model as inner layer.
"""
assert not model._is_graph_network # Only makes sense for subclassed networks
# Retrieve all layers tracked by the model as well as their attribute names
attributes_cache = {}
for name in dir(model):
try:
value = getattr(model, name)
except (AttributeError, ValueError, TypeError):
continue
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model._layers
elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
'attributes in subclassed models used with '
'`model_to_estimator` at this time. Found list '
'model: %s' % name)
# Replace layers on the model with fresh layers
layers_to_names = {value: key for key, value in attributes_cache.items()}
original_layers = model._layers[:]
model._layers = []
for layer in original_layers: # We preserve layer order.
config = layer.get_config()
# This will not work for nested subclassed models used as layers.
# This would be theoretically possible to support, but would add complexity.
# Only do it if users complain.
if isinstance(layer, Network) and not layer._is_graph_network:
raise ValueError('We do not support the use of nested subclassed models '
'in `model_to_estimator` at this time. Found nested '
'model: %s' % layer)
fresh_layer = layer.__class__.from_config(config)
name = layers_to_names[layer]
setattr(model, name, fresh_layer)
# Cache original model build attributes (in addition to layers)
if (not hasattr(model, '_original_attributes_cache') or
model._original_attributes_cache is None):
if model.built:
attributes_to_cache = [
'inputs',
'outputs',
'_feed_outputs',
'_feed_output_names',
'_feed_output_shapes',
'_feed_loss_fns',
'loss_weights_list',
'targets',
'_feed_targets',
'sample_weight_modes',
'weighted_metrics',
'metrics_names',
'metrics_tensors',
'metrics_updates',
'stateful_metric_names',
'total_loss',
'sample_weights',
'_feed_sample_weights',
'train_function',
'test_function',
'predict_function',
'_collected_trainable_weights',
'_feed_inputs',
'_feed_input_names',
'_feed_input_shapes',
'optimizer',
]
for name in attributes_to_cache:
attributes_cache[name] = getattr(model, name)
model._original_attributes_cache = attributes_cache
# Reset built state
model.built = False
model.inputs = None
model.outputs = None
def _in_place_subclassed_model_state_restoration(model):
"""Restores the original state of a model after it was "reset".
This undoes this action of `_in_place_subclassed_model_reset`.
Args:
model: Instance of a Keras model created via subclassing, on which
`_in_place_subclassed_model_reset` was previously called.
"""
assert not model._is_graph_network
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
# Models have sticky attribute assignment, so we want to be careful to add
# back the previous attributes and track Layers by their original names
# without adding dependencies on "utility" attributes which Models exempt
# when they're constructed.
model._layers = data_structures.NoDependency([])
for name, value in model._original_attributes_cache.items():
if not isinstance(value, checkpointable.CheckpointableBase):
# If this value is not already checkpointable, it's probably that way
# for a reason; we don't want to start tracking data structures that the
# original Model didn't.
value = data_structures.NoDependency(value)
setattr(model, name, value)
model._original_attributes_cache = None
else:
# Restore to the state of a never-called model.
model.built = False
model.inputs = None
model.outputs = None
def _clone_and_build_model(mode,
keras_model,
custom_objects,
features=None,
labels=None):
"""Clone and build the given keras_model.
Args:
mode: training mode.
keras_model: an instance of compiled keras model.
custom_objects: Dictionary for custom objects.
features: Dict of tensors.
labels: Dict of tensors, or single tensor instance.
Returns:
The newly built model.
"""
# Set to True during training, False for inference.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
# Get list of inputs.
if features is None:
input_tensors = None
else:
input_tensors = _create_ordered_io(keras_model,
estimator_io=features,
is_input=True)
# Get list of outputs.
if labels is None:
target_tensors = None
elif isinstance(labels, dict):
target_tensors = _create_ordered_io(keras_model,
estimator_io=labels,
is_input=False)
else:
target_tensors = [
_convert_tensor(labels)
]
if keras_model._is_graph_network:
if custom_objects:
with CustomObjectScope(custom_objects):
model = models.clone_model(keras_model, input_tensors=input_tensors)
else:
model = models.clone_model(keras_model, input_tensors=input_tensors)
else:
model = keras_model
_in_place_subclassed_model_reset(model)
if input_tensors is not None:
model._set_inputs(input_tensors)
# Compile/Build model
if mode is model_fn_lib.ModeKeys.PREDICT:
if isinstance(model, models.Sequential):
model.build()
else:
if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
optimizer = keras_model.optimizer
else:
optimizer_config = keras_model.optimizer.get_config()
optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
optimizer.iterations = training_util.get_or_create_global_step()
model.compile(
optimizer,
keras_model.loss,
metrics=keras_model.metrics,
loss_weights=keras_model.loss_weights,
sample_weight_mode=keras_model.sample_weight_mode,
weighted_metrics=keras_model.weighted_metrics,
target_tensors=target_tensors)
return model
def _create_keras_model_fn(keras_model, custom_objects=None):
"""Creates model_fn for keras Estimator.
Args:
keras_model: an instance of compiled keras model.
custom_objects: Dictionary for custom objects.
Returns:
The model_fn for a keras Estimator.
"""
def model_fn(features, labels, mode):
"""model_fn for keras Estimator."""
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
labels)
model_output_names = []
# We need to make sure that the output names of the last layer in the model
# is the same for each of the cloned models. This is required for mirrored
# strategy when we call regroup.
if distribute_lib.has_distribution_strategy():
for name in model.output_names:
name = re.compile(r'_\d$').sub('', name)
model_output_names.append(name)
else:
model_output_names = model.output_names
# Get inputs to EstimatorSpec
predictions = dict(zip(model_output_names, model.outputs))
loss = None
train_op = None
eval_metric_ops = None
# Set loss and metric only during train and evaluate.
if mode is not model_fn_lib.ModeKeys.PREDICT:
if mode is model_fn_lib.ModeKeys.TRAIN:
model._make_train_function() # pylint: disable=protected-access
else:
model._make_test_function() # pylint: disable=protected-access
loss = model.total_loss
if model.metrics:
# TODO(fchollet): support stateful metrics
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
metric_name = model.metrics[output_name]
if callable(metric_name):
metric_name = metric_name.__name__
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
eval_metric_ops[metric_name] = metrics_module.mean(
model.metrics_tensors[i - len(model.metrics)])
else:
for i, metric_name in enumerate(model.metrics):
if callable(metric_name):
metric_name = metric_name.__name__
eval_metric_ops[metric_name] = metrics_module.mean(
model.metrics_tensors[i])
# Set train_op only during train.
if mode is model_fn_lib.ModeKeys.TRAIN:
train_op = model.train_function.updates_op
if not model._is_graph_network:
# Reset model state to original state,
# to avoid `model_fn` being destructive for the initial model argument.
_in_place_subclassed_model_state_restoration(keras_model)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
export_outputs={
_DEFAULT_SERVING_KEY:
export_lib.export_output.PredictOutput(predictions)
})
return model_fn
def _save_first_checkpoint(keras_model, estimator, custom_objects,
keras_weights):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
estimator: keras estimator.
custom_objects: Dictionary for custom objects.
keras_weights: A flat list of Numpy arrays for weights of given keras_model.
Returns:
The model_fn for a keras Estimator.
"""
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with ops.Graph().as_default():
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
# save to checkpoint
with session.Session(config=estimator._session_config) as sess:
if keras_weights:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None):
"""Constructs an `Estimator` instance from given keras model.
For usage example, please see
@{$guide/estimators$creating_estimators_from_keras_models}.
Args:
keras_model: A compiled Keras model object. This argument is mutually
exclusive with `keras_model_path`.
keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
format, which can be generated with the `save()` method of a Keras model.
This argument is mutually exclusive with `keras_model`.
custom_objects: Dictionary for custom objects.
model_dir: Directory to save Estimator model parameters, graph, summary
files for TensorBoard, etc.
config: Configuration object.
Returns:
An Estimator from given keras model.
Raises:
ValueError: if neither keras_model nor keras_model_path was given.
ValueError: if both keras_model and keras_model_path was given.
ValueError: if the keras_model_path is a GCS URI.
ValueError: if keras_model has not been compiled.
"""
if not (keras_model or keras_model_path):
raise ValueError(
'Either `keras_model` or `keras_model_path` needs to be provided.')
if keras_model and keras_model_path:
raise ValueError(
'Please specity either `keras_model` or `keras_model_path`, '
'but not both.')
if not keras_model:
if keras_model_path.startswith(
'gs://') or 'storage.googleapis.com' in keras_model_path:
raise ValueError(
'%s is not a local path. Please copy the model locally first.' %
keras_model_path)
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
logging.info('Using the Keras model provided.')
keras_model = keras_model
if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
raise ValueError(
'The given keras model has not been compiled yet. '
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
if isinstance(config, dict):
config = run_config_lib.RunConfig(**config)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
estimator = estimator_lib.Estimator(
keras_model_fn, model_dir=model_dir, config=config)
# Check if we need to call get_weights:
if _any_weight_initialized(keras_model):
keras_weights = keras_model.get_weights()
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first
# session sticks.
if estimator._session_config.HasField('gpu_options'):
logging.warning(
'The Keras backend session has already been set. '
'The _session_config passed to model_to_estimator will not be used.')
else:
# Pass the config into keras backend's default session.
sess = session.Session(config=estimator._session_config)
K.set_session(sess)
keras_weights = None
if keras_model._is_graph_network:
# TODO(yifeif): move checkpoint initialization to scaffold.init_fn
_save_first_checkpoint(keras_model,
estimator,
custom_objects,
keras_weights)
elif keras_model.built:
logging.warning('You are creating an Estimator from a Keras model '
'manually subclassed from `Model`, that was '
'already called on some inputs (and thus already had '
'weights). We are currently unable to preserve '
'the model\'s state (its weights) '
'as part of the estimator '
'in this case. Be warned that the estimator '
'has been created using '
'a freshly initialized version of your model.\n'
'Note that this doesn\'t affect the state of the '
'model instance you passed as `keras_model` argument.')
return estimator