72 lines
2.9 KiB
Python
72 lines
2.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright (C) 2016 Radim Rehurek <radimrehurek@seznam.cz>
|
|
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
|
|
|
|
import logging
|
|
import unittest
|
|
import numpy as np
|
|
|
|
from gensim.models import LdaModel
|
|
from gensim.test.utils import common_dictionary, common_corpus
|
|
|
|
|
|
class TestLdaDiff(unittest.TestCase):
|
|
def setUp(self):
|
|
self.dictionary = common_dictionary
|
|
self.corpus = common_corpus
|
|
self.num_topics = 5
|
|
self.n_ann_terms = 10
|
|
self.model = LdaModel(corpus=self.corpus, id2word=self.dictionary, num_topics=self.num_topics, passes=10)
|
|
|
|
def testBasic(self):
|
|
# test for matrix case
|
|
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms)
|
|
|
|
self.assertEqual(mdiff.shape, (self.num_topics, self.num_topics))
|
|
self.assertEqual(len(annotation), self.num_topics)
|
|
self.assertEqual(len(annotation[0]), self.num_topics)
|
|
|
|
# test for diagonal case
|
|
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, diagonal=True)
|
|
|
|
self.assertEqual(mdiff.shape, (self.num_topics,))
|
|
self.assertEqual(len(annotation), self.num_topics)
|
|
|
|
def testIdentity(self):
|
|
for dist_name in ["hellinger", "kullback_leibler", "jaccard"]:
|
|
# test for matrix case
|
|
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name)
|
|
|
|
for row in annotation:
|
|
for (int_tokens, diff_tokens) in row:
|
|
self.assertEqual(diff_tokens, [])
|
|
self.assertEqual(len(int_tokens), self.n_ann_terms)
|
|
|
|
self.assertTrue(np.allclose(np.diag(mdiff), np.zeros(mdiff.shape[0], dtype=mdiff.dtype)))
|
|
|
|
if dist_name == "jaccard":
|
|
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
|
|
|
|
# test for diagonal case
|
|
mdiff, annotation = \
|
|
self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name, diagonal=True)
|
|
|
|
for (int_tokens, diff_tokens) in annotation:
|
|
self.assertEqual(diff_tokens, [])
|
|
self.assertEqual(len(int_tokens), self.n_ann_terms)
|
|
|
|
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
|
|
|
|
if dist_name == "jaccard":
|
|
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype)))
|
|
|
|
def testInput(self):
|
|
self.assertRaises(ValueError, self.model.diff, self.model, n_ann_terms=self.n_ann_terms, distance='something')
|
|
self.assertRaises(ValueError, self.model.diff, [], n_ann_terms=self.n_ann_terms, distance='something')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
|
|
unittest.main()
|