562 lines
21 KiB
Python
562 lines
21 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
# pylint: disable=protected-access
|
||
|
"""Home of estimator related functions.
|
||
|
"""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import os
|
||
|
import re
|
||
|
from tensorflow.python.client import session
|
||
|
from tensorflow.python.estimator import estimator as estimator_lib
|
||
|
from tensorflow.python.estimator import export as export_lib
|
||
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||
|
from tensorflow.python.estimator import run_config as run_config_lib
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.framework import random_seed
|
||
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||
|
from tensorflow.python.framework import tensor_util
|
||
|
from tensorflow.python.keras import backend as K
|
||
|
from tensorflow.python.keras import models
|
||
|
from tensorflow.python.keras import optimizers
|
||
|
from tensorflow.python.keras.engine.base_layer import Layer
|
||
|
from tensorflow.python.keras.engine.network import Network
|
||
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||
|
from tensorflow.python.ops import check_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.ops import metrics as metrics_module
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.saved_model import signature_constants
|
||
|
from tensorflow.python.training import distribute as distribute_lib
|
||
|
from tensorflow.python.training import saver as saver_lib
|
||
|
from tensorflow.python.training import training_util
|
||
|
from tensorflow.python.training.checkpointable import base as checkpointable
|
||
|
from tensorflow.python.training.checkpointable import data_structures
|
||
|
|
||
|
|
||
|
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||
|
|
||
|
|
||
|
def _cast_tensor_to_floatx(x):
|
||
|
"""Cast tensor to keras's floatx dtype if it is not already the same dtype."""
|
||
|
if x.dtype == K.floatx():
|
||
|
return x
|
||
|
else:
|
||
|
return math_ops.cast(x, K.floatx())
|
||
|
|
||
|
|
||
|
def _convert_tensor(x):
|
||
|
"""Create or cast tensor if needed."""
|
||
|
if not tensor_util.is_tensor(x):
|
||
|
# x is a numpy array
|
||
|
x = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(x)
|
||
|
if check_ops.is_numeric_tensor(x):
|
||
|
# is_numeric_tensor returns False if provided with a numpy array
|
||
|
x = _cast_tensor_to_floatx(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def _any_weight_initialized(keras_model):
|
||
|
"""Check if any weights has been initialized in the Keras model.
|
||
|
|
||
|
Args:
|
||
|
keras_model: An instance of compiled keras model.
|
||
|
|
||
|
Returns:
|
||
|
boolean, True if at least one weight has been initialized, else False.
|
||
|
Currently keras initialize all weights at get_session().
|
||
|
"""
|
||
|
if keras_model is None:
|
||
|
return False
|
||
|
for layer in keras_model.layers:
|
||
|
for weight in layer.weights:
|
||
|
if hasattr(weight, '_keras_initialized'):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _create_ordered_io(keras_model, estimator_io, is_input=True):
|
||
|
"""Create a list of tensors from IO dictionary based on Keras IO order.
|
||
|
|
||
|
Args:
|
||
|
keras_model: An instance of compiled keras model.
|
||
|
estimator_io: The features or labels (dict or plain array) from model_fn.
|
||
|
is_input: True if dictionary is for inputs.
|
||
|
|
||
|
Returns:
|
||
|
A list of tensors based on Keras IO order.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if dictionary keys cannot be found in Keras model input_names
|
||
|
or output_names.
|
||
|
"""
|
||
|
if isinstance(estimator_io, (list, tuple)):
|
||
|
# Case currently not supported by most built-in input_fn,
|
||
|
# but it's good to have for sanity
|
||
|
return [_convert_tensor(x) for x in estimator_io]
|
||
|
elif isinstance(estimator_io, dict):
|
||
|
if is_input:
|
||
|
if keras_model._is_graph_network:
|
||
|
keras_io_names = keras_model.input_names
|
||
|
else:
|
||
|
keras_io_names = [
|
||
|
'input_%d' % i for i in range(1, len(estimator_io) + 1)]
|
||
|
else:
|
||
|
if keras_model._is_graph_network:
|
||
|
keras_io_names = keras_model.output_names
|
||
|
else:
|
||
|
keras_io_names = [
|
||
|
'output_%d' % i for i in range(1, len(estimator_io) + 1)]
|
||
|
|
||
|
for key in estimator_io:
|
||
|
if key not in keras_io_names:
|
||
|
raise ValueError(
|
||
|
'Cannot find %s with name "%s" in Keras Model. '
|
||
|
'It needs to match one '
|
||
|
'of the following: %s' % ('input' if is_input else 'output', key,
|
||
|
', '.join(keras_io_names)))
|
||
|
tensors = [_convert_tensor(estimator_io[io_name])
|
||
|
for io_name in keras_io_names]
|
||
|
return tensors
|
||
|
else:
|
||
|
# Plain array.
|
||
|
return _convert_tensor(estimator_io)
|
||
|
|
||
|
|
||
|
def _in_place_subclassed_model_reset(model):
|
||
|
"""Substitute for model cloning that works for subclassed models.
|
||
|
|
||
|
Subclassed models cannot be cloned because their topology is not serializable.
|
||
|
To "instantiate" an identical model in a new TF graph, we reuse the original
|
||
|
model object, but we clear its state.
|
||
|
|
||
|
After calling this function on a model instance, you can use the model
|
||
|
instance as if it were a model clone (in particular you can use it in a new
|
||
|
graph).
|
||
|
|
||
|
This method clears the state of the input model. It is thus destructive.
|
||
|
However the original state can be restored fully by calling
|
||
|
`_in_place_subclassed_model_state_restoration`.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of a Keras model created via subclassing.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case the model uses a subclassed model as inner layer.
|
||
|
"""
|
||
|
assert not model._is_graph_network # Only makes sense for subclassed networks
|
||
|
# Retrieve all layers tracked by the model as well as their attribute names
|
||
|
attributes_cache = {}
|
||
|
for name in dir(model):
|
||
|
try:
|
||
|
value = getattr(model, name)
|
||
|
except (AttributeError, ValueError, TypeError):
|
||
|
continue
|
||
|
if isinstance(value, Layer):
|
||
|
attributes_cache[name] = value
|
||
|
assert value in model._layers
|
||
|
elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
|
||
|
# Handle case: list/tuple of layers (also tracked by the Network API).
|
||
|
if value and all(isinstance(val, Layer) for val in value):
|
||
|
raise ValueError('We do not support the use of list-of-layers '
|
||
|
'attributes in subclassed models used with '
|
||
|
'`model_to_estimator` at this time. Found list '
|
||
|
'model: %s' % name)
|
||
|
|
||
|
# Replace layers on the model with fresh layers
|
||
|
layers_to_names = {value: key for key, value in attributes_cache.items()}
|
||
|
original_layers = model._layers[:]
|
||
|
model._layers = []
|
||
|
for layer in original_layers: # We preserve layer order.
|
||
|
config = layer.get_config()
|
||
|
# This will not work for nested subclassed models used as layers.
|
||
|
# This would be theoretically possible to support, but would add complexity.
|
||
|
# Only do it if users complain.
|
||
|
if isinstance(layer, Network) and not layer._is_graph_network:
|
||
|
raise ValueError('We do not support the use of nested subclassed models '
|
||
|
'in `model_to_estimator` at this time. Found nested '
|
||
|
'model: %s' % layer)
|
||
|
fresh_layer = layer.__class__.from_config(config)
|
||
|
name = layers_to_names[layer]
|
||
|
setattr(model, name, fresh_layer)
|
||
|
|
||
|
# Cache original model build attributes (in addition to layers)
|
||
|
if (not hasattr(model, '_original_attributes_cache') or
|
||
|
model._original_attributes_cache is None):
|
||
|
if model.built:
|
||
|
attributes_to_cache = [
|
||
|
'inputs',
|
||
|
'outputs',
|
||
|
'_feed_outputs',
|
||
|
'_feed_output_names',
|
||
|
'_feed_output_shapes',
|
||
|
'_feed_loss_fns',
|
||
|
'loss_weights_list',
|
||
|
'targets',
|
||
|
'_feed_targets',
|
||
|
'sample_weight_modes',
|
||
|
'weighted_metrics',
|
||
|
'metrics_names',
|
||
|
'metrics_tensors',
|
||
|
'metrics_updates',
|
||
|
'stateful_metric_names',
|
||
|
'total_loss',
|
||
|
'sample_weights',
|
||
|
'_feed_sample_weights',
|
||
|
'train_function',
|
||
|
'test_function',
|
||
|
'predict_function',
|
||
|
'_collected_trainable_weights',
|
||
|
'_feed_inputs',
|
||
|
'_feed_input_names',
|
||
|
'_feed_input_shapes',
|
||
|
'optimizer',
|
||
|
]
|
||
|
for name in attributes_to_cache:
|
||
|
attributes_cache[name] = getattr(model, name)
|
||
|
model._original_attributes_cache = attributes_cache
|
||
|
# Reset built state
|
||
|
model.built = False
|
||
|
model.inputs = None
|
||
|
model.outputs = None
|
||
|
|
||
|
|
||
|
def _in_place_subclassed_model_state_restoration(model):
|
||
|
"""Restores the original state of a model after it was "reset".
|
||
|
|
||
|
This undoes this action of `_in_place_subclassed_model_reset`.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of a Keras model created via subclassing, on which
|
||
|
`_in_place_subclassed_model_reset` was previously called.
|
||
|
"""
|
||
|
assert not model._is_graph_network
|
||
|
# Restore layers and build attributes
|
||
|
if (hasattr(model, '_original_attributes_cache') and
|
||
|
model._original_attributes_cache is not None):
|
||
|
# Models have sticky attribute assignment, so we want to be careful to add
|
||
|
# back the previous attributes and track Layers by their original names
|
||
|
# without adding dependencies on "utility" attributes which Models exempt
|
||
|
# when they're constructed.
|
||
|
model._layers = data_structures.NoDependency([])
|
||
|
for name, value in model._original_attributes_cache.items():
|
||
|
if not isinstance(value, checkpointable.CheckpointableBase):
|
||
|
# If this value is not already checkpointable, it's probably that way
|
||
|
# for a reason; we don't want to start tracking data structures that the
|
||
|
# original Model didn't.
|
||
|
value = data_structures.NoDependency(value)
|
||
|
setattr(model, name, value)
|
||
|
model._original_attributes_cache = None
|
||
|
else:
|
||
|
# Restore to the state of a never-called model.
|
||
|
model.built = False
|
||
|
model.inputs = None
|
||
|
model.outputs = None
|
||
|
|
||
|
|
||
|
def _clone_and_build_model(mode,
|
||
|
keras_model,
|
||
|
custom_objects,
|
||
|
features=None,
|
||
|
labels=None):
|
||
|
"""Clone and build the given keras_model.
|
||
|
|
||
|
Args:
|
||
|
mode: training mode.
|
||
|
keras_model: an instance of compiled keras model.
|
||
|
custom_objects: Dictionary for custom objects.
|
||
|
features: Dict of tensors.
|
||
|
labels: Dict of tensors, or single tensor instance.
|
||
|
|
||
|
Returns:
|
||
|
The newly built model.
|
||
|
"""
|
||
|
# Set to True during training, False for inference.
|
||
|
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
|
||
|
|
||
|
# Get list of inputs.
|
||
|
if features is None:
|
||
|
input_tensors = None
|
||
|
else:
|
||
|
input_tensors = _create_ordered_io(keras_model,
|
||
|
estimator_io=features,
|
||
|
is_input=True)
|
||
|
# Get list of outputs.
|
||
|
if labels is None:
|
||
|
target_tensors = None
|
||
|
elif isinstance(labels, dict):
|
||
|
target_tensors = _create_ordered_io(keras_model,
|
||
|
estimator_io=labels,
|
||
|
is_input=False)
|
||
|
else:
|
||
|
target_tensors = [
|
||
|
_convert_tensor(labels)
|
||
|
]
|
||
|
|
||
|
if keras_model._is_graph_network:
|
||
|
if custom_objects:
|
||
|
with CustomObjectScope(custom_objects):
|
||
|
model = models.clone_model(keras_model, input_tensors=input_tensors)
|
||
|
else:
|
||
|
model = models.clone_model(keras_model, input_tensors=input_tensors)
|
||
|
else:
|
||
|
model = keras_model
|
||
|
_in_place_subclassed_model_reset(model)
|
||
|
if input_tensors is not None:
|
||
|
model._set_inputs(input_tensors)
|
||
|
|
||
|
# Compile/Build model
|
||
|
if mode is model_fn_lib.ModeKeys.PREDICT:
|
||
|
if isinstance(model, models.Sequential):
|
||
|
model.build()
|
||
|
else:
|
||
|
if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
|
||
|
optimizer = keras_model.optimizer
|
||
|
else:
|
||
|
optimizer_config = keras_model.optimizer.get_config()
|
||
|
optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
|
||
|
optimizer.iterations = training_util.get_or_create_global_step()
|
||
|
|
||
|
model.compile(
|
||
|
optimizer,
|
||
|
keras_model.loss,
|
||
|
metrics=keras_model.metrics,
|
||
|
loss_weights=keras_model.loss_weights,
|
||
|
sample_weight_mode=keras_model.sample_weight_mode,
|
||
|
weighted_metrics=keras_model.weighted_metrics,
|
||
|
target_tensors=target_tensors)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def _create_keras_model_fn(keras_model, custom_objects=None):
|
||
|
"""Creates model_fn for keras Estimator.
|
||
|
|
||
|
Args:
|
||
|
keras_model: an instance of compiled keras model.
|
||
|
custom_objects: Dictionary for custom objects.
|
||
|
|
||
|
Returns:
|
||
|
The model_fn for a keras Estimator.
|
||
|
"""
|
||
|
|
||
|
def model_fn(features, labels, mode):
|
||
|
"""model_fn for keras Estimator."""
|
||
|
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
|
||
|
labels)
|
||
|
model_output_names = []
|
||
|
# We need to make sure that the output names of the last layer in the model
|
||
|
# is the same for each of the cloned models. This is required for mirrored
|
||
|
# strategy when we call regroup.
|
||
|
if distribute_lib.has_distribution_strategy():
|
||
|
for name in model.output_names:
|
||
|
name = re.compile(r'_\d$').sub('', name)
|
||
|
model_output_names.append(name)
|
||
|
else:
|
||
|
model_output_names = model.output_names
|
||
|
|
||
|
# Get inputs to EstimatorSpec
|
||
|
predictions = dict(zip(model_output_names, model.outputs))
|
||
|
|
||
|
loss = None
|
||
|
train_op = None
|
||
|
eval_metric_ops = None
|
||
|
|
||
|
# Set loss and metric only during train and evaluate.
|
||
|
if mode is not model_fn_lib.ModeKeys.PREDICT:
|
||
|
if mode is model_fn_lib.ModeKeys.TRAIN:
|
||
|
model._make_train_function() # pylint: disable=protected-access
|
||
|
else:
|
||
|
model._make_test_function() # pylint: disable=protected-access
|
||
|
loss = model.total_loss
|
||
|
|
||
|
if model.metrics:
|
||
|
# TODO(fchollet): support stateful metrics
|
||
|
eval_metric_ops = {}
|
||
|
# When each metric maps to an output
|
||
|
if isinstance(model.metrics, dict):
|
||
|
for i, output_name in enumerate(model.metrics.keys()):
|
||
|
metric_name = model.metrics[output_name]
|
||
|
if callable(metric_name):
|
||
|
metric_name = metric_name.__name__
|
||
|
# When some outputs use the same metric
|
||
|
if list(model.metrics.values()).count(metric_name) > 1:
|
||
|
metric_name += '_' + output_name
|
||
|
eval_metric_ops[metric_name] = metrics_module.mean(
|
||
|
model.metrics_tensors[i - len(model.metrics)])
|
||
|
else:
|
||
|
for i, metric_name in enumerate(model.metrics):
|
||
|
if callable(metric_name):
|
||
|
metric_name = metric_name.__name__
|
||
|
eval_metric_ops[metric_name] = metrics_module.mean(
|
||
|
model.metrics_tensors[i])
|
||
|
|
||
|
# Set train_op only during train.
|
||
|
if mode is model_fn_lib.ModeKeys.TRAIN:
|
||
|
train_op = model.train_function.updates_op
|
||
|
|
||
|
if not model._is_graph_network:
|
||
|
# Reset model state to original state,
|
||
|
# to avoid `model_fn` being destructive for the initial model argument.
|
||
|
_in_place_subclassed_model_state_restoration(keras_model)
|
||
|
return model_fn_lib.EstimatorSpec(
|
||
|
mode=mode,
|
||
|
predictions=predictions,
|
||
|
loss=loss,
|
||
|
train_op=train_op,
|
||
|
eval_metric_ops=eval_metric_ops,
|
||
|
export_outputs={
|
||
|
_DEFAULT_SERVING_KEY:
|
||
|
export_lib.export_output.PredictOutput(predictions)
|
||
|
})
|
||
|
|
||
|
return model_fn
|
||
|
|
||
|
|
||
|
def _save_first_checkpoint(keras_model, estimator, custom_objects,
|
||
|
keras_weights):
|
||
|
"""Save first checkpoint for the keras Estimator.
|
||
|
|
||
|
Args:
|
||
|
keras_model: an instance of compiled keras model.
|
||
|
estimator: keras estimator.
|
||
|
custom_objects: Dictionary for custom objects.
|
||
|
keras_weights: A flat list of Numpy arrays for weights of given keras_model.
|
||
|
|
||
|
Returns:
|
||
|
The model_fn for a keras Estimator.
|
||
|
"""
|
||
|
# Load weights and save to checkpoint if there is no checkpoint
|
||
|
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
|
||
|
if not latest_path:
|
||
|
with ops.Graph().as_default():
|
||
|
random_seed.set_random_seed(estimator.config.tf_random_seed)
|
||
|
training_util.create_global_step()
|
||
|
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
|
||
|
custom_objects)
|
||
|
# save to checkpoint
|
||
|
with session.Session(config=estimator._session_config) as sess:
|
||
|
if keras_weights:
|
||
|
model.set_weights(keras_weights)
|
||
|
# Make update ops and initialize all variables.
|
||
|
if not model.train_function:
|
||
|
# pylint: disable=protected-access
|
||
|
model._make_train_function()
|
||
|
K._initialize_variables(sess)
|
||
|
# pylint: enable=protected-access
|
||
|
saver = saver_lib.Saver()
|
||
|
saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
|
||
|
|
||
|
|
||
|
def model_to_estimator(keras_model=None,
|
||
|
keras_model_path=None,
|
||
|
custom_objects=None,
|
||
|
model_dir=None,
|
||
|
config=None):
|
||
|
"""Constructs an `Estimator` instance from given keras model.
|
||
|
|
||
|
For usage example, please see
|
||
|
@{$guide/estimators$creating_estimators_from_keras_models}.
|
||
|
|
||
|
Args:
|
||
|
keras_model: A compiled Keras model object. This argument is mutually
|
||
|
exclusive with `keras_model_path`.
|
||
|
keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
|
||
|
format, which can be generated with the `save()` method of a Keras model.
|
||
|
This argument is mutually exclusive with `keras_model`.
|
||
|
custom_objects: Dictionary for custom objects.
|
||
|
model_dir: Directory to save Estimator model parameters, graph, summary
|
||
|
files for TensorBoard, etc.
|
||
|
config: Configuration object.
|
||
|
|
||
|
Returns:
|
||
|
An Estimator from given keras model.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if neither keras_model nor keras_model_path was given.
|
||
|
ValueError: if both keras_model and keras_model_path was given.
|
||
|
ValueError: if the keras_model_path is a GCS URI.
|
||
|
ValueError: if keras_model has not been compiled.
|
||
|
"""
|
||
|
if not (keras_model or keras_model_path):
|
||
|
raise ValueError(
|
||
|
'Either `keras_model` or `keras_model_path` needs to be provided.')
|
||
|
if keras_model and keras_model_path:
|
||
|
raise ValueError(
|
||
|
'Please specity either `keras_model` or `keras_model_path`, '
|
||
|
'but not both.')
|
||
|
|
||
|
if not keras_model:
|
||
|
if keras_model_path.startswith(
|
||
|
'gs://') or 'storage.googleapis.com' in keras_model_path:
|
||
|
raise ValueError(
|
||
|
'%s is not a local path. Please copy the model locally first.' %
|
||
|
keras_model_path)
|
||
|
logging.info('Loading models from %s', keras_model_path)
|
||
|
keras_model = models.load_model(keras_model_path)
|
||
|
else:
|
||
|
logging.info('Using the Keras model provided.')
|
||
|
keras_model = keras_model
|
||
|
|
||
|
if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
|
||
|
raise ValueError(
|
||
|
'The given keras model has not been compiled yet. '
|
||
|
'Please compile the model with `model.compile()` '
|
||
|
'before calling `model_to_estimator()`.')
|
||
|
|
||
|
if isinstance(config, dict):
|
||
|
config = run_config_lib.RunConfig(**config)
|
||
|
|
||
|
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
|
||
|
estimator = estimator_lib.Estimator(
|
||
|
keras_model_fn, model_dir=model_dir, config=config)
|
||
|
|
||
|
# Check if we need to call get_weights:
|
||
|
if _any_weight_initialized(keras_model):
|
||
|
keras_weights = keras_model.get_weights()
|
||
|
# Warn if config passed to estimator tries to update GPUOptions. If a
|
||
|
# session has already been created, the GPUOptions passed to the first
|
||
|
# session sticks.
|
||
|
if estimator._session_config.HasField('gpu_options'):
|
||
|
logging.warning(
|
||
|
'The Keras backend session has already been set. '
|
||
|
'The _session_config passed to model_to_estimator will not be used.')
|
||
|
else:
|
||
|
# Pass the config into keras backend's default session.
|
||
|
sess = session.Session(config=estimator._session_config)
|
||
|
K.set_session(sess)
|
||
|
keras_weights = None
|
||
|
|
||
|
if keras_model._is_graph_network:
|
||
|
# TODO(yifeif): move checkpoint initialization to scaffold.init_fn
|
||
|
_save_first_checkpoint(keras_model,
|
||
|
estimator,
|
||
|
custom_objects,
|
||
|
keras_weights)
|
||
|
elif keras_model.built:
|
||
|
logging.warning('You are creating an Estimator from a Keras model '
|
||
|
'manually subclassed from `Model`, that was '
|
||
|
'already called on some inputs (and thus already had '
|
||
|
'weights). We are currently unable to preserve '
|
||
|
'the model\'s state (its weights) '
|
||
|
'as part of the estimator '
|
||
|
'in this case. Be warned that the estimator '
|
||
|
'has been created using '
|
||
|
'a freshly initialized version of your model.\n'
|
||
|
'Note that this doesn\'t affect the state of the '
|
||
|
'model instance you passed as `keras_model` argument.')
|
||
|
return estimator
|