laywerrobot/lib/python3.6/site-packages/gensim/test/test_ldavowpalwabbit_wrapper.py
2020-08-27 21:55:39 +02:00

215 lines
7.6 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""
Tests for Vowpal Wabbit LDA wrapper.
Will not be run unless the environment variable 'VOWPAL_WABBIT_PATH' is set
and points to the `vw` executable.
"""
import logging
import unittest
import os
import os.path
import tempfile
from collections import defaultdict
import six
from gensim.corpora import Dictionary
import gensim.models.wrappers.ldavowpalwabbit as ldavowpalwabbit
from gensim.models.wrappers.ldavowpalwabbit import LdaVowpalWabbit
from gensim.test.utils import datapath
# set up vars used in testing ("Deerwester" from the web tutorial)
TOPIC_WORDS = [
'cat lion leopard mouse jaguar lynx cheetah tiger kitten puppy'.split(),
'engine car wheel brakes tyre motor suspension cylinder exhaust clutch'.split(),
'alice bob robert tim sue rachel dave harry alex jim'.split(),
'c cplusplus go python haskell scala java ruby csharp erlang'.split(),
'eggs ham mushrooms cereal coffee beans tea juice sausages bacon'.split()
]
def get_corpus():
text_path = datapath('ldavowpalwabbit.txt')
dict_path = datapath('ldavowpalwabbit.dict.txt')
dictionary = Dictionary.load_from_text(dict_path)
with open(text_path) as fhandle:
corpus = [dictionary.doc2bow(l.strip().split()) for l in fhandle]
return corpus, dictionary
class TestLdaVowpalWabbit(unittest.TestCase):
def setUp(self):
vw_path = os.environ.get('VOWPAL_WABBIT_PATH', None)
if not vw_path:
msg = "Environment variable 'VOWPAL_WABBIT_PATH' not specified, skipping tests"
try:
raise unittest.SkipTest(msg)
except AttributeError:
# couldn't find a way of skipping tests in python 2.6
self.vw_path = None
corpus, dictionary = get_corpus()
self.vw_path = vw_path
self.corpus = corpus
self.dictionary = dictionary
def test_save_load(self):
"""Test loading/saving LdaVowpalWabbit model."""
if not self.vw_path: # for python 2.6
return
lda = LdaVowpalWabbit(
self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
id2word=self.dictionary, cleanup_files=True, alpha=0.1,
eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
)
with tempfile.NamedTemporaryFile() as fhandle:
lda.save(fhandle.name)
lda2 = LdaVowpalWabbit.load(fhandle.name)
# ensure public fields are saved/loaded correctly
saved_fields = [
lda.alpha, lda.chunksize, lda.cleanup_files,
lda.decay, lda.eta, lda.gamma_threshold,
lda.id2word, lda.num_terms, lda.num_topics,
lda.passes, lda.random_seed, lda.vw_path
]
loaded_fields = [
lda2.alpha, lda2.chunksize, lda2.cleanup_files,
lda2.decay, lda2.eta, lda2.gamma_threshold,
lda2.id2word, lda2.num_terms, lda2.num_topics,
lda2.passes, lda2.random_seed, lda2.vw_path
]
self.assertEqual(saved_fields, loaded_fields)
# ensure topic matrices are saved/loaded correctly
saved_topics = lda.show_topics(num_topics=5, num_words=10)
loaded_topics = lda2.show_topics(num_topics=5, num_words=10)
self.assertEqual(loaded_topics, saved_topics)
def test_model_update(self):
"""Test updating existing LdaVowpalWabbit model."""
if not self.vw_path: # for python 2.6
return
lda = LdaVowpalWabbit(
self.vw_path, corpus=[self.corpus[0]], passes=10, chunksize=256,
id2word=self.dictionary, cleanup_files=True, alpha=0.1,
eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
)
lda.update(self.corpus[1:])
result = lda.log_perplexity(self.corpus)
self.assertTrue(result < -1)
self.assertTrue(result > -5)
def test_perplexity(self):
"""Test LdaVowpalWabbit perplexity is within expected range."""
if not self.vw_path: # for python 2.6
return
lda = LdaVowpalWabbit(
self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
id2word=self.dictionary, cleanup_files=True, alpha=0.1,
eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1)
# varies, but should be between -1 and -5
result = lda.log_perplexity(self.corpus)
self.assertTrue(result < -1)
self.assertTrue(result > -5)
def test_topic_coherence(self):
"""Test LdaVowpalWabbit topic coherence."""
if not self.vw_path: # for python 2.6
return
corpus, dictionary = get_corpus()
lda = LdaVowpalWabbit(
self.vw_path, corpus=corpus, passes=10, chunksize=256,
id2word=dictionary, cleanup_files=True, alpha=0.1,
eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
)
lda.print_topics(5, 10)
# map words in known topic to an ID
topic_map = {}
for i, words in enumerate(TOPIC_WORDS):
topic_map[frozenset(words)] = i
n_coherent = 0
for topic_id in range(lda.num_topics):
topic = lda.show_topic(topic_id, topn=20)
# get all words from LDA topic
topic_words = [w[1] for w in topic]
# get list of original topics that each word actually belongs to
ids = []
for word in topic_words:
for src_topic_words, src_topic_id in six.iteritems(topic_map):
if word in src_topic_words:
ids.append(src_topic_id)
# count the number of times each original topic appears
counts = defaultdict(int)
for found_topic_id in ids:
counts[found_topic_id] += 1
# if at least 6/10 words assigned to same topic, consider it coherent
max_count = 0
for count in six.itervalues(counts):
max_count = max(max_count, count)
if max_count >= 6:
n_coherent += 1
# not 100% deterministic, but should always get 3+ coherent topics
self.assertTrue(n_coherent >= 3)
def test_corpus_to_vw(self):
"""Test corpus to Vowpal Wabbit format conversion."""
if not self.vw_path: # for python 2.6
return
corpus = [
[(0, 5), (7, 1), (5, 3), (0, 2)],
[(7, 2), (2, 1), (3, 11)],
[(1, 1)],
[],
[(5, 2), (0, 1)]
]
expected = """
| 0:5 7:1 5:3 0:2
| 7:2 2:1 3:11
| 1:1
|
| 5:2 0:1
""".strip()
result = '\n'.join(ldavowpalwabbit.corpus_to_vw(corpus))
self.assertEqual(result, expected)
def testvwmodel2ldamodel(self):
"""Test copying of VWModel to LdaModel"""
if not self.vw_path:
return
tm1 = LdaVowpalWabbit(vw_path=self.vw_path, corpus=self.corpus, num_topics=2, id2word=self.dictionary)
tm2 = ldavowpalwabbit.vwmodel2ldamodel(tm1)
for document in self.corpus:
element1_1, element1_2 = tm1[document][0]
element2_1, element2_2 = tm2[document][0]
self.assertAlmostEqual(element1_1, element2_1)
self.assertAlmostEqual(element1_2, element2_2, 5)
logging.debug('%d %d', element1_1, element2_1)
logging.debug('%d %d', element1_2, element2_2)
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()