# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """An example of training and predicting with a TFTS estimator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys import numpy as np import tensorflow as tf try: import matplotlib # pylint: disable=g-import-not-at-top matplotlib.use("TkAgg") # Need Tk for interactive plots. from matplotlib import pyplot # pylint: disable=g-import-not-at-top HAS_MATPLOTLIB = True except ImportError: # Plotting requires matplotlib, but the unit test running this code may # execute in an environment without it (i.e. matplotlib is not a build # dependency). We'd still like to test the TensorFlow-dependent parts of this # example, namely train_and_predict. HAS_MATPLOTLIB = False FLAGS = None def structural_ensemble_train_and_predict(csv_file_name): # Cycle between 5 latent values over a period of 100. This leads to a very # smooth periodic component (and a small model), which is a good fit for our # example data. Modeling high-frequency periodic variations will require a # higher cycle_num_latent_values. structural = tf.contrib.timeseries.StructuralEnsembleRegressor( periodicities=100, num_features=1, cycle_num_latent_values=5) return train_and_predict(structural, csv_file_name, training_steps=150) def ar_train_and_predict(csv_file_name): # An autoregressive model, with periodicity handled as a time-based # regression. Note that this requires windows of size 16 (input_window_size + # output_window_size) for training. ar = tf.contrib.timeseries.ARRegressor( periodicities=100, input_window_size=10, output_window_size=6, num_features=1, # Use the (default) normal likelihood loss to adaptively fit the # variance. SQUARED_LOSS overestimates variance when there are trends in # the series. loss=tf.contrib.timeseries.ARModel.NORMAL_LIKELIHOOD_LOSS) return train_and_predict(ar, csv_file_name, training_steps=600) def train_and_predict(estimator, csv_file_name, training_steps): """A simple example of training and predicting.""" # Read data in the default "time,value" CSV format with no header reader = tf.contrib.timeseries.CSVReader(csv_file_name) # Set up windowing and batching for training train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( reader, batch_size=16, window_size=16) # Fit model parameters to data estimator.train(input_fn=train_input_fn, steps=training_steps) # Evaluate on the full dataset sequentially, collecting in-sample predictions # for a qualitative evaluation. Note that this loads the whole dataset into # memory. For quantitative evaluation, use RandomWindowChunker. evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader) evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1) # Predict starting after the evaluation (predictions,) = tuple(estimator.predict( input_fn=tf.contrib.timeseries.predict_continuation_input_fn( evaluation, steps=200))) times = evaluation["times"][0] observed = evaluation["observed"][0, :, 0] mean = np.squeeze(np.concatenate( [evaluation["mean"][0], predictions["mean"]], axis=0)) variance = np.squeeze(np.concatenate( [evaluation["covariance"][0], predictions["covariance"]], axis=0)) all_times = np.concatenate([times, predictions["times"]], axis=0) upper_limit = mean + np.sqrt(variance) lower_limit = mean - np.sqrt(variance) return times, observed, all_times, mean, upper_limit, lower_limit def make_plot(name, training_times, observed, all_times, mean, upper_limit, lower_limit): """Plot a time series in a new figure.""" pyplot.figure() pyplot.plot(training_times, observed, "b", label="training series") pyplot.plot(all_times, mean, "r", label="forecast") pyplot.plot(all_times, upper_limit, "g", label="forecast upper bound") pyplot.plot(all_times, lower_limit, "g", label="forecast lower bound") pyplot.fill_between(all_times, lower_limit, upper_limit, color="grey", alpha="0.2") pyplot.axvline(training_times[-1], color="k", linestyle="--") pyplot.xlabel("time") pyplot.ylabel("observations") pyplot.legend(loc=0) pyplot.title(name) def main(unused_argv): if not HAS_MATPLOTLIB: raise ImportError( "Please install matplotlib to generate a plot from this example.") make_plot("Structural ensemble", *structural_ensemble_train_and_predict(FLAGS.input_filename)) make_plot("AR", *ar_train_and_predict(FLAGS.input_filename)) pyplot.show() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--input_filename", type=str, required=True, help="Input csv file.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)