""" General tests for all estimators in sklearn. """ # Authors: Andreas Mueller # Gael Varoquaux gael.varoquaux@normalesup.org # License: BSD 3 clause from __future__ import print_function import os import warnings import sys import re import pkgutil from sklearn.externals.six import PY3 from sklearn.utils.testing import assert_false, clean_warning_registry from sklearn.utils.testing import all_estimators from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_in from sklearn.utils.testing import ignore_warnings from sklearn.utils.testing import _named_check import sklearn from sklearn.cluster.bicluster import BiclusterMixin from sklearn.linear_model.base import LinearClassifierMixin from sklearn.utils.estimator_checks import ( _yield_all_checks, set_checking_parameters, check_parameters_default_constructible, check_no_fit_attributes_set_in_init, check_class_weight_balanced_linear_classifier) def test_all_estimator_no_base_class(): # test that all_estimators doesn't find abstract classes. for name, Estimator in all_estimators(): msg = ("Base estimators such as {0} should not be included" " in all_estimators").format(name) assert_false(name.lower().startswith('base'), msg=msg) def test_all_estimators(): # Test that estimators are default-constructible, cloneable # and have working repr. estimators = all_estimators(include_meta_estimators=True) # Meta sanity-check to make sure that the estimator introspection runs # properly assert_greater(len(estimators), 0) for name, Estimator in estimators: # some can just not be sensibly default constructed yield (_named_check(check_parameters_default_constructible, name), name, Estimator) def test_non_meta_estimators(): # input validation etc for non-meta estimators estimators = all_estimators() for name, Estimator in estimators: if issubclass(Estimator, BiclusterMixin): continue if name.startswith("_"): continue estimator = Estimator() # check this on class yield _named_check( check_no_fit_attributes_set_in_init, name), name, Estimator for check in _yield_all_checks(name, estimator): set_checking_parameters(estimator) yield _named_check(check, name), name, estimator def test_configure(): # Smoke test the 'configure' step of setup, this tests all the # 'configure' functions in the setup.pys in the scikit cwd = os.getcwd() setup_path = os.path.abspath(os.path.join(sklearn.__path__[0], '..')) setup_filename = os.path.join(setup_path, 'setup.py') if not os.path.exists(setup_filename): return try: os.chdir(setup_path) old_argv = sys.argv sys.argv = ['setup.py', 'config'] clean_warning_registry() with warnings.catch_warnings(): # The configuration spits out warnings when not finding # Blas/Atlas development headers warnings.simplefilter('ignore', UserWarning) if PY3: with open('setup.py') as f: exec(f.read(), dict(__name__='__main__')) else: execfile('setup.py', dict(__name__='__main__')) finally: sys.argv = old_argv os.chdir(cwd) def test_class_weight_balanced_linear_classifiers(): classifiers = all_estimators(type_filter='classifier') clean_warning_registry() with warnings.catch_warnings(record=True): linear_classifiers = [ (name, clazz) for name, clazz in classifiers if ('class_weight' in clazz().get_params().keys() and issubclass(clazz, LinearClassifierMixin))] for name, Classifier in linear_classifiers: yield _named_check(check_class_weight_balanced_linear_classifier, name), name, Classifier @ignore_warnings def test_import_all_consistency(): # Smoke test to check that any name in a __all__ list is actually defined # in the namespace of the module or package. pkgs = pkgutil.walk_packages(path=sklearn.__path__, prefix='sklearn.', onerror=lambda _: None) submods = [modname for _, modname, _ in pkgs] for modname in submods + ['sklearn']: if ".tests." in modname: continue package = __import__(modname, fromlist="dummy") for name in getattr(package, '__all__', ()): if getattr(package, name, None) is None: raise AttributeError( "Module '{0}' has no attribute '{1}'".format( modname, name)) def test_root_import_all_completeness(): EXCEPTIONS = ('utils', 'tests', 'base', 'setup') for _, modname, _ in pkgutil.walk_packages(path=sklearn.__path__, onerror=lambda _: None): if '.' in modname or modname.startswith('_') or modname in EXCEPTIONS: continue assert_in(modname, sklearn.__all__) def test_all_tests_are_importable(): # Ensure that for each contentful subpackage, there is a test directory # within it that is also a subpackage (i.e. a directory with __init__.py) HAS_TESTS_EXCEPTIONS = re.compile(r'''(?x) \.externals(\.|$)| \.tests(\.|$)| \._ ''') lookup = dict((name, ispkg) for _, name, ispkg in pkgutil.walk_packages(sklearn.__path__, prefix='sklearn.')) missing_tests = [name for name, ispkg in lookup.items() if ispkg and not HAS_TESTS_EXCEPTIONS.search(name) and name + '.tests' not in lookup] assert_equal(missing_tests, [], '{0} do not have `tests` subpackages. Perhaps they require ' '__init__.py or an add_subpackage directive in the parent ' 'setup.py'.format(missing_tests))