237 lines
7.7 KiB
Python
237 lines
7.7 KiB
Python
# Important note for the deprecation cleaning of 0.20 :
|
|
# All the function and classes of this file have been deprecated in 0.18.
|
|
# When you remove this file please also remove the related files
|
|
# - 'sklearn/mixture/dpgmm.py'
|
|
# - 'sklearn/mixture/gmm.py'
|
|
# - 'sklearn/mixture/test_gmm.py'
|
|
import unittest
|
|
import sys
|
|
|
|
import numpy as np
|
|
|
|
from sklearn.mixture import DPGMM, VBGMM
|
|
from sklearn.mixture.dpgmm import log_normalize
|
|
from sklearn.datasets import make_blobs
|
|
from sklearn.utils.testing import assert_array_less, assert_equal
|
|
from sklearn.utils.testing import assert_warns_message, ignore_warnings
|
|
from sklearn.mixture.tests.test_gmm import GMMTester
|
|
from sklearn.externals.six.moves import cStringIO as StringIO
|
|
from sklearn.mixture.dpgmm import digamma, gammaln
|
|
from sklearn.mixture.dpgmm import wishart_log_det, wishart_logz
|
|
|
|
|
|
np.seterr(all='warn')
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_class_weights():
|
|
# check that the class weights are updated
|
|
# simple 3 cluster dataset
|
|
X, y = make_blobs(random_state=1)
|
|
for Model in [DPGMM, VBGMM]:
|
|
dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50)
|
|
dpgmm.fit(X)
|
|
# get indices of components that are used:
|
|
indices = np.unique(dpgmm.predict(X))
|
|
active = np.zeros(10, dtype=np.bool)
|
|
active[indices] = True
|
|
# used components are important
|
|
assert_array_less(.1, dpgmm.weights_[active])
|
|
# others are not
|
|
assert_array_less(dpgmm.weights_[~active], .05)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_verbose_boolean():
|
|
# checks that the output for the verbose output is the same
|
|
# for the flag values '1' and 'True'
|
|
# simple 3 cluster dataset
|
|
X, y = make_blobs(random_state=1)
|
|
for Model in [DPGMM, VBGMM]:
|
|
dpgmm_bool = Model(n_components=10, random_state=1, alpha=20,
|
|
n_iter=50, verbose=True)
|
|
dpgmm_int = Model(n_components=10, random_state=1, alpha=20,
|
|
n_iter=50, verbose=1)
|
|
|
|
old_stdout = sys.stdout
|
|
sys.stdout = StringIO()
|
|
try:
|
|
# generate output with the boolean flag
|
|
dpgmm_bool.fit(X)
|
|
verbose_output = sys.stdout
|
|
verbose_output.seek(0)
|
|
bool_output = verbose_output.readline()
|
|
# generate output with the int flag
|
|
dpgmm_int.fit(X)
|
|
verbose_output = sys.stdout
|
|
verbose_output.seek(0)
|
|
int_output = verbose_output.readline()
|
|
assert_equal(bool_output, int_output)
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_verbose_first_level():
|
|
# simple 3 cluster dataset
|
|
X, y = make_blobs(random_state=1)
|
|
for Model in [DPGMM, VBGMM]:
|
|
dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
|
|
verbose=1)
|
|
|
|
old_stdout = sys.stdout
|
|
sys.stdout = StringIO()
|
|
try:
|
|
dpgmm.fit(X)
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_verbose_second_level():
|
|
# simple 3 cluster dataset
|
|
X, y = make_blobs(random_state=1)
|
|
for Model in [DPGMM, VBGMM]:
|
|
dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
|
|
verbose=2)
|
|
|
|
old_stdout = sys.stdout
|
|
sys.stdout = StringIO()
|
|
try:
|
|
dpgmm.fit(X)
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_digamma():
|
|
assert_warns_message(DeprecationWarning, "The function digamma is"
|
|
" deprecated in 0.18 and will be removed in 0.20. "
|
|
"Use scipy.special.digamma instead.", digamma, 3)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_gammaln():
|
|
assert_warns_message(DeprecationWarning, "The function gammaln"
|
|
" is deprecated in 0.18 and will be removed"
|
|
" in 0.20. Use scipy.special.gammaln instead.",
|
|
gammaln, 3)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_log_normalize():
|
|
v = np.array([0.1, 0.8, 0.01, 0.09])
|
|
a = np.log(2 * v)
|
|
result = assert_warns_message(DeprecationWarning, "The function "
|
|
"log_normalize is deprecated in 0.18 and"
|
|
" will be removed in 0.20.",
|
|
log_normalize, a)
|
|
assert np.allclose(v, result, rtol=0.01)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_wishart_log_det():
|
|
a = np.array([0.1, 0.8, 0.01, 0.09])
|
|
b = np.array([0.2, 0.7, 0.05, 0.1])
|
|
assert_warns_message(DeprecationWarning, "The function "
|
|
"wishart_log_det is deprecated in 0.18 and"
|
|
" will be removed in 0.20.",
|
|
wishart_log_det, a, b, 2, 4)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_wishart_logz():
|
|
assert_warns_message(DeprecationWarning, "The function "
|
|
"wishart_logz is deprecated in 0.18 and "
|
|
"will be removed in 0.20.", wishart_logz,
|
|
3, np.identity(3), 1, 3)
|
|
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def test_DPGMM_deprecation():
|
|
assert_warns_message(
|
|
DeprecationWarning, "The `DPGMM` class is not working correctly and "
|
|
"it's better to use `sklearn.mixture.BayesianGaussianMixture` class "
|
|
"with parameter `weight_concentration_prior_type='dirichlet_process'` "
|
|
"instead. DPGMM is deprecated in 0.18 and will be removed in 0.20.",
|
|
DPGMM)
|
|
|
|
|
|
def do_model(self, **kwds):
|
|
return VBGMM(verbose=False, **kwds)
|
|
|
|
|
|
class DPGMMTester(GMMTester):
|
|
model = DPGMM
|
|
do_test_eval = False
|
|
|
|
def score(self, g, train_obs):
|
|
_, z = g.score_samples(train_obs)
|
|
return g.lower_bound(train_obs, z)
|
|
|
|
|
|
class TestDPGMMWithSphericalCovars(unittest.TestCase, DPGMMTester):
|
|
covariance_type = 'spherical'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestDPGMMWithDiagCovars(unittest.TestCase, DPGMMTester):
|
|
covariance_type = 'diag'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestDPGMMWithTiedCovars(unittest.TestCase, DPGMMTester):
|
|
covariance_type = 'tied'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestDPGMMWithFullCovars(unittest.TestCase, DPGMMTester):
|
|
covariance_type = 'full'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
def test_VBGMM_deprecation():
|
|
assert_warns_message(
|
|
DeprecationWarning, "The `VBGMM` class is not working correctly and "
|
|
"it's better to use `sklearn.mixture.BayesianGaussianMixture` class "
|
|
"with parameter `weight_concentration_prior_type="
|
|
"'dirichlet_distribution'` instead. VBGMM is deprecated "
|
|
"in 0.18 and will be removed in 0.20.", VBGMM)
|
|
|
|
|
|
class VBGMMTester(GMMTester):
|
|
model = do_model
|
|
do_test_eval = False
|
|
|
|
def score(self, g, train_obs):
|
|
_, z = g.score_samples(train_obs)
|
|
return g.lower_bound(train_obs, z)
|
|
|
|
|
|
class TestVBGMMWithSphericalCovars(unittest.TestCase, VBGMMTester):
|
|
covariance_type = 'spherical'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestVBGMMWithDiagCovars(unittest.TestCase, VBGMMTester):
|
|
covariance_type = 'diag'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestVBGMMWithTiedCovars(unittest.TestCase, VBGMMTester):
|
|
covariance_type = 'tied'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
class TestVBGMMWithFullCovars(unittest.TestCase, VBGMMTester):
|
|
covariance_type = 'full'
|
|
setUp = GMMTester._setUp
|
|
|
|
|
|
def test_vbgmm_no_modify_alpha():
|
|
alpha = 2.
|
|
n_components = 3
|
|
X, y = make_blobs(random_state=1)
|
|
vbgmm = VBGMM(n_components=n_components, alpha=alpha, n_iter=1)
|
|
assert_equal(vbgmm.alpha, alpha)
|
|
assert_equal(vbgmm.fit(X).alpha_, float(alpha) / n_components)
|