276 lines
9.4 KiB
Python
276 lines
9.4 KiB
Python
from itertools import chain, product
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import scipy.sparse as sp
|
|
from scipy.linalg import pinv2
|
|
from scipy.sparse.csgraph import laplacian
|
|
|
|
from sklearn.utils.testing import (assert_equal, assert_raises, assert_true,
|
|
assert_almost_equal, assert_array_equal,
|
|
SkipTest, assert_raises_regex,
|
|
assert_greater_equal, ignore_warnings)
|
|
from sklearn.utils import check_random_state
|
|
from sklearn.utils import deprecated
|
|
from sklearn.utils import resample
|
|
from sklearn.utils import safe_mask
|
|
from sklearn.utils import column_or_1d
|
|
from sklearn.utils import safe_indexing
|
|
from sklearn.utils import shuffle
|
|
from sklearn.utils import gen_even_slices
|
|
from sklearn.utils.extmath import pinvh
|
|
from sklearn.utils.arpack import eigsh
|
|
from sklearn.utils.mocking import MockDataFrame
|
|
|
|
|
|
def test_make_rng():
|
|
# Check the check_random_state utility function behavior
|
|
assert_true(check_random_state(None) is np.random.mtrand._rand)
|
|
assert_true(check_random_state(np.random) is np.random.mtrand._rand)
|
|
|
|
rng_42 = np.random.RandomState(42)
|
|
assert_true(check_random_state(42).randint(100) == rng_42.randint(100))
|
|
|
|
rng_42 = np.random.RandomState(42)
|
|
assert_true(check_random_state(rng_42) is rng_42)
|
|
|
|
rng_42 = np.random.RandomState(42)
|
|
assert_true(check_random_state(43).randint(100) != rng_42.randint(100))
|
|
|
|
assert_raises(ValueError, check_random_state, "some invalid seed")
|
|
|
|
|
|
def test_deprecated():
|
|
# Test whether the deprecated decorator issues appropriate warnings
|
|
# Copied almost verbatim from http://docs.python.org/library/warnings.html
|
|
|
|
# First a function...
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
|
|
@deprecated()
|
|
def ham():
|
|
return "spam"
|
|
|
|
spam = ham()
|
|
|
|
assert_equal(spam, "spam") # function must remain usable
|
|
|
|
assert_equal(len(w), 1)
|
|
assert_true(issubclass(w[0].category, DeprecationWarning))
|
|
assert_true("deprecated" in str(w[0].message).lower())
|
|
|
|
# ... then a class.
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
|
|
@deprecated("don't use this")
|
|
class Ham(object):
|
|
SPAM = 1
|
|
|
|
ham = Ham()
|
|
|
|
assert_true(hasattr(ham, "SPAM"))
|
|
|
|
assert_equal(len(w), 1)
|
|
assert_true(issubclass(w[0].category, DeprecationWarning))
|
|
assert_true("deprecated" in str(w[0].message).lower())
|
|
|
|
|
|
def test_resample():
|
|
# Border case not worth mentioning in doctests
|
|
assert_true(resample() is None)
|
|
|
|
# Check that invalid arguments yield ValueError
|
|
assert_raises(ValueError, resample, [0], [0, 1])
|
|
assert_raises(ValueError, resample, [0, 1], [0, 1],
|
|
replace=False, n_samples=3)
|
|
assert_raises(ValueError, resample, [0, 1], [0, 1], meaning_of_life=42)
|
|
# Issue:6581, n_samples can be more when replace is True (default).
|
|
assert_equal(len(resample([1, 2], n_samples=5)), 5)
|
|
|
|
|
|
def test_safe_mask():
|
|
random_state = check_random_state(0)
|
|
X = random_state.rand(5, 4)
|
|
X_csr = sp.csr_matrix(X)
|
|
mask = [False, False, True, True, True]
|
|
|
|
mask = safe_mask(X, mask)
|
|
assert_equal(X[mask].shape[0], 3)
|
|
|
|
mask = safe_mask(X_csr, mask)
|
|
assert_equal(X_csr[mask].shape[0], 3)
|
|
|
|
|
|
@ignore_warnings # Test deprecated backport to be removed in 0.21
|
|
def test_pinvh_simple_real():
|
|
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=np.float64)
|
|
a = np.dot(a, a.T)
|
|
a_pinv = pinvh(a)
|
|
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
|
|
|
|
|
|
@ignore_warnings # Test deprecated backport to be removed in 0.21
|
|
def test_pinvh_nonpositive():
|
|
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
|
|
a = np.dot(a, a.T)
|
|
u, s, vt = np.linalg.svd(a)
|
|
s[0] *= -1
|
|
a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
|
|
a_pinv = pinv2(a)
|
|
a_pinvh = pinvh(a)
|
|
assert_almost_equal(a_pinv, a_pinvh)
|
|
|
|
|
|
@ignore_warnings # Test deprecated backport to be removed in 0.21
|
|
def test_pinvh_simple_complex():
|
|
a = (np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
|
|
+ 1j * np.array([[10, 8, 7], [6, 5, 4], [3, 2, 1]]))
|
|
a = np.dot(a, a.conj().T)
|
|
a_pinv = pinvh(a)
|
|
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
|
|
|
|
|
|
@ignore_warnings # Test deprecated backport to be removed in 0.21
|
|
def test_arpack_eigsh_initialization():
|
|
# Non-regression test that shows null-space computation is better with
|
|
# initialization of eigsh from [-1,1] instead of [0,1]
|
|
random_state = check_random_state(42)
|
|
|
|
A = random_state.rand(50, 50)
|
|
A = np.dot(A.T, A) # create s.p.d. matrix
|
|
A = laplacian(A) + 1e-7 * np.identity(A.shape[0])
|
|
k = 5
|
|
|
|
# Test if eigsh is working correctly
|
|
# New initialization [-1,1] (as in original ARPACK)
|
|
# Was [0,1] before, with which this test could fail
|
|
v0 = random_state.uniform(-1, 1, A.shape[0])
|
|
w, _ = eigsh(A, k=k, sigma=0.0, v0=v0)
|
|
|
|
# Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest
|
|
assert_greater_equal(w[0], 0)
|
|
|
|
|
|
def test_column_or_1d():
|
|
EXAMPLES = [
|
|
("binary", ["spam", "egg", "spam"]),
|
|
("binary", [0, 1, 0, 1]),
|
|
("continuous", np.arange(10) / 20.),
|
|
("multiclass", [1, 2, 3]),
|
|
("multiclass", [0, 1, 2, 2, 0]),
|
|
("multiclass", [[1], [2], [3]]),
|
|
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
|
|
("multiclass-multioutput", [[1, 2, 3]]),
|
|
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
|
|
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
|
|
("multiclass-multioutput", [[1, 2, 3]]),
|
|
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
|
|
]
|
|
|
|
for y_type, y in EXAMPLES:
|
|
if y_type in ["binary", 'multiclass', "continuous"]:
|
|
assert_array_equal(column_or_1d(y), np.ravel(y))
|
|
else:
|
|
assert_raises(ValueError, column_or_1d, y)
|
|
|
|
|
|
def test_safe_indexing():
|
|
X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
|
inds = np.array([1, 2])
|
|
X_inds = safe_indexing(X, inds)
|
|
X_arrays = safe_indexing(np.array(X), inds)
|
|
assert_array_equal(np.array(X_inds), X_arrays)
|
|
assert_array_equal(np.array(X_inds), np.array(X)[inds])
|
|
|
|
|
|
def test_safe_indexing_pandas():
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
raise SkipTest("Pandas not found")
|
|
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
X_df = pd.DataFrame(X)
|
|
inds = np.array([1, 2])
|
|
X_df_indexed = safe_indexing(X_df, inds)
|
|
X_indexed = safe_indexing(X_df, inds)
|
|
assert_array_equal(np.array(X_df_indexed), X_indexed)
|
|
# fun with read-only data in dataframes
|
|
# this happens in joblib memmapping
|
|
X.setflags(write=False)
|
|
X_df_readonly = pd.DataFrame(X)
|
|
inds_readonly = inds.copy()
|
|
inds_readonly.setflags(write=False)
|
|
|
|
for this_df, this_inds in product([X_df, X_df_readonly],
|
|
[inds, inds_readonly]):
|
|
with warnings.catch_warnings(record=True):
|
|
X_df_indexed = safe_indexing(this_df, this_inds)
|
|
|
|
assert_array_equal(np.array(X_df_indexed), X_indexed)
|
|
|
|
|
|
def test_safe_indexing_mock_pandas():
|
|
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
X_df = MockDataFrame(X)
|
|
inds = np.array([1, 2])
|
|
X_df_indexed = safe_indexing(X_df, inds)
|
|
X_indexed = safe_indexing(X_df, inds)
|
|
assert_array_equal(np.array(X_df_indexed), X_indexed)
|
|
|
|
|
|
def test_shuffle_on_ndim_equals_three():
|
|
def to_tuple(A): # to make the inner arrays hashable
|
|
return tuple(tuple(tuple(C) for C in B) for B in A)
|
|
|
|
A = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # A.shape = (2,2,2)
|
|
S = set(to_tuple(A))
|
|
shuffle(A) # shouldn't raise a ValueError for dim = 3
|
|
assert_equal(set(to_tuple(A)), S)
|
|
|
|
|
|
def test_shuffle_dont_convert_to_array():
|
|
# Check that shuffle does not try to convert to numpy arrays with float
|
|
# dtypes can let any indexable datastructure pass-through.
|
|
a = ['a', 'b', 'c']
|
|
b = np.array(['a', 'b', 'c'], dtype=object)
|
|
c = [1, 2, 3]
|
|
d = MockDataFrame(np.array([['a', 0],
|
|
['b', 1],
|
|
['c', 2]],
|
|
dtype=object))
|
|
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
|
|
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
|
|
|
|
assert_equal(a_s, ['c', 'b', 'a'])
|
|
assert_equal(type(a_s), list)
|
|
|
|
assert_array_equal(b_s, ['c', 'b', 'a'])
|
|
assert_equal(b_s.dtype, object)
|
|
|
|
assert_equal(c_s, [3, 2, 1])
|
|
assert_equal(type(c_s), list)
|
|
|
|
assert_array_equal(d_s, np.array([['c', 2],
|
|
['b', 1],
|
|
['a', 0]],
|
|
dtype=object))
|
|
assert_equal(type(d_s), MockDataFrame)
|
|
|
|
assert_array_equal(e_s.toarray(), np.array([[4, 5],
|
|
[2, 3],
|
|
[0, 1]]))
|
|
|
|
|
|
def test_gen_even_slices():
|
|
# check that gen_even_slices contains all samples
|
|
some_range = range(10)
|
|
joined_range = list(chain(*[some_range[slice] for slice in
|
|
gen_even_slices(10, 3)]))
|
|
assert_array_equal(some_range, joined_range)
|
|
|
|
# check that passing negative n_chunks raises an error
|
|
slices = gen_even_slices(10, -1)
|
|
assert_raises_regex(ValueError, "gen_even_slices got n_packs=-1, must be"
|
|
" >=1", next, slices)
|