79 lines
3.3 KiB
Python
79 lines
3.3 KiB
Python
import logging
|
|
import unittest
|
|
import os
|
|
import gensim.downloader as api
|
|
from gensim.downloader import base_dir
|
|
import shutil
|
|
import numpy as np
|
|
|
|
|
|
@unittest.skipIf(
|
|
os.environ.get("SKIP_NETWORK_TESTS", False) == "1",
|
|
"Skip network-related tests (probably SSL problems on this CI/OS)"
|
|
)
|
|
class TestApi(unittest.TestCase):
|
|
def test_base_dir_creation(self):
|
|
if os.path.isdir(base_dir):
|
|
shutil.rmtree(base_dir)
|
|
api._create_base_dir()
|
|
self.assertTrue(os.path.isdir(base_dir))
|
|
os.rmdir(base_dir)
|
|
|
|
def test_load_dataset(self):
|
|
dataset_path = os.path.join(base_dir, "__testing_matrix-synopsis", "__testing_matrix-synopsis.gz")
|
|
if os.path.isdir(base_dir):
|
|
shutil.rmtree(base_dir)
|
|
self.assertEqual(api.load("__testing_matrix-synopsis", return_path=True), dataset_path)
|
|
shutil.rmtree(base_dir)
|
|
self.assertEqual(len(list(api.load("__testing_matrix-synopsis"))), 1)
|
|
shutil.rmtree(base_dir)
|
|
|
|
def test_load_model(self):
|
|
if os.path.isdir(base_dir):
|
|
shutil.rmtree(base_dir)
|
|
vector_dead = np.array([
|
|
0.17403787, -0.10167074, -0.00950371, -0.10367849, -0.14034484,
|
|
-0.08751217, 0.10030612, 0.07677923, -0.32563496, 0.01929072,
|
|
0.20521086, -0.1617067, 0.00475458, 0.21956187, -0.08783089,
|
|
-0.05937332, 0.26528183, -0.06771874, -0.12369668, 0.12020949,
|
|
0.28731, 0.36735833, 0.28051138, -0.10407482, 0.2496888,
|
|
-0.19372769, -0.28719661, 0.11989869, -0.00393865, -0.2431484,
|
|
0.02725661, -0.20421691, 0.0328669, -0.26947051, -0.08068217,
|
|
-0.10245913, 0.1170633, 0.16583319, 0.1183883, -0.11217165,
|
|
0.1261425, -0.0319365, -0.15787181, 0.03753783, 0.14748634,
|
|
0.00414471, -0.02296237, 0.18336892, -0.23840059, 0.17924534
|
|
])
|
|
dataset_path = os.path.join(
|
|
base_dir, "__testing_word2vec-matrix-synopsis", "__testing_word2vec-matrix-synopsis.gz"
|
|
)
|
|
model = api.load("__testing_word2vec-matrix-synopsis")
|
|
vector_dead_calc = model["dead"]
|
|
self.assertTrue(np.allclose(vector_dead, vector_dead_calc))
|
|
shutil.rmtree(base_dir)
|
|
self.assertEqual(api.load("__testing_word2vec-matrix-synopsis", return_path=True), dataset_path)
|
|
shutil.rmtree(base_dir)
|
|
|
|
def test_multipart_load(self):
|
|
dataset_path = os.path.join(
|
|
base_dir, '__testing_multipart-matrix-synopsis', '__testing_multipart-matrix-synopsis.gz'
|
|
)
|
|
if os.path.isdir(base_dir):
|
|
shutil.rmtree(base_dir)
|
|
self.assertEqual(dataset_path, api.load("__testing_multipart-matrix-synopsis", return_path=True))
|
|
shutil.rmtree(base_dir)
|
|
dataset = api.load("__testing_multipart-matrix-synopsis")
|
|
self.assertEqual(len(list(dataset)), 1)
|
|
|
|
def test_info(self):
|
|
data = api.info("text8")
|
|
self.assertEqual(data["parts"], 1)
|
|
self.assertEqual(data["file_name"], 'text8.gz')
|
|
data = api.info()
|
|
self.assertEqual(sorted(data.keys()), sorted(['models', 'corpora']))
|
|
self.assertTrue(len(data['models']))
|
|
self.assertTrue(len(data['corpora']))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
|
|
unittest.main()
|