laywerrobot/lib/python3.6/site-packages/tensorflow/python/keras/datasets/reuters.py

135 lines
4.8 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reuters topic classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.datasets.reuters.load_data')
def load_data(path='reuters.npz',
num_words=None,
skip_top=0,
maxlen=None,
test_split=0.2,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
**kwargs):
"""Loads the Reuters newswire classification dataset.
Arguments:
path: where to cache the data (relative to `~/.keras/dataset`).
num_words: max number of words to include. Words are ranked
by how often they occur (in the training set) and only
the most frequent words are kept
skip_top: skip the top N most frequently occurring words
(which may not be informative).
maxlen: truncate sequences after this length.
test_split: Fraction of the dataset to be used as test data.
seed: random seed for sample shuffling.
start_char: The start of a sequence will be marked with this character.
Set to 1 because 0 is usually the padding character.
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
**kwargs: Used for backwards compatibility.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Note that the 'out of vocabulary' character is only used for
words that were present in the training set but are not included
because they're not making the `num_words` cut here.
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
# Legacy support
if 'nb_words' in kwargs:
logging.warning('The `nb_words` argument in `load_data` '
'has been renamed `num_words`.')
num_words = kwargs.pop('nb_words')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
with np.load(path) as f:
xs, labels = f['x'], f['y']
np.random.seed(seed)
indices = np.arange(len(xs))
np.random.shuffle(indices)
xs = xs[indices]
labels = labels[indices]
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
xs, labels = _remove_long_seq(maxlen, xs, labels)
if not num_words:
num_words = max([max(x) for x in xs])
# by convention, use 2 as OOV word
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
else:
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
idx = int(len(xs) * (1 - test_split))
x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
@tf_export('keras.datasets.reuters.get_word_index')
def get_word_index(path='reuters_word_index.json'):
"""Retrieves the dictionary mapping word indices back to words.
Arguments:
path: where to cache the data (relative to `~/.keras/dataset`).
Returns:
The word index dictionary.
"""
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'reuters_word_index.json',
file_hash='4d44cc38712099c9e383dc6e5f11a921')
with open(path) as f:
return json.load(f)