# Author: Gael Varoquaux # License: BSD 3 clause import numpy as np import scipy.sparse as sp import sklearn from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import assert_dict_equal from sklearn.utils.testing import ignore_warnings from sklearn.base import BaseEstimator, clone, is_classifier from sklearn.svm import SVC from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeRegressor from sklearn import datasets from sklearn.utils import deprecated from sklearn.base import TransformerMixin from sklearn.utils.mocking import MockDataFrame import pickle ############################################################################# # A few test classes class MyEstimator(BaseEstimator): def __init__(self, l1=0, empty=None): self.l1 = l1 self.empty = empty class K(BaseEstimator): def __init__(self, c=None, d=None): self.c = c self.d = d class T(BaseEstimator): def __init__(self, a=None, b=None): self.a = a self.b = b class ModifyInitParams(BaseEstimator): """Deprecated behavior. Equal parameters but with a type cast. Doesn't fulfill a is a """ def __init__(self, a=np.array([0])): self.a = a.copy() class DeprecatedAttributeEstimator(BaseEstimator): def __init__(self, a=None, b=None): self.a = a if b is not None: DeprecationWarning("b is deprecated and renamed 'a'") self.a = b @property @deprecated("Parameter 'b' is deprecated and renamed to 'a'") def b(self): return self._b class Buggy(BaseEstimator): " A buggy estimator that does not set its parameters right. " def __init__(self, a=None): self.a = 1 class NoEstimator(object): def __init__(self): pass def fit(self, X=None, y=None): return self def predict(self, X=None): return None class VargEstimator(BaseEstimator): """scikit-learn estimators shouldn't have vargs.""" def __init__(self, *vargs): pass ############################################################################# # The tests def test_clone(): # Tests that clone creates a correct deep copy. # We create an estimator, make a copy of its original state # (which, in this case, is the current state of the estimator), # and check that the obtained copy is a correct deep copy. from sklearn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) new_selector = clone(selector) assert_true(selector is not new_selector) assert_equal(selector.get_params(), new_selector.get_params()) selector = SelectFpr(f_classif, alpha=np.zeros((10, 2))) new_selector = clone(selector) assert_true(selector is not new_selector) def test_clone_2(): # Tests that clone doesn't copy everything. # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't # have the specific attribute we manually added to the initial estimator. from sklearn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) selector.own_attribute = "test" new_selector = clone(selector) assert_false(hasattr(new_selector, "own_attribute")) def test_clone_buggy(): # Check that clone raises an error on buggy estimators. buggy = Buggy() buggy.a = 2 assert_raises(RuntimeError, clone, buggy) no_estimator = NoEstimator() assert_raises(TypeError, clone, no_estimator) varg_est = VargEstimator() assert_raises(RuntimeError, clone, varg_est) def test_clone_empty_array(): # Regression test for cloning estimators with empty arrays clf = MyEstimator(empty=np.array([])) clf2 = clone(clf) assert_array_equal(clf.empty, clf2.empty) clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]]))) clf2 = clone(clf) assert_array_equal(clf.empty.data, clf2.empty.data) def test_clone_nan(): # Regression test for cloning estimators with default parameter as np.nan clf = MyEstimator(empty=np.nan) clf2 = clone(clf) assert_true(clf.empty is clf2.empty) def test_clone_copy_init_params(): # test for deprecation warning when copying or casting an init parameter est = ModifyInitParams() message = ("Estimator ModifyInitParams modifies parameters in __init__. " "This behavior is deprecated as of 0.18 and support " "for this behavior will be removed in 0.20.") assert_warns_message(DeprecationWarning, message, clone, est) def test_clone_sparse_matrices(): sparse_matrix_classes = [ getattr(sp, name) for name in dir(sp) if name.endswith('_matrix')] for cls in sparse_matrix_classes: sparse_matrix = cls(np.eye(5)) clf = MyEstimator(empty=sparse_matrix) clf_cloned = clone(clf) assert_true(clf.empty.__class__ is clf_cloned.empty.__class__) assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray()) def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator() repr(my_estimator) test = T(K(), K()) assert_equal( repr(test), "T(a=K(c=None, d=None), b=K(c=None, d=None))" ) some_est = T(a=["long_params"] * 1000) assert_equal(len(repr(some_est)), 415) def test_str(): # Smoke test the str of the base estimator my_estimator = MyEstimator() str(my_estimator) def test_get_params(): test = T(K(), K()) assert_true('a__d' in test.get_params(deep=True)) assert_true('a__d' not in test.get_params(deep=False)) test.set_params(a__d=2) assert_true(test.a.d == 2) assert_raises(ValueError, test.set_params, a__a=2) def test_get_params_deprecated(): # deprecated attribute should not show up as params est = DeprecatedAttributeEstimator(a=1) assert_true('a' in est.get_params()) assert_true('a' in est.get_params(deep=True)) assert_true('a' in est.get_params(deep=False)) assert_true('b' not in est.get_params()) assert_true('b' not in est.get_params(deep=True)) assert_true('b' not in est.get_params(deep=False)) def test_is_classifier(): svc = SVC() assert_true(is_classifier(svc)) assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]}))) assert_true(is_classifier(Pipeline([('svc', svc)]))) assert_true(is_classifier(Pipeline( [('svc_cv', GridSearchCV(svc, {'C': [0.1, 1]}))]))) def test_set_params(): # test nested estimator parameter setting clf = Pipeline([("svc", SVC())]) # non-existing parameter in svc assert_raises(ValueError, clf.set_params, svc__stupid_param=True) # non-existing parameter of pipeline assert_raises(ValueError, clf.set_params, svm__stupid_param=True) # we don't currently catch if the things in pipeline are estimators # bad_pipeline = Pipeline([("bad", NoEstimator())]) # assert_raises(AttributeError, bad_pipeline.set_params, # bad__stupid_param=True) def test_set_params_passes_all_parameters(): # Make sure all parameters are passed together to set_params # of nested estimator. Regression test for #9944 class TestDecisionTree(DecisionTreeClassifier): def set_params(self, **kwargs): super(TestDecisionTree, self).set_params(**kwargs) # expected_kwargs is in test scope assert kwargs == expected_kwargs return self expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2} for est in [Pipeline([('estimator', TestDecisionTree())]), GridSearchCV(TestDecisionTree(), {})]: est.set_params(estimator__max_depth=5, estimator__min_samples_leaf=2) def test_score_sample_weight(): rng = np.random.RandomState(0) # test both ClassifierMixin and RegressorMixin estimators = [DecisionTreeClassifier(max_depth=2), DecisionTreeRegressor(max_depth=2)] sets = [datasets.load_iris(), datasets.load_boston()] for est, ds in zip(estimators, sets): est.fit(ds.data, ds.target) # generate random sample weights sample_weight = rng.randint(1, 10, size=len(ds.target)) # check that the score with and without sample weights are different assert_not_equal(est.score(ds.data, ds.target), est.score(ds.data, ds.target, sample_weight=sample_weight), msg="Unweighted and weighted scores " "are unexpectedly equal") def test_clone_pandas_dataframe(): class DummyEstimator(BaseEstimator, TransformerMixin): """This is a dummy class for generating numerical features This feature extractor extracts numerical features from pandas data frame. Parameters ---------- df: pandas data frame The pandas data frame parameter. Notes ----- """ def __init__(self, df=None, scalar_param=1): self.df = df self.scalar_param = scalar_param def fit(self, X, y=None): pass def transform(self, X): pass # build and clone estimator d = np.arange(10) df = MockDataFrame(d) e = DummyEstimator(df, scalar_param=1) cloned_e = clone(e) # the test assert_true((e.df == cloned_e.df).values.all()) assert_equal(e.scalar_param, cloned_e.scalar_param) def test_pickle_version_warning_is_not_raised_with_matching_version(): iris = datasets.load_iris() tree = DecisionTreeClassifier().fit(iris.data, iris.target) tree_pickle = pickle.dumps(tree) assert_true(b"version" in tree_pickle) tree_restored = assert_no_warnings(pickle.loads, tree_pickle) # test that we can predict with the restored decision tree classifier score_of_original = tree.score(iris.data, iris.target) score_of_restored = tree_restored.score(iris.data, iris.target) assert_equal(score_of_original, score_of_restored) class TreeBadVersion(DecisionTreeClassifier): def __getstate__(self): return dict(self.__dict__.items(), _sklearn_version="something") pickle_error_message = ( "Trying to unpickle estimator {estimator} from " "version {old_version} when using version " "{current_version}. This might " "lead to breaking code or invalid results. " "Use at your own risk.") def test_pickle_version_warning_is_issued_upon_different_version(): iris = datasets.load_iris() tree = TreeBadVersion().fit(iris.data, iris.target) tree_pickle_other = pickle.dumps(tree) message = pickle_error_message.format(estimator="TreeBadVersion", old_version="something", current_version=sklearn.__version__) assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other) class TreeNoVersion(DecisionTreeClassifier): def __getstate__(self): return self.__dict__ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle(): iris = datasets.load_iris() # TreeNoVersion has no getstate, like pre-0.18 tree = TreeNoVersion().fit(iris.data, iris.target) tree_pickle_noversion = pickle.dumps(tree) assert_false(b"version" in tree_pickle_noversion) message = pickle_error_message.format(estimator="TreeNoVersion", old_version="pre-0.18", current_version=sklearn.__version__) # check we got the warning about using pre-0.18 pickle assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_noversion) def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator(): iris = datasets.load_iris() tree = TreeNoVersion().fit(iris.data, iris.target) tree_pickle_noversion = pickle.dumps(tree) try: module_backup = TreeNoVersion.__module__ TreeNoVersion.__module__ = "notsklearn" assert_no_warnings(pickle.loads, tree_pickle_noversion) finally: TreeNoVersion.__module__ = module_backup class DontPickleAttributeMixin(object): def __getstate__(self): data = self.__dict__.copy() data["_attribute_not_pickled"] = None return data def __setstate__(self, state): state["_restored"] = True self.__dict__.update(state) class MultiInheritanceEstimator(BaseEstimator, DontPickleAttributeMixin): def __init__(self, attribute_pickled=5): self.attribute_pickled = attribute_pickled self._attribute_not_pickled = None def test_pickling_when_getstate_is_overwritten_by_mixin(): estimator = MultiInheritanceEstimator() estimator._attribute_not_pickled = "this attribute should not be pickled" serialized = pickle.dumps(estimator) estimator_restored = pickle.loads(serialized) assert_equal(estimator_restored.attribute_pickled, 5) assert_equal(estimator_restored._attribute_not_pickled, None) assert_true(estimator_restored._restored) def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn(): try: estimator = MultiInheritanceEstimator() text = "this attribute should not be pickled" estimator._attribute_not_pickled = text old_mod = type(estimator).__module__ type(estimator).__module__ = "notsklearn" serialized = estimator.__getstate__() assert_dict_equal(serialized, {'_attribute_not_pickled': None, 'attribute_pickled': 5}) serialized['attribute_pickled'] = 4 estimator.__setstate__(serialized) assert_equal(estimator.attribute_pickled, 4) assert_true(estimator._restored) finally: type(estimator).__module__ = old_mod class SingleInheritanceEstimator(BaseEstimator): def __init__(self, attribute_pickled=5): self.attribute_pickled = attribute_pickled self._attribute_not_pickled = None def __getstate__(self): data = self.__dict__.copy() data["_attribute_not_pickled"] = None return data @ignore_warnings(category=(UserWarning)) def test_pickling_works_when_getstate_is_overwritten_in_the_child_class(): estimator = SingleInheritanceEstimator() estimator._attribute_not_pickled = "this attribute should not be pickled" serialized = pickle.dumps(estimator) estimator_restored = pickle.loads(serialized) assert_equal(estimator_restored.attribute_pickled, 5) assert_equal(estimator_restored._attribute_not_pickled, None)