laywerrobot/lib/python3.6/site-packages/sklearn/ensemble/tests/test_base.py
2020-08-27 21:55:39 +02:00

131 lines
5.2 KiB
Python

"""
Testing for the base module (sklearn.ensemble.base).
"""
# Authors: Gilles Louppe
# License: BSD 3 clause
import numpy as np
from numpy.testing import assert_equal
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_true
from sklearn.datasets import load_iris
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble.base import _set_random_states
from sklearn.linear_model import Perceptron
from collections import OrderedDict
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectFromModel
def test_base():
# Check BaseEnsemble methods.
ensemble = BaggingClassifier(
base_estimator=Perceptron(tol=1e-3, random_state=None), n_estimators=3)
iris = load_iris()
ensemble.fit(iris.data, iris.target)
ensemble.estimators_ = [] # empty the list and create estimators manually
ensemble._make_estimator()
random_state = np.random.RandomState(3)
ensemble._make_estimator(random_state=random_state)
ensemble._make_estimator(random_state=random_state)
ensemble._make_estimator(append=False)
assert_equal(3, len(ensemble))
assert_equal(3, len(ensemble.estimators_))
assert_true(isinstance(ensemble[0], Perceptron))
assert_equal(ensemble[0].random_state, None)
assert_true(isinstance(ensemble[1].random_state, int))
assert_true(isinstance(ensemble[2].random_state, int))
assert_not_equal(ensemble[1].random_state, ensemble[2].random_state)
np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3),
n_estimators=np.int32(3))
np_int_ensemble.fit(iris.data, iris.target)
def test_base_zero_n_estimators():
# Check that instantiating a BaseEnsemble with n_estimators<=0 raises
# a ValueError.
ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3),
n_estimators=0)
iris = load_iris()
assert_raise_message(ValueError,
"n_estimators must be greater than zero, got 0.",
ensemble.fit, iris.data, iris.target)
def test_base_not_int_n_estimators():
# Check that instantiating a BaseEnsemble with a string as n_estimators
# raises a ValueError demanding n_estimators to be supplied as an integer.
string_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3),
n_estimators='3')
iris = load_iris()
assert_raise_message(ValueError,
"n_estimators must be an integer",
string_ensemble.fit, iris.data, iris.target)
float_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3),
n_estimators=3.0)
assert_raise_message(ValueError,
"n_estimators must be an integer",
float_ensemble.fit, iris.data, iris.target)
def test_set_random_states():
# Linear Discriminant Analysis doesn't have random state: smoke test
_set_random_states(LinearDiscriminantAnalysis(), random_state=17)
clf1 = Perceptron(tol=1e-3, random_state=None)
assert_equal(clf1.random_state, None)
# check random_state is None still sets
_set_random_states(clf1, None)
assert_true(isinstance(clf1.random_state, int))
# check random_state fixes results in consistent initialisation
_set_random_states(clf1, 3)
assert_true(isinstance(clf1.random_state, int))
clf2 = Perceptron(tol=1e-3, random_state=None)
_set_random_states(clf2, 3)
assert_equal(clf1.random_state, clf2.random_state)
# nested random_state
def make_steps():
return [('sel', SelectFromModel(Perceptron(tol=1e-3,
random_state=None))),
('clf', Perceptron(tol=1e-3, random_state=None))]
est1 = Pipeline(make_steps())
_set_random_states(est1, 3)
assert_true(isinstance(est1.steps[0][1].estimator.random_state, int))
assert_true(isinstance(est1.steps[1][1].random_state, int))
assert_not_equal(est1.get_params()['sel__estimator__random_state'],
est1.get_params()['clf__random_state'])
# ensure multiple random_state parameters are invariant to get_params()
# iteration order
class AlphaParamPipeline(Pipeline):
def get_params(self, *args, **kwargs):
params = Pipeline.get_params(self, *args, **kwargs).items()
return OrderedDict(sorted(params))
class RevParamPipeline(Pipeline):
def get_params(self, *args, **kwargs):
params = Pipeline.get_params(self, *args, **kwargs).items()
return OrderedDict(sorted(params, reverse=True))
for cls in [AlphaParamPipeline, RevParamPipeline]:
est2 = cls(make_steps())
_set_random_states(est2, 3)
assert_equal(est1.get_params()['sel__estimator__random_state'],
est2.get_params()['sel__estimator__random_state'])
assert_equal(est1.get_params()['clf__random_state'],
est2.get_params()['clf__random_state'])