486 lines
14 KiB
Python
486 lines
14 KiB
Python
import warnings
|
|
import unittest
|
|
import sys
|
|
import numpy as np
|
|
from scipy import sparse
|
|
|
|
from sklearn.utils.deprecation import deprecated
|
|
from sklearn.utils.metaestimators import if_delegate_has_method
|
|
from sklearn.utils.testing import (
|
|
assert_true,
|
|
assert_raises,
|
|
assert_less,
|
|
assert_greater,
|
|
assert_less_equal,
|
|
assert_greater_equal,
|
|
assert_warns,
|
|
assert_no_warnings,
|
|
assert_equal,
|
|
set_random_state,
|
|
assert_raise_message,
|
|
ignore_warnings,
|
|
check_docstring_parameters,
|
|
assert_allclose_dense_sparse)
|
|
|
|
from sklearn.utils.testing import SkipTest
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
|
|
|
|
|
def test_assert_less():
|
|
assert_less(0, 1)
|
|
assert_raises(AssertionError, assert_less, 1, 0)
|
|
|
|
|
|
def test_assert_greater():
|
|
assert_greater(1, 0)
|
|
assert_raises(AssertionError, assert_greater, 0, 1)
|
|
|
|
|
|
def test_assert_less_equal():
|
|
assert_less_equal(0, 1)
|
|
assert_less_equal(1, 1)
|
|
assert_raises(AssertionError, assert_less_equal, 1, 0)
|
|
|
|
|
|
def test_assert_greater_equal():
|
|
assert_greater_equal(1, 0)
|
|
assert_greater_equal(1, 1)
|
|
assert_raises(AssertionError, assert_greater_equal, 0, 1)
|
|
|
|
|
|
def test_set_random_state():
|
|
lda = LinearDiscriminantAnalysis()
|
|
tree = DecisionTreeClassifier()
|
|
# Linear Discriminant Analysis doesn't have random state: smoke test
|
|
set_random_state(lda, 3)
|
|
set_random_state(tree, 3)
|
|
assert_equal(tree.random_state, 3)
|
|
|
|
|
|
def test_assert_allclose_dense_sparse():
|
|
x = np.arange(9).reshape(3, 3)
|
|
msg = "Not equal to tolerance "
|
|
y = sparse.csc_matrix(x)
|
|
for X in [x, y]:
|
|
# basic compare
|
|
assert_raise_message(AssertionError, msg, assert_allclose_dense_sparse,
|
|
X, X * 2)
|
|
assert_allclose_dense_sparse(X, X)
|
|
|
|
assert_raise_message(ValueError, "Can only compare two sparse",
|
|
assert_allclose_dense_sparse, x, y)
|
|
|
|
A = sparse.diags(np.ones(5), offsets=0).tocsr()
|
|
B = sparse.csr_matrix(np.ones((1, 5)))
|
|
|
|
assert_raise_message(AssertionError, "Arrays are not equal",
|
|
assert_allclose_dense_sparse, B, A)
|
|
|
|
|
|
def test_assert_raise_message():
|
|
def _raise_ValueError(message):
|
|
raise ValueError(message)
|
|
|
|
def _no_raise():
|
|
pass
|
|
|
|
assert_raise_message(ValueError, "test",
|
|
_raise_ValueError, "test")
|
|
|
|
assert_raises(AssertionError,
|
|
assert_raise_message, ValueError, "something else",
|
|
_raise_ValueError, "test")
|
|
|
|
assert_raises(ValueError,
|
|
assert_raise_message, TypeError, "something else",
|
|
_raise_ValueError, "test")
|
|
|
|
assert_raises(AssertionError,
|
|
assert_raise_message, ValueError, "test",
|
|
_no_raise)
|
|
|
|
# multiple exceptions in a tuple
|
|
assert_raises(AssertionError,
|
|
assert_raise_message, (ValueError, AttributeError),
|
|
"test", _no_raise)
|
|
|
|
|
|
def test_ignore_warning():
|
|
# This check that ignore_warning decorateur and context manager are working
|
|
# as expected
|
|
def _warning_function():
|
|
warnings.warn("deprecation warning", DeprecationWarning)
|
|
|
|
def _multiple_warning_function():
|
|
warnings.warn("deprecation warning", DeprecationWarning)
|
|
warnings.warn("deprecation warning")
|
|
|
|
# Check the function directly
|
|
assert_no_warnings(ignore_warnings(_warning_function))
|
|
assert_no_warnings(ignore_warnings(_warning_function,
|
|
category=DeprecationWarning))
|
|
assert_warns(DeprecationWarning, ignore_warnings(_warning_function,
|
|
category=UserWarning))
|
|
assert_warns(UserWarning,
|
|
ignore_warnings(_multiple_warning_function,
|
|
category=DeprecationWarning))
|
|
assert_warns(DeprecationWarning,
|
|
ignore_warnings(_multiple_warning_function,
|
|
category=UserWarning))
|
|
assert_no_warnings(ignore_warnings(_warning_function,
|
|
category=(DeprecationWarning,
|
|
UserWarning)))
|
|
|
|
# Check the decorator
|
|
@ignore_warnings
|
|
def decorator_no_warning():
|
|
_warning_function()
|
|
_multiple_warning_function()
|
|
|
|
@ignore_warnings(category=(DeprecationWarning, UserWarning))
|
|
def decorator_no_warning_multiple():
|
|
_multiple_warning_function()
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def decorator_no_deprecation_warning():
|
|
_warning_function()
|
|
|
|
@ignore_warnings(category=UserWarning)
|
|
def decorator_no_user_warning():
|
|
_warning_function()
|
|
|
|
@ignore_warnings(category=DeprecationWarning)
|
|
def decorator_no_deprecation_multiple_warning():
|
|
_multiple_warning_function()
|
|
|
|
@ignore_warnings(category=UserWarning)
|
|
def decorator_no_user_multiple_warning():
|
|
_multiple_warning_function()
|
|
|
|
assert_no_warnings(decorator_no_warning)
|
|
assert_no_warnings(decorator_no_warning_multiple)
|
|
assert_no_warnings(decorator_no_deprecation_warning)
|
|
assert_warns(DeprecationWarning, decorator_no_user_warning)
|
|
assert_warns(UserWarning, decorator_no_deprecation_multiple_warning)
|
|
assert_warns(DeprecationWarning, decorator_no_user_multiple_warning)
|
|
|
|
# Check the context manager
|
|
def context_manager_no_warning():
|
|
with ignore_warnings():
|
|
_warning_function()
|
|
|
|
def context_manager_no_warning_multiple():
|
|
with ignore_warnings(category=(DeprecationWarning, UserWarning)):
|
|
_multiple_warning_function()
|
|
|
|
def context_manager_no_deprecation_warning():
|
|
with ignore_warnings(category=DeprecationWarning):
|
|
_warning_function()
|
|
|
|
def context_manager_no_user_warning():
|
|
with ignore_warnings(category=UserWarning):
|
|
_warning_function()
|
|
|
|
def context_manager_no_deprecation_multiple_warning():
|
|
with ignore_warnings(category=DeprecationWarning):
|
|
_multiple_warning_function()
|
|
|
|
def context_manager_no_user_multiple_warning():
|
|
with ignore_warnings(category=UserWarning):
|
|
_multiple_warning_function()
|
|
|
|
assert_no_warnings(context_manager_no_warning)
|
|
assert_no_warnings(context_manager_no_warning_multiple)
|
|
assert_no_warnings(context_manager_no_deprecation_warning)
|
|
assert_warns(DeprecationWarning, context_manager_no_user_warning)
|
|
assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning)
|
|
assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning)
|
|
|
|
|
|
# This class is inspired from numpy 1.7 with an alteration to check
|
|
# the reset warning filters after calls to assert_warns.
|
|
# This assert_warns behavior is specific to scikit-learn because
|
|
# `clean_warning_registry()` is called internally by assert_warns
|
|
# and clears all previous filters.
|
|
class TestWarns(unittest.TestCase):
|
|
def test_warn(self):
|
|
def f():
|
|
warnings.warn("yo")
|
|
return 3
|
|
|
|
# Test that assert_warns is not impacted by externally set
|
|
# filters and is reset internally.
|
|
# This is because `clean_warning_registry()` is called internally by
|
|
# assert_warns and clears all previous filters.
|
|
warnings.simplefilter("ignore", UserWarning)
|
|
assert_equal(assert_warns(UserWarning, f), 3)
|
|
|
|
# Test that the warning registry is empty after assert_warns
|
|
assert_equal(sys.modules['warnings'].filters, [])
|
|
|
|
assert_raises(AssertionError, assert_no_warnings, f)
|
|
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
|
|
|
|
def test_warn_wrong_warning(self):
|
|
def f():
|
|
warnings.warn("yo", DeprecationWarning)
|
|
|
|
failed = False
|
|
filters = sys.modules['warnings'].filters[:]
|
|
try:
|
|
try:
|
|
# Should raise an AssertionError
|
|
assert_warns(UserWarning, f)
|
|
failed = True
|
|
except AssertionError:
|
|
pass
|
|
finally:
|
|
sys.modules['warnings'].filters = filters
|
|
|
|
if failed:
|
|
raise AssertionError("wrong warning caught by assert_warn")
|
|
|
|
|
|
# Tests for docstrings:
|
|
|
|
def f_ok(a, b):
|
|
"""Function f
|
|
|
|
Parameters
|
|
----------
|
|
a : int
|
|
Parameter a
|
|
b : float
|
|
Parameter b
|
|
|
|
Returns
|
|
-------
|
|
c : list
|
|
Parameter c
|
|
"""
|
|
c = a + b
|
|
return c
|
|
|
|
|
|
def f_bad_sections(a, b):
|
|
"""Function f
|
|
|
|
Parameters
|
|
----------
|
|
a : int
|
|
Parameter a
|
|
b : float
|
|
Parameter b
|
|
|
|
Results
|
|
-------
|
|
c : list
|
|
Parameter c
|
|
"""
|
|
c = a + b
|
|
return c
|
|
|
|
|
|
def f_bad_order(b, a):
|
|
"""Function f
|
|
|
|
Parameters
|
|
----------
|
|
a : int
|
|
Parameter a
|
|
b : float
|
|
Parameter b
|
|
|
|
Returns
|
|
-------
|
|
c : list
|
|
Parameter c
|
|
"""
|
|
c = a + b
|
|
return c
|
|
|
|
|
|
def f_missing(a, b):
|
|
"""Function f
|
|
|
|
Parameters
|
|
----------
|
|
a : int
|
|
Parameter a
|
|
|
|
Returns
|
|
-------
|
|
c : list
|
|
Parameter c
|
|
"""
|
|
c = a + b
|
|
return c
|
|
|
|
|
|
def f_check_param_definition(a, b, c, d):
|
|
"""Function f
|
|
|
|
Parameters
|
|
----------
|
|
a: int
|
|
Parameter a
|
|
b:
|
|
Parameter b
|
|
c :
|
|
Parameter c
|
|
d:int
|
|
Parameter d
|
|
"""
|
|
return a + b + c + d
|
|
|
|
|
|
class Klass(object):
|
|
def f_missing(self, X, y):
|
|
pass
|
|
|
|
def f_bad_sections(self, X, y):
|
|
"""Function f
|
|
|
|
Parameter
|
|
----------
|
|
a : int
|
|
Parameter a
|
|
b : float
|
|
Parameter b
|
|
|
|
Results
|
|
-------
|
|
c : list
|
|
Parameter c
|
|
"""
|
|
pass
|
|
|
|
|
|
class MockEst(object):
|
|
def __init__(self):
|
|
"""MockEstimator"""
|
|
def fit(self, X, y):
|
|
return X
|
|
|
|
def predict(self, X):
|
|
return X
|
|
|
|
def predict_proba(self, X):
|
|
return X
|
|
|
|
def score(self, X):
|
|
return 1.
|
|
|
|
|
|
class MockMetaEstimator(object):
|
|
def __init__(self, delegate):
|
|
"""MetaEstimator to check if doctest on delegated methods work.
|
|
|
|
Parameters
|
|
---------
|
|
delegate : estimator
|
|
Delegated estimator.
|
|
"""
|
|
self.delegate = delegate
|
|
|
|
@if_delegate_has_method(delegate=('delegate'))
|
|
def predict(self, X):
|
|
"""This is available only if delegate has predict.
|
|
|
|
Parameters
|
|
----------
|
|
y : ndarray
|
|
Parameter y
|
|
"""
|
|
return self.delegate.predict(X)
|
|
|
|
@deprecated("Testing a deprecated delegated method")
|
|
@if_delegate_has_method(delegate=('delegate'))
|
|
def score(self, X):
|
|
"""This is available only if delegate has score.
|
|
|
|
Parameters
|
|
---------
|
|
y : ndarray
|
|
Parameter y
|
|
"""
|
|
|
|
@if_delegate_has_method(delegate=('delegate'))
|
|
def predict_proba(self, X):
|
|
"""This is available only if delegate has predict_proba.
|
|
|
|
Parameters
|
|
---------
|
|
X : ndarray
|
|
Parameter X
|
|
"""
|
|
return X
|
|
|
|
@deprecated('Testing deprecated function with incorrect params')
|
|
@if_delegate_has_method(delegate=('delegate'))
|
|
def predict_log_proba(self, X):
|
|
"""This is available only if delegate has predict_proba.
|
|
|
|
Parameters
|
|
---------
|
|
y : ndarray
|
|
Parameter X
|
|
"""
|
|
return X
|
|
|
|
@deprecated('Testing deprecated function with wrong params')
|
|
@if_delegate_has_method(delegate=('delegate'))
|
|
def fit(self, X, y):
|
|
"""Incorrect docstring but should not be tested"""
|
|
|
|
|
|
def test_check_docstring_parameters():
|
|
try:
|
|
import numpydoc # noqa
|
|
assert sys.version_info >= (3, 5)
|
|
except (ImportError, AssertionError):
|
|
raise SkipTest(
|
|
"numpydoc is required to test the docstrings")
|
|
|
|
incorrect = check_docstring_parameters(f_ok)
|
|
assert_equal(incorrect, [])
|
|
incorrect = check_docstring_parameters(f_ok, ignore=['b'])
|
|
assert_equal(incorrect, [])
|
|
incorrect = check_docstring_parameters(f_missing, ignore=['b'])
|
|
assert_equal(incorrect, [])
|
|
assert_raise_message(RuntimeError, 'Unknown section Results',
|
|
check_docstring_parameters, f_bad_sections)
|
|
assert_raise_message(RuntimeError, 'Unknown section Parameter',
|
|
check_docstring_parameters, Klass.f_bad_sections)
|
|
|
|
messages = ["a != b", "arg mismatch: ['b']", "arg mismatch: ['X', 'y']",
|
|
"predict y != X",
|
|
"predict_proba arg mismatch: ['X']",
|
|
"predict_log_proba arg mismatch: ['X']",
|
|
"score arg mismatch: ['X']",
|
|
".fit arg mismatch: ['X', 'y']"]
|
|
|
|
mock_meta = MockMetaEstimator(delegate=MockEst())
|
|
|
|
for mess, f in zip(messages,
|
|
[f_bad_order, f_missing, Klass.f_missing,
|
|
mock_meta.predict, mock_meta.predict_proba,
|
|
mock_meta.predict_log_proba,
|
|
mock_meta.score, mock_meta.fit]):
|
|
incorrect = check_docstring_parameters(f)
|
|
assert_true(len(incorrect) >= 1)
|
|
assert_true(mess in incorrect[0],
|
|
'"%s" not in "%s"' % (mess, incorrect[0]))
|
|
|
|
incorrect = check_docstring_parameters(f_check_param_definition)
|
|
assert_equal(
|
|
incorrect,
|
|
['sklearn.utils.tests.test_testing.f_check_param_definition There was '
|
|
'no space between the param name and colon ("a: int")',
|
|
'sklearn.utils.tests.test_testing.f_check_param_definition There was '
|
|
'no space between the param name and colon ("b:")',
|
|
'sklearn.utils.tests.test_testing.f_check_param_definition Incorrect '
|
|
'type definition for param: "c " (type definition was "")',
|
|
'sklearn.utils.tests.test_testing.f_check_param_definition There was '
|
|
'no space between the param name and colon ("d:int")'])
|