135 lines
4.8 KiB
Python
135 lines
4.8 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Reuters topic classification dataset.
|
||
|
"""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import json
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq
|
||
|
from tensorflow.python.keras.utils.data_utils import get_file
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
@tf_export('keras.datasets.reuters.load_data')
|
||
|
def load_data(path='reuters.npz',
|
||
|
num_words=None,
|
||
|
skip_top=0,
|
||
|
maxlen=None,
|
||
|
test_split=0.2,
|
||
|
seed=113,
|
||
|
start_char=1,
|
||
|
oov_char=2,
|
||
|
index_from=3,
|
||
|
**kwargs):
|
||
|
"""Loads the Reuters newswire classification dataset.
|
||
|
|
||
|
Arguments:
|
||
|
path: where to cache the data (relative to `~/.keras/dataset`).
|
||
|
num_words: max number of words to include. Words are ranked
|
||
|
by how often they occur (in the training set) and only
|
||
|
the most frequent words are kept
|
||
|
skip_top: skip the top N most frequently occurring words
|
||
|
(which may not be informative).
|
||
|
maxlen: truncate sequences after this length.
|
||
|
test_split: Fraction of the dataset to be used as test data.
|
||
|
seed: random seed for sample shuffling.
|
||
|
start_char: The start of a sequence will be marked with this character.
|
||
|
Set to 1 because 0 is usually the padding character.
|
||
|
oov_char: words that were cut out because of the `num_words`
|
||
|
or `skip_top` limit will be replaced with this character.
|
||
|
index_from: index actual words with this index and higher.
|
||
|
**kwargs: Used for backwards compatibility.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||
|
|
||
|
Note that the 'out of vocabulary' character is only used for
|
||
|
words that were present in the training set but are not included
|
||
|
because they're not making the `num_words` cut here.
|
||
|
Words that were not seen in the training set but are in the test set
|
||
|
have simply been skipped.
|
||
|
"""
|
||
|
# Legacy support
|
||
|
if 'nb_words' in kwargs:
|
||
|
logging.warning('The `nb_words` argument in `load_data` '
|
||
|
'has been renamed `num_words`.')
|
||
|
num_words = kwargs.pop('nb_words')
|
||
|
if kwargs:
|
||
|
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
|
||
|
|
||
|
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
|
||
|
path = get_file(
|
||
|
path,
|
||
|
origin=origin_folder + 'reuters.npz',
|
||
|
file_hash='87aedbeb0cb229e378797a632c1997b6')
|
||
|
with np.load(path) as f:
|
||
|
xs, labels = f['x'], f['y']
|
||
|
|
||
|
np.random.seed(seed)
|
||
|
indices = np.arange(len(xs))
|
||
|
np.random.shuffle(indices)
|
||
|
xs = xs[indices]
|
||
|
labels = labels[indices]
|
||
|
|
||
|
if start_char is not None:
|
||
|
xs = [[start_char] + [w + index_from for w in x] for x in xs]
|
||
|
elif index_from:
|
||
|
xs = [[w + index_from for w in x] for x in xs]
|
||
|
|
||
|
if maxlen:
|
||
|
xs, labels = _remove_long_seq(maxlen, xs, labels)
|
||
|
|
||
|
if not num_words:
|
||
|
num_words = max([max(x) for x in xs])
|
||
|
|
||
|
# by convention, use 2 as OOV word
|
||
|
# reserve 'index_from' (=3 by default) characters:
|
||
|
# 0 (padding), 1 (start), 2 (OOV)
|
||
|
if oov_char is not None:
|
||
|
xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
|
||
|
else:
|
||
|
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
|
||
|
|
||
|
idx = int(len(xs) * (1 - test_split))
|
||
|
x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
|
||
|
x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
|
||
|
|
||
|
return (x_train, y_train), (x_test, y_test)
|
||
|
|
||
|
|
||
|
@tf_export('keras.datasets.reuters.get_word_index')
|
||
|
def get_word_index(path='reuters_word_index.json'):
|
||
|
"""Retrieves the dictionary mapping word indices back to words.
|
||
|
|
||
|
Arguments:
|
||
|
path: where to cache the data (relative to `~/.keras/dataset`).
|
||
|
|
||
|
Returns:
|
||
|
The word index dictionary.
|
||
|
"""
|
||
|
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
|
||
|
path = get_file(
|
||
|
path,
|
||
|
origin=origin_folder + 'reuters_word_index.json',
|
||
|
file_hash='4d44cc38712099c9e383dc6e5f11a921')
|
||
|
with open(path) as f:
|
||
|
return json.load(f)
|