124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
|
"""
|
||
|
Testing for mean shift clustering methods
|
||
|
|
||
|
"""
|
||
|
|
||
|
import numpy as np
|
||
|
import warnings
|
||
|
|
||
|
from scipy import sparse
|
||
|
|
||
|
from sklearn.utils.testing import assert_equal
|
||
|
from sklearn.utils.testing import assert_false
|
||
|
from sklearn.utils.testing import assert_true
|
||
|
from sklearn.utils.testing import assert_array_equal
|
||
|
from sklearn.utils.testing import assert_raise_message
|
||
|
|
||
|
from sklearn.cluster import MeanShift
|
||
|
from sklearn.cluster import mean_shift
|
||
|
from sklearn.cluster import estimate_bandwidth
|
||
|
from sklearn.cluster import get_bin_seeds
|
||
|
from sklearn.datasets.samples_generator import make_blobs
|
||
|
|
||
|
|
||
|
n_clusters = 3
|
||
|
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
|
||
|
X, _ = make_blobs(n_samples=300, n_features=2, centers=centers,
|
||
|
cluster_std=0.4, shuffle=True, random_state=11)
|
||
|
|
||
|
|
||
|
def test_estimate_bandwidth():
|
||
|
# Test estimate_bandwidth
|
||
|
bandwidth = estimate_bandwidth(X, n_samples=200)
|
||
|
assert_true(0.9 <= bandwidth <= 1.5)
|
||
|
|
||
|
|
||
|
def test_mean_shift():
|
||
|
# Test MeanShift algorithm
|
||
|
bandwidth = 1.2
|
||
|
|
||
|
ms = MeanShift(bandwidth=bandwidth)
|
||
|
labels = ms.fit(X).labels_
|
||
|
labels_unique = np.unique(labels)
|
||
|
n_clusters_ = len(labels_unique)
|
||
|
assert_equal(n_clusters_, n_clusters)
|
||
|
|
||
|
cluster_centers, labels = mean_shift(X, bandwidth=bandwidth)
|
||
|
labels_unique = np.unique(labels)
|
||
|
n_clusters_ = len(labels_unique)
|
||
|
assert_equal(n_clusters_, n_clusters)
|
||
|
|
||
|
|
||
|
def test_estimate_bandwidth_with_sparse_matrix():
|
||
|
# Test estimate_bandwidth with sparse matrix
|
||
|
X = sparse.lil_matrix((1000, 1000))
|
||
|
msg = "A sparse matrix was passed, but dense data is required."
|
||
|
assert_raise_message(TypeError, msg, estimate_bandwidth, X, 200)
|
||
|
|
||
|
|
||
|
def test_parallel():
|
||
|
ms1 = MeanShift(n_jobs=2)
|
||
|
ms1.fit(X)
|
||
|
|
||
|
ms2 = MeanShift()
|
||
|
ms2.fit(X)
|
||
|
|
||
|
assert_array_equal(ms1.cluster_centers_, ms2.cluster_centers_)
|
||
|
assert_array_equal(ms1.labels_, ms2.labels_)
|
||
|
|
||
|
|
||
|
def test_meanshift_predict():
|
||
|
# Test MeanShift.predict
|
||
|
ms = MeanShift(bandwidth=1.2)
|
||
|
labels = ms.fit_predict(X)
|
||
|
labels2 = ms.predict(X)
|
||
|
assert_array_equal(labels, labels2)
|
||
|
|
||
|
|
||
|
def test_meanshift_all_orphans():
|
||
|
# init away from the data, crash with a sensible warning
|
||
|
ms = MeanShift(bandwidth=0.1, seeds=[[-9, -9], [-10, -10]])
|
||
|
msg = "No point was within bandwidth=0.1"
|
||
|
assert_raise_message(ValueError, msg, ms.fit, X,)
|
||
|
|
||
|
|
||
|
def test_unfitted():
|
||
|
# Non-regression: before fit, there should be not fitted attributes.
|
||
|
ms = MeanShift()
|
||
|
assert_false(hasattr(ms, "cluster_centers_"))
|
||
|
assert_false(hasattr(ms, "labels_"))
|
||
|
|
||
|
|
||
|
def test_bin_seeds():
|
||
|
# Test the bin seeding technique which can be used in the mean shift
|
||
|
# algorithm
|
||
|
# Data is just 6 points in the plane
|
||
|
X = np.array([[1., 1.], [1.4, 1.4], [1.8, 1.2],
|
||
|
[2., 1.], [2.1, 1.1], [0., 0.]])
|
||
|
|
||
|
# With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
|
||
|
# found
|
||
|
ground_truth = set([(1., 1.), (2., 1.), (0., 0.)])
|
||
|
test_bins = get_bin_seeds(X, 1, 1)
|
||
|
test_result = set([tuple(p) for p in test_bins])
|
||
|
assert_true(len(ground_truth.symmetric_difference(test_result)) == 0)
|
||
|
|
||
|
# With a bin coarseness of 1.0 and min_bin_freq of 2, 2 bins should be
|
||
|
# found
|
||
|
ground_truth = set([(1., 1.), (2., 1.)])
|
||
|
test_bins = get_bin_seeds(X, 1, 2)
|
||
|
test_result = set([tuple(p) for p in test_bins])
|
||
|
assert_true(len(ground_truth.symmetric_difference(test_result)) == 0)
|
||
|
|
||
|
# With a bin size of 0.01 and min_bin_freq of 1, 6 bins should be found
|
||
|
# we bail and use the whole data here.
|
||
|
with warnings.catch_warnings(record=True):
|
||
|
test_bins = get_bin_seeds(X, 0.01, 1)
|
||
|
assert_array_equal(test_bins, X)
|
||
|
|
||
|
# tight clusters around [0, 0] and [1, 1], only get two bins
|
||
|
X, _ = make_blobs(n_samples=100, n_features=2, centers=[[0, 0], [1, 1]],
|
||
|
cluster_std=0.1, random_state=0)
|
||
|
test_bins = get_bin_seeds(X, 1)
|
||
|
assert_array_equal(test_bins, [[0, 0], [1, 1]])
|