laywerrobot/lib/python3.6/site-packages/gensim/test/test_fasttext_wrapper.py
2020-08-27 21:55:39 +02:00

373 lines
17 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""
Automated tests for checking transformation algorithms (the models package).
"""
import logging
import unittest
import os
import numpy
from gensim.models.wrappers import fasttext
from gensim.models import keyedvectors
from gensim.test.utils import datapath, get_tmpfile
logger = logging.getLogger(__name__)
class TestFastText(unittest.TestCase):
def setUp(self):
ft_home = os.environ.get('FT_HOME', None)
self.ft_path = os.path.join(ft_home, 'fasttext') if ft_home else None
self.corpus_file = datapath('lee_background.cor')
self.test_model_file = datapath('lee_fasttext')
self.test_new_model_file = datapath('lee_fasttext_new')
# Load pre-trained model to perform tests in case FastText binary isn't available in test environment
self.test_model = fasttext.FastText.load_fasttext_format(self.test_model_file)
def model_sanity(self, model):
"""Even tiny models trained on any corpus should pass these sanity checks"""
self.assertEqual(model.wv.syn0.shape, (len(model.wv.vocab), model.vector_size))
self.assertEqual(model.wv.syn0_ngrams.shape, (model.num_ngram_vectors, model.vector_size))
def models_equal(self, model1, model2):
self.assertEqual(len(model1.wv.vocab), len(model2.wv.vocab))
self.assertEqual(set(model1.wv.vocab.keys()), set(model2.wv.vocab.keys()))
self.assertTrue(numpy.allclose(model1.wv.syn0, model2.wv.syn0))
self.assertTrue(numpy.allclose(model1.wv.syn0_ngrams, model2.wv.syn0_ngrams))
def testTraining(self):
"""Test self.test_model successfully trained, parameters and weights correctly loaded"""
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return # Use self.skipTest once python < 2.7 is no longer supported
vocab_size, model_size = 1763, 10
tmpf = get_tmpfile('gensim_fasttext_wrapper.tst')
trained_model = fasttext.FastText.train(
self.ft_path, self.corpus_file, size=model_size, output_file=tmpf
)
self.assertEqual(trained_model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(trained_model.wv.vocab), vocab_size)
self.assertEqual(trained_model.wv.syn0_ngrams.shape[1], model_size)
self.model_sanity(trained_model)
# Tests temporary training files deleted
self.assertFalse(os.path.exists('%s.bin' % tmpf))
def testMinCount(self):
"""Tests words with frequency less than `min_count` absent from vocab"""
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return # Use self.skipTest once python < 2.7 is no longer supported
tmpf = get_tmpfile('gensim_fasttext_wrapper.tst')
test_model_min_count_5 = fasttext.FastText.train(
self.ft_path, self.corpus_file, output_file=tmpf, size=10, min_count=5
)
self.assertTrue('forests' not in test_model_min_count_5.wv.vocab)
test_model_min_count_1 = fasttext.FastText.train(
self.ft_path, self.corpus_file, output_file=tmpf, size=10, min_count=1
)
self.assertTrue('forests' in test_model_min_count_1.wv.vocab)
def testModelSize(self):
"""Tests output vector dimensions are the same as the value for `size` param"""
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return # Use self.skipTest once python < 2.7 is no longer supported
tmpf = get_tmpfile('gensim_fasttext_wrapper.tst')
test_model_size_20 = fasttext.FastText.train(
self.ft_path, self.corpus_file, output_file=tmpf, size=20
)
self.assertEqual(test_model_size_20.vector_size, 20)
self.assertEqual(test_model_size_20.wv.syn0.shape[1], 20)
self.assertEqual(test_model_size_20.wv.syn0_ngrams.shape[1], 20)
def testPersistence(self):
"""Test storing/loading the entire model."""
tmpf = get_tmpfile('gensim_fasttext_wrapper.tst')
self.test_model.save(tmpf)
loaded = fasttext.FastText.load(tmpf)
self.models_equal(self.test_model, loaded)
self.test_model.save(tmpf, sep_limit=0)
self.models_equal(self.test_model, fasttext.FastText.load(tmpf))
def testNormalizedVectorsNotSaved(self):
"""Test syn0norm/syn0_ngrams_norm aren't saved in model file"""
tmpf = get_tmpfile('gensim_fasttext_wrapper.tst')
self.test_model.init_sims()
self.test_model.save(tmpf)
loaded = fasttext.FastText.load(tmpf)
self.assertTrue(loaded.wv.syn0norm is None)
self.assertTrue(loaded.wv.syn0_ngrams_norm is None)
wv = self.test_model.wv
wv.save(tmpf)
loaded_kv = keyedvectors.KeyedVectors.load(tmpf)
self.assertTrue(loaded_kv.syn0norm is None)
self.assertTrue(loaded_kv.syn0_ngrams_norm is None)
def testLoadFastTextFormat(self):
"""Test model successfully loaded from fastText .bin file"""
try:
model = fasttext.FastText.load_fasttext_format(self.test_model_file)
except Exception as exc:
self.fail('Unable to load FastText model from file %s: %s' % (self.test_model_file, exc))
vocab_size, model_size = 1762, 10
self.assertEqual(model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(model.wv.vocab), vocab_size, model_size)
self.assertEqual(model.wv.syn0_ngrams.shape, (model.num_ngram_vectors, model_size))
expected_vec = [
-0.57144,
-0.0085561,
0.15748,
-0.67855,
-0.25459,
-0.58077,
-0.09913,
1.1447,
0.23418,
0.060007
] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin
self.assertTrue(numpy.allclose(model["hundred"], expected_vec, atol=1e-4))
# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.23825,
-0.58482,
-0.22276,
-0.41215,
0.91015,
-1.6786,
-0.26724,
0.58818,
0.57828,
0.75801
]
self.assertTrue(numpy.allclose(model["rejection"], expected_vec_oov, atol=1e-4))
self.assertEqual(model.min_count, 5)
self.assertEqual(model.window, 5)
self.assertEqual(model.iter, 5)
self.assertEqual(model.negative, 5)
self.assertEqual(model.sample, 0.0001)
self.assertEqual(model.bucket, 1000)
self.assertEqual(model.wv.max_n, 6)
self.assertEqual(model.wv.min_n, 3)
self.model_sanity(model)
def testLoadFastTextNewFormat(self):
""" Test model successfully loaded from fastText (new format) .bin file """
try:
new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file)
except Exception as exc:
self.fail('Unable to load FastText model from file %s: %s' % (self.test_new_model_file, exc))
vocab_size, model_size = 1763, 10
self.assertEqual(new_model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(new_model.wv.vocab), vocab_size, model_size)
self.assertEqual(new_model.wv.syn0_ngrams.shape, (new_model.num_ngram_vectors, model_size))
expected_vec = [
-0.025627,
-0.11448,
0.18116,
-0.96779,
0.2532,
-0.93224,
0.3929,
0.12679,
-0.19685,
-0.13179
] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin
self.assertTrue(numpy.allclose(new_model["hundred"], expected_vec, atol=1e-4))
# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.53378,
-0.19,
0.013482,
-0.86767,
-0.21684,
-0.89928,
0.45124,
0.18025,
-0.14128,
0.22508
]
self.assertTrue(numpy.allclose(new_model["rejection"], expected_vec_oov, atol=1e-4))
self.assertEqual(new_model.min_count, 5)
self.assertEqual(new_model.window, 5)
self.assertEqual(new_model.iter, 5)
self.assertEqual(new_model.negative, 5)
self.assertEqual(new_model.sample, 0.0001)
self.assertEqual(new_model.bucket, 1000)
self.assertEqual(new_model.wv.max_n, 6)
self.assertEqual(new_model.wv.min_n, 3)
self.model_sanity(new_model)
def testLoadFileName(self):
""" Test model accepts input as both `/path/to/model` or `/path/to/model.bin` """
self.assertTrue(fasttext.FastText.load_fasttext_format(datapath('lee_fasttext_new')))
self.assertTrue(fasttext.FastText.load_fasttext_format(datapath('lee_fasttext_new.bin')))
def testLoadModelSupervised(self):
"""Test loading model with supervised learning labels"""
with self.assertRaises(NotImplementedError):
fasttext.FastText.load_fasttext_format(datapath('pang_lee_polarity_fasttext'))
def testLoadModelWithNonAsciiVocab(self):
"""Test loading model with non-ascii words in vocab"""
model = fasttext.FastText.load_fasttext_format(datapath('non_ascii_fasttext'))
self.assertTrue(u'který' in model)
try:
vector = model[u'který'] # noqa:F841
except UnicodeDecodeError:
self.fail('Unable to access vector for utf8 encoded non-ascii word')
def testLoadModelNonUtf8Encoding(self):
"""Test loading model with words in user-specified encoding"""
model = fasttext.FastText.load_fasttext_format(datapath('cp852_fasttext'), encoding='cp852')
self.assertTrue(u'který' in model)
try:
vector = model[u'který'] # noqa:F841
except KeyError:
self.fail('Unable to access vector for cp-852 word')
def testNSimilarity(self):
"""Test n_similarity for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertTrue(numpy.allclose(self.test_model.n_similarity(['the', 'and'], ['and', 'the']), 1.0))
self.assertEqual(self.test_model.n_similarity(['the'], ['and']), self.test_model.n_similarity(['and'], ['the']))
# Out of vocab check
self.assertTrue(numpy.allclose(self.test_model.n_similarity(['night', 'nights'], ['nights', 'night']), 1.0))
self.assertEqual(
self.test_model.n_similarity(['night'], ['nights']),
self.test_model.n_similarity(['nights'], ['night'])
)
def testSimilarity(self):
"""Test similarity for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertTrue(numpy.allclose(self.test_model.similarity('the', 'the'), 1.0))
self.assertEqual(self.test_model.similarity('the', 'and'), self.test_model.similarity('and', 'the'))
# Out of vocab check
self.assertTrue(numpy.allclose(self.test_model.similarity('nights', 'nights'), 1.0))
self.assertEqual(self.test_model.similarity('night', 'nights'), self.test_model.similarity('nights', 'night'))
def testMostSimilar(self):
"""Test most_similar for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertEqual(len(self.test_model.most_similar(positive=['the', 'and'], topn=5)), 5)
self.assertEqual(self.test_model.most_similar('the'), self.test_model.most_similar(positive=['the']))
# Out of vocab check
self.assertEqual(len(self.test_model.most_similar(['night', 'nights'], topn=5)), 5)
self.assertEqual(self.test_model.most_similar('nights'), self.test_model.most_similar(positive=['nights']))
def testMostSimilarCosmul(self):
"""Test most_similar_cosmul for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertEqual(len(self.test_model.most_similar_cosmul(positive=['the', 'and'], topn=5)), 5)
self.assertEqual(
self.test_model.most_similar_cosmul('the'),
self.test_model.most_similar_cosmul(positive=['the']))
# Out of vocab check
self.assertEqual(len(self.test_model.most_similar_cosmul(['night', 'nights'], topn=5)), 5)
self.assertEqual(
self.test_model.most_similar_cosmul('nights'),
self.test_model.most_similar_cosmul(positive=['nights']))
def testLookup(self):
"""Tests word vector lookup for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertTrue('night' in self.test_model.wv.vocab)
self.assertTrue(numpy.allclose(self.test_model['night'], self.test_model[['night']]))
# Out of vocab check
self.assertFalse('nights' in self.test_model.wv.vocab)
self.assertTrue(numpy.allclose(self.test_model['nights'], self.test_model[['nights']]))
# Word with no ngrams in model
self.assertRaises(KeyError, lambda: self.test_model['a!@'])
def testContains(self):
"""Tests __contains__ for in-vocab and out-of-vocab words"""
# In vocab, sanity check
self.assertTrue('night' in self.test_model.wv.vocab)
self.assertTrue('night' in self.test_model)
# Out of vocab check
self.assertFalse('nights' in self.test_model.wv.vocab)
self.assertTrue('nights' in self.test_model)
# Word with no ngrams in model
self.assertFalse('a!@' in self.test_model.wv.vocab)
self.assertFalse('a!@' in self.test_model)
def testWmdistance(self):
"""Tests wmdistance for docs with in-vocab and out-of-vocab words"""
doc = ['night', 'payment']
oov_doc = ['nights', 'forests', 'payments']
ngrams_absent_doc = ['a!@', 'b#$']
dist = self.test_model.wmdistance(doc, oov_doc)
self.assertNotEqual(float('inf'), dist)
dist = self.test_model.wmdistance(doc, ngrams_absent_doc)
self.assertEqual(float('inf'), dist)
def testDoesntMatch(self):
"""Tests doesnt_match for list of out-of-vocab words"""
oov_words = ['nights', 'forests', 'payments']
# Out of vocab check
for word in oov_words:
self.assertFalse(word in self.test_model.wv.vocab)
try:
self.test_model.doesnt_match(oov_words)
except Exception:
self.fail('model.doesnt_match raises exception for oov words')
def testHash(self):
# Tests FastText.ft_hash method return values to those obtained from original C implementation
ft_hash = fasttext.ft_hash('test')
self.assertEqual(ft_hash, 2949673445)
ft_hash = fasttext.ft_hash('word')
self.assertEqual(ft_hash, 1788406269)
def testConsistentDtype(self):
"""Test that the same dtype is returned for OOV words as for words in the vocabulary"""
vocab_word = 'night'
oov_word = 'wordnotpresentinvocabulary'
self.assertIn(vocab_word, self.test_model.wv.vocab)
self.assertNotIn(oov_word, self.test_model.wv.vocab)
vocab_embedding = self.test_model[vocab_word]
oov_embedding = self.test_model[oov_word]
self.assertEqual(vocab_embedding.dtype, oov_embedding.dtype)
def testPersistenceForOldVersions(self):
"""Test backward compatibility for models saved with versions < 3.0.0"""
old_model_path = datapath('ft_model_2.3.0')
loaded_model = fasttext.FastText.load(old_model_path)
self.assertEqual(loaded_model.vector_size, 10)
self.assertEqual(loaded_model.wv.syn0.shape[1], 10)
self.assertEqual(loaded_model.wv.syn0_ngrams.shape[1], 10)
# in-vocab word
in_expected_vec = numpy.array([-2.44566941, -1.54802394, -2.61103821, -1.88549316, 1.02860415,
1.19031894, 2.01627707, 1.98942184, -1.39095843, -0.65036952])
self.assertTrue(numpy.allclose(loaded_model["the"], in_expected_vec, atol=1e-4))
# out-of-vocab word
out_expected_vec = numpy.array([-1.34948218, -0.8686831, -1.51483142, -1.0164026, 0.56272298,
0.66228276, 1.06477463, 1.1355902, -0.80972326, -0.39845538])
self.assertTrue(numpy.allclose(loaded_model["random_word"], out_expected_vec, atol=1e-4))
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()