laywerrobot/lib/python3.6/site-packages/nltk/test/unit/test_classify.py

56 lines
1.4 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# -*- coding: utf-8 -*-
"""
Unit tests for nltk.classify. See also: nltk/test/classify.doctest
"""
from __future__ import absolute_import
from nose import SkipTest
from nltk import classify
TRAIN = [
(dict(a=1, b=1, c=1), 'y'),
(dict(a=1, b=1, c=1), 'x'),
(dict(a=1, b=1, c=0), 'y'),
(dict(a=0, b=1, c=1), 'x'),
(dict(a=0, b=1, c=1), 'y'),
(dict(a=0, b=0, c=1), 'y'),
(dict(a=0, b=1, c=0), 'x'),
(dict(a=0, b=0, c=0), 'x'),
(dict(a=0, b=1, c=1), 'y'),
]
TEST = [
(dict(a=1, b=0, c=1)), # unseen
(dict(a=1, b=0, c=0)), # unseen
(dict(a=0, b=1, c=1)), # seen 3 times, labels=y,y,x
(dict(a=0, b=1, c=0)), # seen 1 time, label=x
]
RESULTS = [
(0.16, 0.84),
(0.46, 0.54),
(0.41, 0.59),
(0.76, 0.24),
]
def assert_classifier_correct(algorithm):
try:
classifier = classify.MaxentClassifier.train(
TRAIN, algorithm, trace=0, max_iter=1000
)
except (LookupError, AttributeError) as e:
raise SkipTest(str(e))
for (px, py), featureset in zip(RESULTS, TEST):
pdist = classifier.prob_classify(featureset)
assert abs(pdist.prob('x') - px) < 1e-2, (pdist.prob('x'), px)
assert abs(pdist.prob('y') - py) < 1e-2, (pdist.prob('y'), py)
def test_megam():
assert_classifier_correct('MEGAM')
def test_tadm():
assert_classifier_correct('TADM')