laywerrobot/lib/python3.6/site-packages/sklearn/utils/tests/test_testing.py
2020-08-27 21:55:39 +02:00

486 lines
14 KiB
Python

import warnings
import unittest
import sys
import numpy as np
from scipy import sparse
from sklearn.utils.deprecation import deprecated
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.testing import (
assert_true,
assert_raises,
assert_less,
assert_greater,
assert_less_equal,
assert_greater_equal,
assert_warns,
assert_no_warnings,
assert_equal,
set_random_state,
assert_raise_message,
ignore_warnings,
check_docstring_parameters,
assert_allclose_dense_sparse)
from sklearn.utils.testing import SkipTest
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
def test_assert_less():
assert_less(0, 1)
assert_raises(AssertionError, assert_less, 1, 0)
def test_assert_greater():
assert_greater(1, 0)
assert_raises(AssertionError, assert_greater, 0, 1)
def test_assert_less_equal():
assert_less_equal(0, 1)
assert_less_equal(1, 1)
assert_raises(AssertionError, assert_less_equal, 1, 0)
def test_assert_greater_equal():
assert_greater_equal(1, 0)
assert_greater_equal(1, 1)
assert_raises(AssertionError, assert_greater_equal, 0, 1)
def test_set_random_state():
lda = LinearDiscriminantAnalysis()
tree = DecisionTreeClassifier()
# Linear Discriminant Analysis doesn't have random state: smoke test
set_random_state(lda, 3)
set_random_state(tree, 3)
assert_equal(tree.random_state, 3)
def test_assert_allclose_dense_sparse():
x = np.arange(9).reshape(3, 3)
msg = "Not equal to tolerance "
y = sparse.csc_matrix(x)
for X in [x, y]:
# basic compare
assert_raise_message(AssertionError, msg, assert_allclose_dense_sparse,
X, X * 2)
assert_allclose_dense_sparse(X, X)
assert_raise_message(ValueError, "Can only compare two sparse",
assert_allclose_dense_sparse, x, y)
A = sparse.diags(np.ones(5), offsets=0).tocsr()
B = sparse.csr_matrix(np.ones((1, 5)))
assert_raise_message(AssertionError, "Arrays are not equal",
assert_allclose_dense_sparse, B, A)
def test_assert_raise_message():
def _raise_ValueError(message):
raise ValueError(message)
def _no_raise():
pass
assert_raise_message(ValueError, "test",
_raise_ValueError, "test")
assert_raises(AssertionError,
assert_raise_message, ValueError, "something else",
_raise_ValueError, "test")
assert_raises(ValueError,
assert_raise_message, TypeError, "something else",
_raise_ValueError, "test")
assert_raises(AssertionError,
assert_raise_message, ValueError, "test",
_no_raise)
# multiple exceptions in a tuple
assert_raises(AssertionError,
assert_raise_message, (ValueError, AttributeError),
"test", _no_raise)
def test_ignore_warning():
# This check that ignore_warning decorateur and context manager are working
# as expected
def _warning_function():
warnings.warn("deprecation warning", DeprecationWarning)
def _multiple_warning_function():
warnings.warn("deprecation warning", DeprecationWarning)
warnings.warn("deprecation warning")
# Check the function directly
assert_no_warnings(ignore_warnings(_warning_function))
assert_no_warnings(ignore_warnings(_warning_function,
category=DeprecationWarning))
assert_warns(DeprecationWarning, ignore_warnings(_warning_function,
category=UserWarning))
assert_warns(UserWarning,
ignore_warnings(_multiple_warning_function,
category=DeprecationWarning))
assert_warns(DeprecationWarning,
ignore_warnings(_multiple_warning_function,
category=UserWarning))
assert_no_warnings(ignore_warnings(_warning_function,
category=(DeprecationWarning,
UserWarning)))
# Check the decorator
@ignore_warnings
def decorator_no_warning():
_warning_function()
_multiple_warning_function()
@ignore_warnings(category=(DeprecationWarning, UserWarning))
def decorator_no_warning_multiple():
_multiple_warning_function()
@ignore_warnings(category=DeprecationWarning)
def decorator_no_deprecation_warning():
_warning_function()
@ignore_warnings(category=UserWarning)
def decorator_no_user_warning():
_warning_function()
@ignore_warnings(category=DeprecationWarning)
def decorator_no_deprecation_multiple_warning():
_multiple_warning_function()
@ignore_warnings(category=UserWarning)
def decorator_no_user_multiple_warning():
_multiple_warning_function()
assert_no_warnings(decorator_no_warning)
assert_no_warnings(decorator_no_warning_multiple)
assert_no_warnings(decorator_no_deprecation_warning)
assert_warns(DeprecationWarning, decorator_no_user_warning)
assert_warns(UserWarning, decorator_no_deprecation_multiple_warning)
assert_warns(DeprecationWarning, decorator_no_user_multiple_warning)
# Check the context manager
def context_manager_no_warning():
with ignore_warnings():
_warning_function()
def context_manager_no_warning_multiple():
with ignore_warnings(category=(DeprecationWarning, UserWarning)):
_multiple_warning_function()
def context_manager_no_deprecation_warning():
with ignore_warnings(category=DeprecationWarning):
_warning_function()
def context_manager_no_user_warning():
with ignore_warnings(category=UserWarning):
_warning_function()
def context_manager_no_deprecation_multiple_warning():
with ignore_warnings(category=DeprecationWarning):
_multiple_warning_function()
def context_manager_no_user_multiple_warning():
with ignore_warnings(category=UserWarning):
_multiple_warning_function()
assert_no_warnings(context_manager_no_warning)
assert_no_warnings(context_manager_no_warning_multiple)
assert_no_warnings(context_manager_no_deprecation_warning)
assert_warns(DeprecationWarning, context_manager_no_user_warning)
assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning)
assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning)
# This class is inspired from numpy 1.7 with an alteration to check
# the reset warning filters after calls to assert_warns.
# This assert_warns behavior is specific to scikit-learn because
# `clean_warning_registry()` is called internally by assert_warns
# and clears all previous filters.
class TestWarns(unittest.TestCase):
def test_warn(self):
def f():
warnings.warn("yo")
return 3
# Test that assert_warns is not impacted by externally set
# filters and is reset internally.
# This is because `clean_warning_registry()` is called internally by
# assert_warns and clears all previous filters.
warnings.simplefilter("ignore", UserWarning)
assert_equal(assert_warns(UserWarning, f), 3)
# Test that the warning registry is empty after assert_warns
assert_equal(sys.modules['warnings'].filters, [])
assert_raises(AssertionError, assert_no_warnings, f)
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
def test_warn_wrong_warning(self):
def f():
warnings.warn("yo", DeprecationWarning)
failed = False
filters = sys.modules['warnings'].filters[:]
try:
try:
# Should raise an AssertionError
assert_warns(UserWarning, f)
failed = True
except AssertionError:
pass
finally:
sys.modules['warnings'].filters = filters
if failed:
raise AssertionError("wrong warning caught by assert_warn")
# Tests for docstrings:
def f_ok(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_bad_sections(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Results
-------
c : list
Parameter c
"""
c = a + b
return c
def f_bad_order(b, a):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_missing(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_check_param_definition(a, b, c, d):
"""Function f
Parameters
----------
a: int
Parameter a
b:
Parameter b
c :
Parameter c
d:int
Parameter d
"""
return a + b + c + d
class Klass(object):
def f_missing(self, X, y):
pass
def f_bad_sections(self, X, y):
"""Function f
Parameter
----------
a : int
Parameter a
b : float
Parameter b
Results
-------
c : list
Parameter c
"""
pass
class MockEst(object):
def __init__(self):
"""MockEstimator"""
def fit(self, X, y):
return X
def predict(self, X):
return X
def predict_proba(self, X):
return X
def score(self, X):
return 1.
class MockMetaEstimator(object):
def __init__(self, delegate):
"""MetaEstimator to check if doctest on delegated methods work.
Parameters
---------
delegate : estimator
Delegated estimator.
"""
self.delegate = delegate
@if_delegate_has_method(delegate=('delegate'))
def predict(self, X):
"""This is available only if delegate has predict.
Parameters
----------
y : ndarray
Parameter y
"""
return self.delegate.predict(X)
@deprecated("Testing a deprecated delegated method")
@if_delegate_has_method(delegate=('delegate'))
def score(self, X):
"""This is available only if delegate has score.
Parameters
---------
y : ndarray
Parameter y
"""
@if_delegate_has_method(delegate=('delegate'))
def predict_proba(self, X):
"""This is available only if delegate has predict_proba.
Parameters
---------
X : ndarray
Parameter X
"""
return X
@deprecated('Testing deprecated function with incorrect params')
@if_delegate_has_method(delegate=('delegate'))
def predict_log_proba(self, X):
"""This is available only if delegate has predict_proba.
Parameters
---------
y : ndarray
Parameter X
"""
return X
@deprecated('Testing deprecated function with wrong params')
@if_delegate_has_method(delegate=('delegate'))
def fit(self, X, y):
"""Incorrect docstring but should not be tested"""
def test_check_docstring_parameters():
try:
import numpydoc # noqa
assert sys.version_info >= (3, 5)
except (ImportError, AssertionError):
raise SkipTest(
"numpydoc is required to test the docstrings")
incorrect = check_docstring_parameters(f_ok)
assert_equal(incorrect, [])
incorrect = check_docstring_parameters(f_ok, ignore=['b'])
assert_equal(incorrect, [])
incorrect = check_docstring_parameters(f_missing, ignore=['b'])
assert_equal(incorrect, [])
assert_raise_message(RuntimeError, 'Unknown section Results',
check_docstring_parameters, f_bad_sections)
assert_raise_message(RuntimeError, 'Unknown section Parameter',
check_docstring_parameters, Klass.f_bad_sections)
messages = ["a != b", "arg mismatch: ['b']", "arg mismatch: ['X', 'y']",
"predict y != X",
"predict_proba arg mismatch: ['X']",
"predict_log_proba arg mismatch: ['X']",
"score arg mismatch: ['X']",
".fit arg mismatch: ['X', 'y']"]
mock_meta = MockMetaEstimator(delegate=MockEst())
for mess, f in zip(messages,
[f_bad_order, f_missing, Klass.f_missing,
mock_meta.predict, mock_meta.predict_proba,
mock_meta.predict_log_proba,
mock_meta.score, mock_meta.fit]):
incorrect = check_docstring_parameters(f)
assert_true(len(incorrect) >= 1)
assert_true(mess in incorrect[0],
'"%s" not in "%s"' % (mess, incorrect[0]))
incorrect = check_docstring_parameters(f_check_param_definition)
assert_equal(
incorrect,
['sklearn.utils.tests.test_testing.f_check_param_definition There was '
'no space between the param name and colon ("a: int")',
'sklearn.utils.tests.test_testing.f_check_param_definition There was '
'no space between the param name and colon ("b:")',
'sklearn.utils.tests.test_testing.f_check_param_definition Incorrect '
'type definition for param: "c " (type definition was "")',
'sklearn.utils.tests.test_testing.f_check_param_definition There was '
'no space between the param name and colon ("d:int")'])