laywerrobot/lib/python3.6/site-packages/gensim/test/test_tmdiff.py
2020-08-27 21:55:39 +02:00

72 lines
2.9 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2016 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
import logging
import unittest
import numpy as np
from gensim.models import LdaModel
from gensim.test.utils import common_dictionary, common_corpus
class TestLdaDiff(unittest.TestCase):
def setUp(self):
self.dictionary = common_dictionary
self.corpus = common_corpus
self.num_topics = 5
self.n_ann_terms = 10
self.model = LdaModel(corpus=self.corpus, id2word=self.dictionary, num_topics=self.num_topics, passes=10)
def testBasic(self):
# test for matrix case
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms)
self.assertEqual(mdiff.shape, (self.num_topics, self.num_topics))
self.assertEqual(len(annotation), self.num_topics)
self.assertEqual(len(annotation[0]), self.num_topics)
# test for diagonal case
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, diagonal=True)
self.assertEqual(mdiff.shape, (self.num_topics,))
self.assertEqual(len(annotation), self.num_topics)
def testIdentity(self):
for dist_name in ["hellinger", "kullback_leibler", "jaccard"]:
# test for matrix case
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name)
for row in annotation:
for (int_tokens, diff_tokens) in row:
self.assertEqual(diff_tokens, [])
self.assertEqual(len(int_tokens), self.n_ann_terms)
self.assertTrue(np.allclose(np.diag(mdiff), np.zeros(mdiff.shape[0], dtype=mdiff.dtype)))
if dist_name == "jaccard":
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
# test for diagonal case
mdiff, annotation = \
self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name, diagonal=True)
for (int_tokens, diff_tokens) in annotation:
self.assertEqual(diff_tokens, [])
self.assertEqual(len(int_tokens), self.n_ann_terms)
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
if dist_name == "jaccard":
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
def testInput(self):
self.assertRaises(ValueError, self.model.diff, self.model, n_ann_terms=self.n_ann_terms, distance='something')
self.assertRaises(ValueError, self.model.diff, [], n_ann_terms=self.n_ann_terms, distance='something')
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()