laywerrobot/lib/python3.6/site-packages/sklearn/tests/test_base.py

471 lines
15 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# Author: Gael Varoquaux
# License: BSD 3 clause
import numpy as np
import scipy.sparse as sp
import sklearn
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_no_warnings
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import assert_dict_equal
from sklearn.utils.testing import ignore_warnings
from sklearn.base import BaseEstimator, clone, is_classifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn import datasets
from sklearn.utils import deprecated
from sklearn.base import TransformerMixin
from sklearn.utils.mocking import MockDataFrame
import pickle
#############################################################################
# A few test classes
class MyEstimator(BaseEstimator):
def __init__(self, l1=0, empty=None):
self.l1 = l1
self.empty = empty
class K(BaseEstimator):
def __init__(self, c=None, d=None):
self.c = c
self.d = d
class T(BaseEstimator):
def __init__(self, a=None, b=None):
self.a = a
self.b = b
class ModifyInitParams(BaseEstimator):
"""Deprecated behavior.
Equal parameters but with a type cast.
Doesn't fulfill a is a
"""
def __init__(self, a=np.array([0])):
self.a = a.copy()
class DeprecatedAttributeEstimator(BaseEstimator):
def __init__(self, a=None, b=None):
self.a = a
if b is not None:
DeprecationWarning("b is deprecated and renamed 'a'")
self.a = b
@property
@deprecated("Parameter 'b' is deprecated and renamed to 'a'")
def b(self):
return self._b
class Buggy(BaseEstimator):
" A buggy estimator that does not set its parameters right. "
def __init__(self, a=None):
self.a = 1
class NoEstimator(object):
def __init__(self):
pass
def fit(self, X=None, y=None):
return self
def predict(self, X=None):
return None
class VargEstimator(BaseEstimator):
"""scikit-learn estimators shouldn't have vargs."""
def __init__(self, *vargs):
pass
#############################################################################
# The tests
def test_clone():
# Tests that clone creates a correct deep copy.
# We create an estimator, make a copy of its original state
# (which, in this case, is the current state of the estimator),
# and check that the obtained copy is a correct deep copy.
from sklearn.feature_selection import SelectFpr, f_classif
selector = SelectFpr(f_classif, alpha=0.1)
new_selector = clone(selector)
assert_true(selector is not new_selector)
assert_equal(selector.get_params(), new_selector.get_params())
selector = SelectFpr(f_classif, alpha=np.zeros((10, 2)))
new_selector = clone(selector)
assert_true(selector is not new_selector)
def test_clone_2():
# Tests that clone doesn't copy everything.
# We first create an estimator, give it an own attribute, and
# make a copy of its original state. Then we check that the copy doesn't
# have the specific attribute we manually added to the initial estimator.
from sklearn.feature_selection import SelectFpr, f_classif
selector = SelectFpr(f_classif, alpha=0.1)
selector.own_attribute = "test"
new_selector = clone(selector)
assert_false(hasattr(new_selector, "own_attribute"))
def test_clone_buggy():
# Check that clone raises an error on buggy estimators.
buggy = Buggy()
buggy.a = 2
assert_raises(RuntimeError, clone, buggy)
no_estimator = NoEstimator()
assert_raises(TypeError, clone, no_estimator)
varg_est = VargEstimator()
assert_raises(RuntimeError, clone, varg_est)
def test_clone_empty_array():
# Regression test for cloning estimators with empty arrays
clf = MyEstimator(empty=np.array([]))
clf2 = clone(clf)
assert_array_equal(clf.empty, clf2.empty)
clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]])))
clf2 = clone(clf)
assert_array_equal(clf.empty.data, clf2.empty.data)
def test_clone_nan():
# Regression test for cloning estimators with default parameter as np.nan
clf = MyEstimator(empty=np.nan)
clf2 = clone(clf)
assert_true(clf.empty is clf2.empty)
def test_clone_copy_init_params():
# test for deprecation warning when copying or casting an init parameter
est = ModifyInitParams()
message = ("Estimator ModifyInitParams modifies parameters in __init__. "
"This behavior is deprecated as of 0.18 and support "
"for this behavior will be removed in 0.20.")
assert_warns_message(DeprecationWarning, message, clone, est)
def test_clone_sparse_matrices():
sparse_matrix_classes = [
getattr(sp, name)
for name in dir(sp) if name.endswith('_matrix')]
for cls in sparse_matrix_classes:
sparse_matrix = cls(np.eye(5))
clf = MyEstimator(empty=sparse_matrix)
clf_cloned = clone(clf)
assert_true(clf.empty.__class__ is clf_cloned.empty.__class__)
assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())
def test_repr():
# Smoke test the repr of the base estimator.
my_estimator = MyEstimator()
repr(my_estimator)
test = T(K(), K())
assert_equal(
repr(test),
"T(a=K(c=None, d=None), b=K(c=None, d=None))"
)
some_est = T(a=["long_params"] * 1000)
assert_equal(len(repr(some_est)), 415)
def test_str():
# Smoke test the str of the base estimator
my_estimator = MyEstimator()
str(my_estimator)
def test_get_params():
test = T(K(), K())
assert_true('a__d' in test.get_params(deep=True))
assert_true('a__d' not in test.get_params(deep=False))
test.set_params(a__d=2)
assert_true(test.a.d == 2)
assert_raises(ValueError, test.set_params, a__a=2)
def test_get_params_deprecated():
# deprecated attribute should not show up as params
est = DeprecatedAttributeEstimator(a=1)
assert_true('a' in est.get_params())
assert_true('a' in est.get_params(deep=True))
assert_true('a' in est.get_params(deep=False))
assert_true('b' not in est.get_params())
assert_true('b' not in est.get_params(deep=True))
assert_true('b' not in est.get_params(deep=False))
def test_is_classifier():
svc = SVC()
assert_true(is_classifier(svc))
assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]})))
assert_true(is_classifier(Pipeline([('svc', svc)])))
assert_true(is_classifier(Pipeline(
[('svc_cv', GridSearchCV(svc, {'C': [0.1, 1]}))])))
def test_set_params():
# test nested estimator parameter setting
clf = Pipeline([("svc", SVC())])
# non-existing parameter in svc
assert_raises(ValueError, clf.set_params, svc__stupid_param=True)
# non-existing parameter of pipeline
assert_raises(ValueError, clf.set_params, svm__stupid_param=True)
# we don't currently catch if the things in pipeline are estimators
# bad_pipeline = Pipeline([("bad", NoEstimator())])
# assert_raises(AttributeError, bad_pipeline.set_params,
# bad__stupid_param=True)
def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator. Regression test for #9944
class TestDecisionTree(DecisionTreeClassifier):
def set_params(self, **kwargs):
super(TestDecisionTree, self).set_params(**kwargs)
# expected_kwargs is in test scope
assert kwargs == expected_kwargs
return self
expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2}
for est in [Pipeline([('estimator', TestDecisionTree())]),
GridSearchCV(TestDecisionTree(), {})]:
est.set_params(estimator__max_depth=5,
estimator__min_samples_leaf=2)
def test_score_sample_weight():
rng = np.random.RandomState(0)
# test both ClassifierMixin and RegressorMixin
estimators = [DecisionTreeClassifier(max_depth=2),
DecisionTreeRegressor(max_depth=2)]
sets = [datasets.load_iris(),
datasets.load_boston()]
for est, ds in zip(estimators, sets):
est.fit(ds.data, ds.target)
# generate random sample weights
sample_weight = rng.randint(1, 10, size=len(ds.target))
# check that the score with and without sample weights are different
assert_not_equal(est.score(ds.data, ds.target),
est.score(ds.data, ds.target,
sample_weight=sample_weight),
msg="Unweighted and weighted scores "
"are unexpectedly equal")
def test_clone_pandas_dataframe():
class DummyEstimator(BaseEstimator, TransformerMixin):
"""This is a dummy class for generating numerical features
This feature extractor extracts numerical features from pandas data
frame.
Parameters
----------
df: pandas data frame
The pandas data frame parameter.
Notes
-----
"""
def __init__(self, df=None, scalar_param=1):
self.df = df
self.scalar_param = scalar_param
def fit(self, X, y=None):
pass
def transform(self, X):
pass
# build and clone estimator
d = np.arange(10)
df = MockDataFrame(d)
e = DummyEstimator(df, scalar_param=1)
cloned_e = clone(e)
# the test
assert_true((e.df == cloned_e.df).values.all())
assert_equal(e.scalar_param, cloned_e.scalar_param)
def test_pickle_version_warning_is_not_raised_with_matching_version():
iris = datasets.load_iris()
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
tree_pickle = pickle.dumps(tree)
assert_true(b"version" in tree_pickle)
tree_restored = assert_no_warnings(pickle.loads, tree_pickle)
# test that we can predict with the restored decision tree classifier
score_of_original = tree.score(iris.data, iris.target)
score_of_restored = tree_restored.score(iris.data, iris.target)
assert_equal(score_of_original, score_of_restored)
class TreeBadVersion(DecisionTreeClassifier):
def __getstate__(self):
return dict(self.__dict__.items(), _sklearn_version="something")
pickle_error_message = (
"Trying to unpickle estimator {estimator} from "
"version {old_version} when using version "
"{current_version}. This might "
"lead to breaking code or invalid results. "
"Use at your own risk.")
def test_pickle_version_warning_is_issued_upon_different_version():
iris = datasets.load_iris()
tree = TreeBadVersion().fit(iris.data, iris.target)
tree_pickle_other = pickle.dumps(tree)
message = pickle_error_message.format(estimator="TreeBadVersion",
old_version="something",
current_version=sklearn.__version__)
assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other)
class TreeNoVersion(DecisionTreeClassifier):
def __getstate__(self):
return self.__dict__
def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
iris = datasets.load_iris()
# TreeNoVersion has no getstate, like pre-0.18
tree = TreeNoVersion().fit(iris.data, iris.target)
tree_pickle_noversion = pickle.dumps(tree)
assert_false(b"version" in tree_pickle_noversion)
message = pickle_error_message.format(estimator="TreeNoVersion",
old_version="pre-0.18",
current_version=sklearn.__version__)
# check we got the warning about using pre-0.18 pickle
assert_warns_message(UserWarning, message, pickle.loads,
tree_pickle_noversion)
def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
iris = datasets.load_iris()
tree = TreeNoVersion().fit(iris.data, iris.target)
tree_pickle_noversion = pickle.dumps(tree)
try:
module_backup = TreeNoVersion.__module__
TreeNoVersion.__module__ = "notsklearn"
assert_no_warnings(pickle.loads, tree_pickle_noversion)
finally:
TreeNoVersion.__module__ = module_backup
class DontPickleAttributeMixin(object):
def __getstate__(self):
data = self.__dict__.copy()
data["_attribute_not_pickled"] = None
return data
def __setstate__(self, state):
state["_restored"] = True
self.__dict__.update(state)
class MultiInheritanceEstimator(BaseEstimator, DontPickleAttributeMixin):
def __init__(self, attribute_pickled=5):
self.attribute_pickled = attribute_pickled
self._attribute_not_pickled = None
def test_pickling_when_getstate_is_overwritten_by_mixin():
estimator = MultiInheritanceEstimator()
estimator._attribute_not_pickled = "this attribute should not be pickled"
serialized = pickle.dumps(estimator)
estimator_restored = pickle.loads(serialized)
assert_equal(estimator_restored.attribute_pickled, 5)
assert_equal(estimator_restored._attribute_not_pickled, None)
assert_true(estimator_restored._restored)
def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn():
try:
estimator = MultiInheritanceEstimator()
text = "this attribute should not be pickled"
estimator._attribute_not_pickled = text
old_mod = type(estimator).__module__
type(estimator).__module__ = "notsklearn"
serialized = estimator.__getstate__()
assert_dict_equal(serialized, {'_attribute_not_pickled': None,
'attribute_pickled': 5})
serialized['attribute_pickled'] = 4
estimator.__setstate__(serialized)
assert_equal(estimator.attribute_pickled, 4)
assert_true(estimator._restored)
finally:
type(estimator).__module__ = old_mod
class SingleInheritanceEstimator(BaseEstimator):
def __init__(self, attribute_pickled=5):
self.attribute_pickled = attribute_pickled
self._attribute_not_pickled = None
def __getstate__(self):
data = self.__dict__.copy()
data["_attribute_not_pickled"] = None
return data
@ignore_warnings(category=(UserWarning))
def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
estimator = SingleInheritanceEstimator()
estimator._attribute_not_pickled = "this attribute should not be pickled"
serialized = pickle.dumps(estimator)
estimator_restored = pickle.loads(serialized)
assert_equal(estimator_restored.attribute_pickled, 5)
assert_equal(estimator_restored._attribute_not_pickled, None)