laywerrobot/lib/python3.6/site-packages/gensim/test/test_wordrank_wrapper.py
2020-08-27 21:55:39 +02:00

81 lines
2.8 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""
Automated tests for checking transformation algorithms (the models package).
"""
import logging
import unittest
import os
import numpy
from gensim.models.wrappers import wordrank
from gensim.test.utils import datapath, get_tmpfile
class TestWordrank(unittest.TestCase):
def setUp(self):
wr_home = os.environ.get('WR_HOME', None)
self.wr_path = wr_home if wr_home else None
self.corpus_file = datapath('lee.cor')
self.out_name = 'testmodel'
self.wr_file = datapath('test_glove.txt')
if not self.wr_path:
return
self.test_model = wordrank.Wordrank.train(
self.wr_path, self.corpus_file, self.out_name, iter=6,
dump_period=5, period=5, np=4, cleanup_files=True
)
def testLoadWordrankFormat(self):
"""Test model successfully loaded from Wordrank format file"""
model = wordrank.Wordrank.load_wordrank_model(self.wr_file)
vocab_size, dim = 76, 50
self.assertEqual(model.syn0.shape, (vocab_size, dim))
self.assertEqual(len(model.vocab), vocab_size)
os.remove(self.wr_file + '.w2vformat')
def testEnsemble(self):
"""Test ensemble of two embeddings"""
if not self.wr_path:
return
new_emb = self.test_model.ensemble_embedding(self.wr_file, self.wr_file)
self.assertEqual(new_emb.shape, (76, 50))
os.remove(self.wr_file + '.w2vformat')
def testPersistence(self):
"""Test storing/loading the entire model"""
if not self.wr_path:
return
tmpf = get_tmpfile('gensim_wordrank.test')
self.test_model.save(tmpf)
loaded = wordrank.Wordrank.load(tmpf)
self.models_equal(self.test_model, loaded)
def testSimilarity(self):
"""Test n_similarity for vocab words"""
if not self.wr_path:
return
self.assertTrue(numpy.allclose(self.test_model.n_similarity(['the', 'and'], ['and', 'the']), 1.0))
self.assertEqual(self.test_model.similarity('the', 'and'), self.test_model.similarity('the', 'and'))
def testLookup(self):
if not self.wr_path:
return
self.assertTrue(numpy.allclose(self.test_model['night'], self.test_model[['night']]))
def models_equal(self, model, model2):
self.assertEqual(len(model.vocab), len(model2.vocab))
self.assertEqual(set(model.vocab.keys()), set(model2.vocab.keys()))
self.assertTrue(numpy.allclose(model.syn0, model2.syn0))
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()