laywerrobot/lib/python3.6/site-packages/gensim/test/test_api.py
2020-08-27 21:55:39 +02:00

79 lines
3.3 KiB
Python

import logging
import unittest
import os
import gensim.downloader as api
from gensim.downloader import base_dir
import shutil
import numpy as np
@unittest.skipIf(
os.environ.get("SKIP_NETWORK_TESTS", False) == "1",
"Skip network-related tests (probably SSL problems on this CI/OS)"
)
class TestApi(unittest.TestCase):
def test_base_dir_creation(self):
if os.path.isdir(base_dir):
shutil.rmtree(base_dir)
api._create_base_dir()
self.assertTrue(os.path.isdir(base_dir))
os.rmdir(base_dir)
def test_load_dataset(self):
dataset_path = os.path.join(base_dir, "__testing_matrix-synopsis", "__testing_matrix-synopsis.gz")
if os.path.isdir(base_dir):
shutil.rmtree(base_dir)
self.assertEqual(api.load("__testing_matrix-synopsis", return_path=True), dataset_path)
shutil.rmtree(base_dir)
self.assertEqual(len(list(api.load("__testing_matrix-synopsis"))), 1)
shutil.rmtree(base_dir)
def test_load_model(self):
if os.path.isdir(base_dir):
shutil.rmtree(base_dir)
vector_dead = np.array([
0.17403787, -0.10167074, -0.00950371, -0.10367849, -0.14034484,
-0.08751217, 0.10030612, 0.07677923, -0.32563496, 0.01929072,
0.20521086, -0.1617067, 0.00475458, 0.21956187, -0.08783089,
-0.05937332, 0.26528183, -0.06771874, -0.12369668, 0.12020949,
0.28731, 0.36735833, 0.28051138, -0.10407482, 0.2496888,
-0.19372769, -0.28719661, 0.11989869, -0.00393865, -0.2431484,
0.02725661, -0.20421691, 0.0328669, -0.26947051, -0.08068217,
-0.10245913, 0.1170633, 0.16583319, 0.1183883, -0.11217165,
0.1261425, -0.0319365, -0.15787181, 0.03753783, 0.14748634,
0.00414471, -0.02296237, 0.18336892, -0.23840059, 0.17924534
])
dataset_path = os.path.join(
base_dir, "__testing_word2vec-matrix-synopsis", "__testing_word2vec-matrix-synopsis.gz"
)
model = api.load("__testing_word2vec-matrix-synopsis")
vector_dead_calc = model["dead"]
self.assertTrue(np.allclose(vector_dead, vector_dead_calc))
shutil.rmtree(base_dir)
self.assertEqual(api.load("__testing_word2vec-matrix-synopsis", return_path=True), dataset_path)
shutil.rmtree(base_dir)
def test_multipart_load(self):
dataset_path = os.path.join(
base_dir, '__testing_multipart-matrix-synopsis', '__testing_multipart-matrix-synopsis.gz'
)
if os.path.isdir(base_dir):
shutil.rmtree(base_dir)
self.assertEqual(dataset_path, api.load("__testing_multipart-matrix-synopsis", return_path=True))
shutil.rmtree(base_dir)
dataset = api.load("__testing_multipart-matrix-synopsis")
self.assertEqual(len(list(dataset)), 1)
def test_info(self):
data = api.info("text8")
self.assertEqual(data["parts"], 1)
self.assertEqual(data["file_name"], 'text8.gz')
data = api.info()
self.assertEqual(sorted(data.keys()), sorted(['models', 'corpora']))
self.assertTrue(len(data['models']))
self.assertTrue(len(data['corpora']))
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
unittest.main()