1346 lines
52 KiB
Python
1346 lines
52 KiB
Python
"""
|
|
The :mod:`sklearn.model_selection._validation` module includes classes and
|
|
functions to validate the model.
|
|
"""
|
|
|
|
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
# Gael Varoquaux <gael.varoquaux@normalesup.org>
|
|
# Olivier Grisel <olivier.grisel@ensta.org>
|
|
# Raghav RV <rvraghav93@gmail.com>
|
|
# License: BSD 3 clause
|
|
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
import warnings
|
|
import numbers
|
|
import time
|
|
|
|
import numpy as np
|
|
import scipy.sparse as sp
|
|
|
|
from ..base import is_classifier, clone
|
|
from ..utils import indexable, check_random_state, safe_indexing
|
|
from ..utils.deprecation import DeprecationDict
|
|
from ..utils.validation import _is_arraylike, _num_samples
|
|
from ..utils.metaestimators import _safe_split
|
|
from ..externals.joblib import Parallel, delayed, logger
|
|
from ..externals.six.moves import zip
|
|
from ..metrics.scorer import check_scoring, _check_multimetric_scoring
|
|
from ..exceptions import FitFailedWarning
|
|
from ._split import check_cv
|
|
from ..preprocessing import LabelEncoder
|
|
|
|
|
|
__all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict',
|
|
'permutation_test_score', 'learning_curve', 'validation_curve']
|
|
|
|
|
|
def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
|
|
n_jobs=1, verbose=0, fit_params=None,
|
|
pre_dispatch='2*n_jobs', return_train_score="warn"):
|
|
"""Evaluate metric(s) by cross-validation and also record fit/score times.
|
|
|
|
Read more in the :ref:`User Guide <multimetric_cross_validation>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like
|
|
The data to fit. Can be for example a list, or an array.
|
|
|
|
y : array-like, optional, default: None
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Group labels for the samples used while splitting the dataset into
|
|
train/test set.
|
|
|
|
scoring : string, callable, list/tuple, dict or None, default: None
|
|
A single string (see :ref:`scoring_parameter`) or a callable
|
|
(see :ref:`scoring`) to evaluate the predictions on the test set.
|
|
|
|
For evaluating multiple metrics, either give a list of (unique) strings
|
|
or a dict with names as keys and callables as values.
|
|
|
|
NOTE that when using custom scorers, each scorer should return a single
|
|
value. Metric functions returning a list/array of values can be wrapped
|
|
into multiple scorers that return one value each.
|
|
|
|
See :ref:`multimetric_grid_search` for an example.
|
|
|
|
If None, the estimator's default scorer (if available) is used.
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
n_jobs : integer, optional
|
|
The number of CPUs to use to do the computation. -1 means
|
|
'all CPUs'.
|
|
|
|
verbose : integer, optional
|
|
The verbosity level.
|
|
|
|
fit_params : dict, optional
|
|
Parameters to pass to the fit method of the estimator.
|
|
|
|
pre_dispatch : int, or string, optional
|
|
Controls the number of jobs that get dispatched during parallel
|
|
execution. Reducing this number can be useful to avoid an
|
|
explosion of memory consumption when more jobs get dispatched
|
|
than CPUs can process. This parameter can be:
|
|
|
|
- None, in which case all the jobs are immediately
|
|
created and spawned. Use this for lightweight and
|
|
fast-running jobs, to avoid delays due to on-demand
|
|
spawning of the jobs
|
|
|
|
- An int, giving the exact number of total jobs that are
|
|
spawned
|
|
|
|
- A string, giving an expression as a function of n_jobs,
|
|
as in '2*n_jobs'
|
|
|
|
return_train_score : boolean, optional
|
|
Whether to include train scores.
|
|
|
|
Current default is ``'warn'``, which behaves as ``True`` in addition
|
|
to raising a warning when a training score is looked up.
|
|
That default will be changed to ``False`` in 0.21.
|
|
Computing training scores is used to get insights on how different
|
|
parameter settings impact the overfitting/underfitting trade-off.
|
|
However computing the scores on the training set can be computationally
|
|
expensive and is not strictly required to select the parameters that
|
|
yield the best generalization performance.
|
|
|
|
Returns
|
|
-------
|
|
scores : dict of float arrays of shape=(n_splits,)
|
|
Array of scores of the estimator for each run of the cross validation.
|
|
|
|
A dict of arrays containing the score/time arrays for each scorer is
|
|
returned. The possible keys for this ``dict`` are:
|
|
|
|
``test_score``
|
|
The score array for test scores on each cv split.
|
|
``train_score``
|
|
The score array for train scores on each cv split.
|
|
This is available only if ``return_train_score`` parameter
|
|
is ``True``.
|
|
``fit_time``
|
|
The time for fitting the estimator on the train
|
|
set for each cv split.
|
|
``score_time``
|
|
The time for scoring the estimator on the test set for each
|
|
cv split. (Note time for scoring on the train set is not
|
|
included even if ``return_train_score`` is set to ``True``
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn import datasets, linear_model
|
|
>>> from sklearn.model_selection import cross_validate
|
|
>>> from sklearn.metrics.scorer import make_scorer
|
|
>>> from sklearn.metrics import confusion_matrix
|
|
>>> from sklearn.svm import LinearSVC
|
|
>>> diabetes = datasets.load_diabetes()
|
|
>>> X = diabetes.data[:150]
|
|
>>> y = diabetes.target[:150]
|
|
>>> lasso = linear_model.Lasso()
|
|
|
|
Single metric evaluation using ``cross_validate``
|
|
|
|
>>> cv_results = cross_validate(lasso, X, y, return_train_score=False)
|
|
>>> sorted(cv_results.keys()) # doctest: +ELLIPSIS
|
|
['fit_time', 'score_time', 'test_score']
|
|
>>> cv_results['test_score'] # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
array([ 0.33..., 0.08..., 0.03...])
|
|
|
|
Multiple metric evaluation using ``cross_validate``
|
|
(please refer the ``scoring`` parameter doc for more information)
|
|
|
|
>>> scores = cross_validate(lasso, X, y,
|
|
... scoring=('r2', 'neg_mean_squared_error'))
|
|
>>> print(scores['test_neg_mean_squared_error']) # doctest: +ELLIPSIS
|
|
[-3635.5... -3573.3... -6114.7...]
|
|
>>> print(scores['train_r2']) # doctest: +ELLIPSIS
|
|
[ 0.28... 0.39... 0.22...]
|
|
|
|
See Also
|
|
---------
|
|
:func:`sklearn.model_selection.cross_val_score`:
|
|
Run cross-validation for single metric evaluation.
|
|
|
|
:func:`sklearn.metrics.make_scorer`:
|
|
Make a scorer from a performance metric or loss function.
|
|
|
|
"""
|
|
X, y, groups = indexable(X, y, groups)
|
|
|
|
cv = check_cv(cv, y, classifier=is_classifier(estimator))
|
|
scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring)
|
|
|
|
# We clone the estimator to make sure that all the folds are
|
|
# independent, and that it is pickle-able.
|
|
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
|
|
pre_dispatch=pre_dispatch)
|
|
scores = parallel(
|
|
delayed(_fit_and_score)(
|
|
clone(estimator), X, y, scorers, train, test, verbose, None,
|
|
fit_params, return_train_score=return_train_score,
|
|
return_times=True)
|
|
for train, test in cv.split(X, y, groups))
|
|
|
|
if return_train_score:
|
|
train_scores, test_scores, fit_times, score_times = zip(*scores)
|
|
train_scores = _aggregate_score_dicts(train_scores)
|
|
else:
|
|
test_scores, fit_times, score_times = zip(*scores)
|
|
test_scores = _aggregate_score_dicts(test_scores)
|
|
|
|
# TODO: replace by a dict in 0.21
|
|
ret = DeprecationDict() if return_train_score == 'warn' else {}
|
|
ret['fit_time'] = np.array(fit_times)
|
|
ret['score_time'] = np.array(score_times)
|
|
|
|
for name in scorers:
|
|
ret['test_%s' % name] = np.array(test_scores[name])
|
|
if return_train_score:
|
|
key = 'train_%s' % name
|
|
ret[key] = np.array(train_scores[name])
|
|
if return_train_score == 'warn':
|
|
message = (
|
|
'You are accessing a training score ({!r}), '
|
|
'which will not be available by default '
|
|
'any more in 0.21. If you need training scores, '
|
|
'please set return_train_score=True').format(key)
|
|
# warn on key access
|
|
ret.add_warning(key, message, FutureWarning)
|
|
|
|
return ret
|
|
|
|
|
|
def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
|
|
n_jobs=1, verbose=0, fit_params=None,
|
|
pre_dispatch='2*n_jobs'):
|
|
"""Evaluate a score by cross-validation
|
|
|
|
Read more in the :ref:`User Guide <cross_validation>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like
|
|
The data to fit. Can be for example a list, or an array.
|
|
|
|
y : array-like, optional, default: None
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Group labels for the samples used while splitting the dataset into
|
|
train/test set.
|
|
|
|
scoring : string, callable or None, optional, default: None
|
|
A string (see model evaluation documentation) or
|
|
a scorer callable object / function with signature
|
|
``scorer(estimator, X, y)``.
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
n_jobs : integer, optional
|
|
The number of CPUs to use to do the computation. -1 means
|
|
'all CPUs'.
|
|
|
|
verbose : integer, optional
|
|
The verbosity level.
|
|
|
|
fit_params : dict, optional
|
|
Parameters to pass to the fit method of the estimator.
|
|
|
|
pre_dispatch : int, or string, optional
|
|
Controls the number of jobs that get dispatched during parallel
|
|
execution. Reducing this number can be useful to avoid an
|
|
explosion of memory consumption when more jobs get dispatched
|
|
than CPUs can process. This parameter can be:
|
|
|
|
- None, in which case all the jobs are immediately
|
|
created and spawned. Use this for lightweight and
|
|
fast-running jobs, to avoid delays due to on-demand
|
|
spawning of the jobs
|
|
|
|
- An int, giving the exact number of total jobs that are
|
|
spawned
|
|
|
|
- A string, giving an expression as a function of n_jobs,
|
|
as in '2*n_jobs'
|
|
|
|
Returns
|
|
-------
|
|
scores : array of float, shape=(len(list(cv)),)
|
|
Array of scores of the estimator for each run of the cross validation.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn import datasets, linear_model
|
|
>>> from sklearn.model_selection import cross_val_score
|
|
>>> diabetes = datasets.load_diabetes()
|
|
>>> X = diabetes.data[:150]
|
|
>>> y = diabetes.target[:150]
|
|
>>> lasso = linear_model.Lasso()
|
|
>>> print(cross_val_score(lasso, X, y)) # doctest: +ELLIPSIS
|
|
[ 0.33150734 0.08022311 0.03531764]
|
|
|
|
See Also
|
|
---------
|
|
:func:`sklearn.model_selection.cross_validate`:
|
|
To run cross-validation on multiple metrics and also to return
|
|
train scores, fit times and score times.
|
|
|
|
:func:`sklearn.metrics.make_scorer`:
|
|
Make a scorer from a performance metric or loss function.
|
|
|
|
"""
|
|
# To ensure multimetric format is not supported
|
|
scorer = check_scoring(estimator, scoring=scoring)
|
|
|
|
cv_results = cross_validate(estimator=estimator, X=X, y=y, groups=groups,
|
|
scoring={'score': scorer}, cv=cv,
|
|
return_train_score=False,
|
|
n_jobs=n_jobs, verbose=verbose,
|
|
fit_params=fit_params,
|
|
pre_dispatch=pre_dispatch)
|
|
return cv_results['test_score']
|
|
|
|
|
|
def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
|
|
parameters, fit_params, return_train_score=False,
|
|
return_parameters=False, return_n_test_samples=False,
|
|
return_times=False, error_score='raise'):
|
|
"""Fit estimator and compute scores for a given dataset split.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like of shape at least 2D
|
|
The data to fit.
|
|
|
|
y : array-like, optional, default: None
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
scorer : A single callable or dict mapping scorer name to the callable
|
|
If it is a single callable, the return value for ``train_scores`` and
|
|
``test_scores`` is a single float.
|
|
|
|
For a dict, it should be one mapping the scorer name to the scorer
|
|
callable object / function.
|
|
|
|
The callable object / fn should have signature
|
|
``scorer(estimator, X, y)``.
|
|
|
|
train : array-like, shape (n_train_samples,)
|
|
Indices of training samples.
|
|
|
|
test : array-like, shape (n_test_samples,)
|
|
Indices of test samples.
|
|
|
|
verbose : integer
|
|
The verbosity level.
|
|
|
|
error_score : 'raise' (default) or numeric
|
|
Value to assign to the score if an error occurs in estimator fitting.
|
|
If set to 'raise', the error is raised. If a numeric value is given,
|
|
FitFailedWarning is raised. This parameter does not affect the refit
|
|
step, which will always raise the error.
|
|
|
|
parameters : dict or None
|
|
Parameters to be set on the estimator.
|
|
|
|
fit_params : dict or None
|
|
Parameters that will be passed to ``estimator.fit``.
|
|
|
|
return_train_score : boolean, optional, default: False
|
|
Compute and return score on training set.
|
|
|
|
return_parameters : boolean, optional, default: False
|
|
Return parameters that has been used for the estimator.
|
|
|
|
return_n_test_samples : boolean, optional, default: False
|
|
Whether to return the ``n_test_samples``
|
|
|
|
return_times : boolean, optional, default: False
|
|
Whether to return the fit/score times.
|
|
|
|
Returns
|
|
-------
|
|
train_scores : dict of scorer name -> float, optional
|
|
Score on training set (for all the scorers),
|
|
returned only if `return_train_score` is `True`.
|
|
|
|
test_scores : dict of scorer name -> float, optional
|
|
Score on testing set (for all the scorers).
|
|
|
|
n_test_samples : int
|
|
Number of test samples.
|
|
|
|
fit_time : float
|
|
Time spent for fitting in seconds.
|
|
|
|
score_time : float
|
|
Time spent for scoring in seconds.
|
|
|
|
parameters : dict or None, optional
|
|
The parameters that have been evaluated.
|
|
"""
|
|
if verbose > 1:
|
|
if parameters is None:
|
|
msg = ''
|
|
else:
|
|
msg = '%s' % (', '.join('%s=%s' % (k, v)
|
|
for k, v in parameters.items()))
|
|
print("[CV] %s %s" % (msg, (64 - len(msg)) * '.'))
|
|
|
|
# Adjust length of sample weights
|
|
fit_params = fit_params if fit_params is not None else {}
|
|
fit_params = dict([(k, _index_param_value(X, v, train))
|
|
for k, v in fit_params.items()])
|
|
|
|
test_scores = {}
|
|
train_scores = {}
|
|
if parameters is not None:
|
|
estimator.set_params(**parameters)
|
|
|
|
start_time = time.time()
|
|
|
|
X_train, y_train = _safe_split(estimator, X, y, train)
|
|
X_test, y_test = _safe_split(estimator, X, y, test, train)
|
|
|
|
is_multimetric = not callable(scorer)
|
|
n_scorers = len(scorer.keys()) if is_multimetric else 1
|
|
|
|
try:
|
|
if y_train is None:
|
|
estimator.fit(X_train, **fit_params)
|
|
else:
|
|
estimator.fit(X_train, y_train, **fit_params)
|
|
|
|
except Exception as e:
|
|
# Note fit time as time until error
|
|
fit_time = time.time() - start_time
|
|
score_time = 0.0
|
|
if error_score == 'raise':
|
|
raise
|
|
elif isinstance(error_score, numbers.Number):
|
|
if is_multimetric:
|
|
test_scores = dict(zip(scorer.keys(),
|
|
[error_score, ] * n_scorers))
|
|
if return_train_score:
|
|
train_scores = dict(zip(scorer.keys(),
|
|
[error_score, ] * n_scorers))
|
|
else:
|
|
test_scores = error_score
|
|
if return_train_score:
|
|
train_scores = error_score
|
|
warnings.warn("Classifier fit failed. The score on this train-test"
|
|
" partition for these parameters will be set to %f. "
|
|
"Details: \n%r" % (error_score, e), FitFailedWarning)
|
|
else:
|
|
raise ValueError("error_score must be the string 'raise' or a"
|
|
" numeric value. (Hint: if using 'raise', please"
|
|
" make sure that it has been spelled correctly.)")
|
|
|
|
else:
|
|
fit_time = time.time() - start_time
|
|
# _score will return dict if is_multimetric is True
|
|
test_scores = _score(estimator, X_test, y_test, scorer, is_multimetric)
|
|
score_time = time.time() - start_time - fit_time
|
|
if return_train_score:
|
|
train_scores = _score(estimator, X_train, y_train, scorer,
|
|
is_multimetric)
|
|
|
|
if verbose > 2:
|
|
if is_multimetric:
|
|
for scorer_name, score in test_scores.items():
|
|
msg += ", %s=%s" % (scorer_name, score)
|
|
else:
|
|
msg += ", score=%s" % test_scores
|
|
if verbose > 1:
|
|
total_time = score_time + fit_time
|
|
end_msg = "%s, total=%s" % (msg, logger.short_format_time(total_time))
|
|
print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
|
|
|
|
ret = [train_scores, test_scores] if return_train_score else [test_scores]
|
|
|
|
if return_n_test_samples:
|
|
ret.append(_num_samples(X_test))
|
|
if return_times:
|
|
ret.extend([fit_time, score_time])
|
|
if return_parameters:
|
|
ret.append(parameters)
|
|
return ret
|
|
|
|
|
|
def _score(estimator, X_test, y_test, scorer, is_multimetric=False):
|
|
"""Compute the score(s) of an estimator on a given test set.
|
|
|
|
Will return a single float if is_multimetric is False and a dict of floats,
|
|
if is_multimetric is True
|
|
"""
|
|
if is_multimetric:
|
|
return _multimetric_score(estimator, X_test, y_test, scorer)
|
|
else:
|
|
if y_test is None:
|
|
score = scorer(estimator, X_test)
|
|
else:
|
|
score = scorer(estimator, X_test, y_test)
|
|
|
|
if hasattr(score, 'item'):
|
|
try:
|
|
# e.g. unwrap memmapped scalars
|
|
score = score.item()
|
|
except ValueError:
|
|
# non-scalar?
|
|
pass
|
|
|
|
if not isinstance(score, numbers.Number):
|
|
raise ValueError("scoring must return a number, got %s (%s) "
|
|
"instead. (scorer=%r)"
|
|
% (str(score), type(score), scorer))
|
|
return score
|
|
|
|
|
|
def _multimetric_score(estimator, X_test, y_test, scorers):
|
|
"""Return a dict of score for multimetric scoring"""
|
|
scores = {}
|
|
|
|
for name, scorer in scorers.items():
|
|
if y_test is None:
|
|
score = scorer(estimator, X_test)
|
|
else:
|
|
score = scorer(estimator, X_test, y_test)
|
|
|
|
if hasattr(score, 'item'):
|
|
try:
|
|
# e.g. unwrap memmapped scalars
|
|
score = score.item()
|
|
except ValueError:
|
|
# non-scalar?
|
|
pass
|
|
scores[name] = score
|
|
|
|
if not isinstance(score, numbers.Number):
|
|
raise ValueError("scoring must return a number, got %s (%s) "
|
|
"instead. (scorer=%s)"
|
|
% (str(score), type(score), name))
|
|
return scores
|
|
|
|
|
|
def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
|
|
verbose=0, fit_params=None, pre_dispatch='2*n_jobs',
|
|
method='predict'):
|
|
"""Generate cross-validated estimates for each input data point
|
|
|
|
Read more in the :ref:`User Guide <cross_validation>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit' and 'predict'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like
|
|
The data to fit. Can be, for example a list, or an array at least 2d.
|
|
|
|
y : array-like, optional, default: None
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Group labels for the samples used while splitting the dataset into
|
|
train/test set.
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
n_jobs : integer, optional
|
|
The number of CPUs to use to do the computation. -1 means
|
|
'all CPUs'.
|
|
|
|
verbose : integer, optional
|
|
The verbosity level.
|
|
|
|
fit_params : dict, optional
|
|
Parameters to pass to the fit method of the estimator.
|
|
|
|
pre_dispatch : int, or string, optional
|
|
Controls the number of jobs that get dispatched during parallel
|
|
execution. Reducing this number can be useful to avoid an
|
|
explosion of memory consumption when more jobs get dispatched
|
|
than CPUs can process. This parameter can be:
|
|
|
|
- None, in which case all the jobs are immediately
|
|
created and spawned. Use this for lightweight and
|
|
fast-running jobs, to avoid delays due to on-demand
|
|
spawning of the jobs
|
|
|
|
- An int, giving the exact number of total jobs that are
|
|
spawned
|
|
|
|
- A string, giving an expression as a function of n_jobs,
|
|
as in '2*n_jobs'
|
|
|
|
method : string, optional, default: 'predict'
|
|
Invokes the passed method name of the passed estimator. For
|
|
method='predict_proba', the columns correspond to the classes
|
|
in sorted order.
|
|
|
|
Returns
|
|
-------
|
|
predictions : ndarray
|
|
This is the result of calling ``method``
|
|
|
|
Notes
|
|
-----
|
|
In the case that one or more classes are absent in a training portion, a
|
|
default score needs to be assigned to all instances for that class if
|
|
``method`` produces columns per class, as in {'decision_function',
|
|
'predict_proba', 'predict_log_proba'}. For ``predict_proba`` this value is
|
|
0. In order to ensure finite output, we approximate negative infinity by
|
|
the minimum finite float value for the dtype in other cases.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn import datasets, linear_model
|
|
>>> from sklearn.model_selection import cross_val_predict
|
|
>>> diabetes = datasets.load_diabetes()
|
|
>>> X = diabetes.data[:150]
|
|
>>> y = diabetes.target[:150]
|
|
>>> lasso = linear_model.Lasso()
|
|
>>> y_pred = cross_val_predict(lasso, X, y)
|
|
"""
|
|
X, y, groups = indexable(X, y, groups)
|
|
|
|
cv = check_cv(cv, y, classifier=is_classifier(estimator))
|
|
|
|
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
|
|
le = LabelEncoder()
|
|
y = le.fit_transform(y)
|
|
|
|
# We clone the estimator to make sure that all the folds are
|
|
# independent, and that it is pickle-able.
|
|
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
|
|
pre_dispatch=pre_dispatch)
|
|
prediction_blocks = parallel(delayed(_fit_and_predict)(
|
|
clone(estimator), X, y, train, test, verbose, fit_params, method)
|
|
for train, test in cv.split(X, y, groups))
|
|
|
|
# Concatenate the predictions
|
|
predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
|
|
test_indices = np.concatenate([indices_i
|
|
for _, indices_i in prediction_blocks])
|
|
|
|
if not _check_is_permutation(test_indices, _num_samples(X)):
|
|
raise ValueError('cross_val_predict only works for partitions')
|
|
|
|
inv_test_indices = np.empty(len(test_indices), dtype=int)
|
|
inv_test_indices[test_indices] = np.arange(len(test_indices))
|
|
|
|
# Check for sparse predictions
|
|
if sp.issparse(predictions[0]):
|
|
predictions = sp.vstack(predictions, format=predictions[0].format)
|
|
else:
|
|
predictions = np.concatenate(predictions)
|
|
return predictions[inv_test_indices]
|
|
|
|
|
|
def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
|
|
method):
|
|
"""Fit estimator and predict values for a given dataset split.
|
|
|
|
Read more in the :ref:`User Guide <cross_validation>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit' and 'predict'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like of shape at least 2D
|
|
The data to fit.
|
|
|
|
y : array-like, optional, default: None
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
train : array-like, shape (n_train_samples,)
|
|
Indices of training samples.
|
|
|
|
test : array-like, shape (n_test_samples,)
|
|
Indices of test samples.
|
|
|
|
verbose : integer
|
|
The verbosity level.
|
|
|
|
fit_params : dict or None
|
|
Parameters that will be passed to ``estimator.fit``.
|
|
|
|
method : string
|
|
Invokes the passed method name of the passed estimator.
|
|
|
|
Returns
|
|
-------
|
|
predictions : sequence
|
|
Result of calling 'estimator.method'
|
|
|
|
test : array-like
|
|
This is the value of the test parameter
|
|
"""
|
|
# Adjust length of sample weights
|
|
fit_params = fit_params if fit_params is not None else {}
|
|
fit_params = dict([(k, _index_param_value(X, v, train))
|
|
for k, v in fit_params.items()])
|
|
|
|
X_train, y_train = _safe_split(estimator, X, y, train)
|
|
X_test, _ = _safe_split(estimator, X, y, test, train)
|
|
|
|
if y_train is None:
|
|
estimator.fit(X_train, **fit_params)
|
|
else:
|
|
estimator.fit(X_train, y_train, **fit_params)
|
|
func = getattr(estimator, method)
|
|
predictions = func(X_test)
|
|
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
|
|
n_classes = len(set(y))
|
|
if n_classes != len(estimator.classes_):
|
|
recommendation = (
|
|
'To fix this, use a cross-validation '
|
|
'technique resulting in properly '
|
|
'stratified folds')
|
|
warnings.warn('Number of classes in training fold ({}) does '
|
|
'not match total number of classes ({}). '
|
|
'Results may not be appropriate for your use case. '
|
|
'{}'.format(len(estimator.classes_),
|
|
n_classes, recommendation),
|
|
RuntimeWarning)
|
|
if method == 'decision_function':
|
|
if (predictions.ndim == 2 and
|
|
predictions.shape[1] != len(estimator.classes_)):
|
|
# This handles the case when the shape of predictions
|
|
# does not match the number of classes used to train
|
|
# it with. This case is found when sklearn.svm.SVC is
|
|
# set to `decision_function_shape='ovo'`.
|
|
raise ValueError('Output shape {} of {} does not match '
|
|
'number of classes ({}) in fold. '
|
|
'Irregular decision_function outputs '
|
|
'are not currently supported by '
|
|
'cross_val_predict'.format(
|
|
predictions.shape, method,
|
|
len(estimator.classes_),
|
|
recommendation))
|
|
if len(estimator.classes_) <= 2:
|
|
# In this special case, `predictions` contains a 1D array.
|
|
raise ValueError('Only {} class/es in training fold, this '
|
|
'is not supported for decision_function '
|
|
'with imbalanced folds. {}'.format(
|
|
len(estimator.classes_),
|
|
recommendation))
|
|
|
|
float_min = np.finfo(predictions.dtype).min
|
|
default_values = {'decision_function': float_min,
|
|
'predict_log_proba': float_min,
|
|
'predict_proba': 0}
|
|
predictions_for_all_classes = np.full((_num_samples(predictions),
|
|
n_classes),
|
|
default_values[method])
|
|
predictions_for_all_classes[:, estimator.classes_] = predictions
|
|
predictions = predictions_for_all_classes
|
|
return predictions, test
|
|
|
|
|
|
def _check_is_permutation(indices, n_samples):
|
|
"""Check whether indices is a reordering of the array np.arange(n_samples)
|
|
|
|
Parameters
|
|
----------
|
|
indices : ndarray
|
|
integer array to test
|
|
n_samples : int
|
|
number of expected elements
|
|
|
|
Returns
|
|
-------
|
|
is_partition : bool
|
|
True iff sorted(indices) is np.arange(n)
|
|
"""
|
|
if len(indices) != n_samples:
|
|
return False
|
|
hit = np.zeros(n_samples, dtype=bool)
|
|
hit[indices] = True
|
|
if not np.all(hit):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _index_param_value(X, v, indices):
|
|
"""Private helper function for parameter value indexing."""
|
|
if not _is_arraylike(v) or _num_samples(v) != _num_samples(X):
|
|
# pass through: skip indexing
|
|
return v
|
|
if sp.issparse(v):
|
|
v = v.tocsr()
|
|
return safe_indexing(v, indices)
|
|
|
|
|
|
def permutation_test_score(estimator, X, y, groups=None, cv=None,
|
|
n_permutations=100, n_jobs=1, random_state=0,
|
|
verbose=0, scoring=None):
|
|
"""Evaluate the significance of a cross-validated score with permutations
|
|
|
|
Read more in the :ref:`User Guide <cross_validation>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : estimator object implementing 'fit'
|
|
The object to use to fit the data.
|
|
|
|
X : array-like of shape at least 2D
|
|
The data to fit.
|
|
|
|
y : array-like
|
|
The target variable to try to predict in the case of
|
|
supervised learning.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Labels to constrain permutation within groups, i.e. ``y`` values
|
|
are permuted among samples with the same group identifier.
|
|
When not specified, ``y`` values are permuted among all samples.
|
|
|
|
When a grouped cross-validator is used, the group labels are
|
|
also passed on to the ``split`` method of the cross-validator. The
|
|
cross-validator uses them for grouping the samples while splitting
|
|
the dataset into train/test set.
|
|
|
|
scoring : string, callable or None, optional, default: None
|
|
A single string (see :ref:`scoring_parameter`) or a callable
|
|
(see :ref:`scoring`) to evaluate the predictions on the test set.
|
|
|
|
If None the estimator's default scorer, if available, is used.
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
n_permutations : integer, optional
|
|
Number of times to permute ``y``.
|
|
|
|
n_jobs : integer, optional
|
|
The number of CPUs to use to do the computation. -1 means
|
|
'all CPUs'.
|
|
|
|
random_state : int, RandomState instance or None, optional (default=0)
|
|
If int, random_state is the seed used by the random number generator;
|
|
If RandomState instance, random_state is the random number generator;
|
|
If None, the random number generator is the RandomState instance used
|
|
by `np.random`.
|
|
|
|
verbose : integer, optional
|
|
The verbosity level.
|
|
|
|
Returns
|
|
-------
|
|
score : float
|
|
The true score without permuting targets.
|
|
|
|
permutation_scores : array, shape (n_permutations,)
|
|
The scores obtained for each permutations.
|
|
|
|
pvalue : float
|
|
The p-value, which approximates the probability that the score would
|
|
be obtained by chance. This is calculated as:
|
|
|
|
`(C + 1) / (n_permutations + 1)`
|
|
|
|
Where C is the number of permutations whose score >= the true score.
|
|
|
|
The best possible p-value is 1/(n_permutations + 1), the worst is 1.0.
|
|
|
|
Notes
|
|
-----
|
|
This function implements Test 1 in:
|
|
|
|
Ojala and Garriga. Permutation Tests for Studying Classifier
|
|
Performance. The Journal of Machine Learning Research (2010)
|
|
vol. 11
|
|
|
|
"""
|
|
X, y, groups = indexable(X, y, groups)
|
|
|
|
cv = check_cv(cv, y, classifier=is_classifier(estimator))
|
|
scorer = check_scoring(estimator, scoring=scoring)
|
|
random_state = check_random_state(random_state)
|
|
|
|
# We clone the estimator to make sure that all the folds are
|
|
# independent, and that it is pickle-able.
|
|
score = _permutation_test_score(clone(estimator), X, y, groups, cv, scorer)
|
|
permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
|
|
delayed(_permutation_test_score)(
|
|
clone(estimator), X, _shuffle(y, groups, random_state),
|
|
groups, cv, scorer)
|
|
for _ in range(n_permutations))
|
|
permutation_scores = np.array(permutation_scores)
|
|
pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1)
|
|
return score, permutation_scores, pvalue
|
|
|
|
|
|
permutation_test_score.__test__ = False # to avoid a pb with nosetests
|
|
|
|
|
|
def _permutation_test_score(estimator, X, y, groups, cv, scorer):
|
|
"""Auxiliary function for permutation_test_score"""
|
|
avg_score = []
|
|
for train, test in cv.split(X, y, groups):
|
|
X_train, y_train = _safe_split(estimator, X, y, train)
|
|
X_test, y_test = _safe_split(estimator, X, y, test, train)
|
|
estimator.fit(X_train, y_train)
|
|
avg_score.append(scorer(estimator, X_test, y_test))
|
|
return np.mean(avg_score)
|
|
|
|
|
|
def _shuffle(y, groups, random_state):
|
|
"""Return a shuffled copy of y eventually shuffle among same groups."""
|
|
if groups is None:
|
|
indices = random_state.permutation(len(y))
|
|
else:
|
|
indices = np.arange(len(groups))
|
|
for group in np.unique(groups):
|
|
this_mask = (groups == group)
|
|
indices[this_mask] = random_state.permutation(indices[this_mask])
|
|
return safe_indexing(y, indices)
|
|
|
|
|
|
def learning_curve(estimator, X, y, groups=None,
|
|
train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None,
|
|
exploit_incremental_learning=False, n_jobs=1,
|
|
pre_dispatch="all", verbose=0, shuffle=False,
|
|
random_state=None):
|
|
"""Learning curve.
|
|
|
|
Determines cross-validated training and test scores for different training
|
|
set sizes.
|
|
|
|
A cross-validation generator splits the whole dataset k times in training
|
|
and test data. Subsets of the training set with varying sizes will be used
|
|
to train the estimator and a score for each training subset size and the
|
|
test set will be computed. Afterwards, the scores will be averaged over
|
|
all k runs for each training subset size.
|
|
|
|
Read more in the :ref:`User Guide <learning_curve>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : object type that implements the "fit" and "predict" methods
|
|
An object of that type which is cloned for each validation.
|
|
|
|
X : array-like, shape (n_samples, n_features)
|
|
Training vector, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
y : array-like, shape (n_samples) or (n_samples, n_features), optional
|
|
Target relative to X for classification or regression;
|
|
None for unsupervised learning.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Group labels for the samples used while splitting the dataset into
|
|
train/test set.
|
|
|
|
train_sizes : array-like, shape (n_ticks,), dtype float or int
|
|
Relative or absolute numbers of training examples that will be used to
|
|
generate the learning curve. If the dtype is float, it is regarded as a
|
|
fraction of the maximum size of the training set (that is determined
|
|
by the selected validation method), i.e. it has to be within (0, 1].
|
|
Otherwise it is interpreted as absolute sizes of the training sets.
|
|
Note that for classification the number of samples usually have to
|
|
be big enough to contain at least one sample from each class.
|
|
(default: np.linspace(0.1, 1.0, 5))
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
scoring : string, callable or None, optional, default: None
|
|
A string (see model evaluation documentation) or
|
|
a scorer callable object / function with signature
|
|
``scorer(estimator, X, y)``.
|
|
|
|
exploit_incremental_learning : boolean, optional, default: False
|
|
If the estimator supports incremental learning, this will be
|
|
used to speed up fitting for different training set sizes.
|
|
|
|
n_jobs : integer, optional
|
|
Number of jobs to run in parallel (default 1).
|
|
|
|
pre_dispatch : integer or string, optional
|
|
Number of predispatched jobs for parallel execution (default is
|
|
all). The option can reduce the allocated memory. The string can
|
|
be an expression like '2*n_jobs'.
|
|
|
|
verbose : integer, optional
|
|
Controls the verbosity: the higher, the more messages.
|
|
|
|
shuffle : boolean, optional
|
|
Whether to shuffle training data before taking prefixes of it
|
|
based on``train_sizes``.
|
|
|
|
random_state : int, RandomState instance or None, optional (default=None)
|
|
If int, random_state is the seed used by the random number generator;
|
|
If RandomState instance, random_state is the random number generator;
|
|
If None, the random number generator is the RandomState instance used
|
|
by `np.random`. Used when ``shuffle`` is True.
|
|
|
|
Returns
|
|
-------
|
|
train_sizes_abs : array, shape = (n_unique_ticks,), dtype int
|
|
Numbers of training examples that has been used to generate the
|
|
learning curve. Note that the number of ticks might be less
|
|
than n_ticks because duplicate entries will be removed.
|
|
|
|
train_scores : array, shape (n_ticks, n_cv_folds)
|
|
Scores on training sets.
|
|
|
|
test_scores : array, shape (n_ticks, n_cv_folds)
|
|
Scores on test set.
|
|
|
|
Notes
|
|
-----
|
|
See :ref:`examples/model_selection/plot_learning_curve.py
|
|
<sphx_glr_auto_examples_model_selection_plot_learning_curve.py>`
|
|
"""
|
|
if exploit_incremental_learning and not hasattr(estimator, "partial_fit"):
|
|
raise ValueError("An estimator must support the partial_fit interface "
|
|
"to exploit incremental learning")
|
|
X, y, groups = indexable(X, y, groups)
|
|
|
|
cv = check_cv(cv, y, classifier=is_classifier(estimator))
|
|
# Store it as list as we will be iterating over the list multiple times
|
|
cv_iter = list(cv.split(X, y, groups))
|
|
|
|
scorer = check_scoring(estimator, scoring=scoring)
|
|
|
|
n_max_training_samples = len(cv_iter[0][0])
|
|
# Because the lengths of folds can be significantly different, it is
|
|
# not guaranteed that we use all of the available training data when we
|
|
# use the first 'n_max_training_samples' samples.
|
|
train_sizes_abs = _translate_train_sizes(train_sizes,
|
|
n_max_training_samples)
|
|
n_unique_ticks = train_sizes_abs.shape[0]
|
|
if verbose > 0:
|
|
print("[learning_curve] Training set sizes: " + str(train_sizes_abs))
|
|
|
|
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
|
|
verbose=verbose)
|
|
|
|
if shuffle:
|
|
rng = check_random_state(random_state)
|
|
cv_iter = ((rng.permutation(train), test) for train, test in cv_iter)
|
|
|
|
if exploit_incremental_learning:
|
|
classes = np.unique(y) if is_classifier(estimator) else None
|
|
out = parallel(delayed(_incremental_fit_estimator)(
|
|
clone(estimator), X, y, classes, train, test, train_sizes_abs,
|
|
scorer, verbose) for train, test in cv_iter)
|
|
else:
|
|
train_test_proportions = []
|
|
for train, test in cv_iter:
|
|
for n_train_samples in train_sizes_abs:
|
|
train_test_proportions.append((train[:n_train_samples], test))
|
|
|
|
out = parallel(delayed(_fit_and_score)(
|
|
clone(estimator), X, y, scorer, train, test,
|
|
verbose, parameters=None, fit_params=None, return_train_score=True)
|
|
for train, test in train_test_proportions)
|
|
out = np.array(out)
|
|
n_cv_folds = out.shape[0] // n_unique_ticks
|
|
out = out.reshape(n_cv_folds, n_unique_ticks, 2)
|
|
|
|
out = np.asarray(out).transpose((2, 1, 0))
|
|
|
|
return train_sizes_abs, out[0], out[1]
|
|
|
|
|
|
def _translate_train_sizes(train_sizes, n_max_training_samples):
|
|
"""Determine absolute sizes of training subsets and validate 'train_sizes'.
|
|
|
|
Examples:
|
|
_translate_train_sizes([0.5, 1.0], 10) -> [5, 10]
|
|
_translate_train_sizes([5, 10], 10) -> [5, 10]
|
|
|
|
Parameters
|
|
----------
|
|
train_sizes : array-like, shape (n_ticks,), dtype float or int
|
|
Numbers of training examples that will be used to generate the
|
|
learning curve. If the dtype is float, it is regarded as a
|
|
fraction of 'n_max_training_samples', i.e. it has to be within (0, 1].
|
|
|
|
n_max_training_samples : int
|
|
Maximum number of training samples (upper bound of 'train_sizes').
|
|
|
|
Returns
|
|
-------
|
|
train_sizes_abs : array, shape (n_unique_ticks,), dtype int
|
|
Numbers of training examples that will be used to generate the
|
|
learning curve. Note that the number of ticks might be less
|
|
than n_ticks because duplicate entries will be removed.
|
|
"""
|
|
train_sizes_abs = np.asarray(train_sizes)
|
|
n_ticks = train_sizes_abs.shape[0]
|
|
n_min_required_samples = np.min(train_sizes_abs)
|
|
n_max_required_samples = np.max(train_sizes_abs)
|
|
if np.issubdtype(train_sizes_abs.dtype, np.floating):
|
|
if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0:
|
|
raise ValueError("train_sizes has been interpreted as fractions "
|
|
"of the maximum number of training samples and "
|
|
"must be within (0, 1], but is within [%f, %f]."
|
|
% (n_min_required_samples,
|
|
n_max_required_samples))
|
|
train_sizes_abs = (train_sizes_abs * n_max_training_samples).astype(
|
|
dtype=np.int, copy=False)
|
|
train_sizes_abs = np.clip(train_sizes_abs, 1,
|
|
n_max_training_samples)
|
|
else:
|
|
if (n_min_required_samples <= 0 or
|
|
n_max_required_samples > n_max_training_samples):
|
|
raise ValueError("train_sizes has been interpreted as absolute "
|
|
"numbers of training samples and must be within "
|
|
"(0, %d], but is within [%d, %d]."
|
|
% (n_max_training_samples,
|
|
n_min_required_samples,
|
|
n_max_required_samples))
|
|
|
|
train_sizes_abs = np.unique(train_sizes_abs)
|
|
if n_ticks > train_sizes_abs.shape[0]:
|
|
warnings.warn("Removed duplicate entries from 'train_sizes'. Number "
|
|
"of ticks will be less than the size of "
|
|
"'train_sizes' %d instead of %d)."
|
|
% (train_sizes_abs.shape[0], n_ticks), RuntimeWarning)
|
|
|
|
return train_sizes_abs
|
|
|
|
|
|
def _incremental_fit_estimator(estimator, X, y, classes, train, test,
|
|
train_sizes, scorer, verbose):
|
|
"""Train estimator on training subsets incrementally and compute scores."""
|
|
train_scores, test_scores = [], []
|
|
partitions = zip(train_sizes, np.split(train, train_sizes)[:-1])
|
|
for n_train_samples, partial_train in partitions:
|
|
train_subset = train[:n_train_samples]
|
|
X_train, y_train = _safe_split(estimator, X, y, train_subset)
|
|
X_partial_train, y_partial_train = _safe_split(estimator, X, y,
|
|
partial_train)
|
|
X_test, y_test = _safe_split(estimator, X, y, test, train_subset)
|
|
if y_partial_train is None:
|
|
estimator.partial_fit(X_partial_train, classes=classes)
|
|
else:
|
|
estimator.partial_fit(X_partial_train, y_partial_train,
|
|
classes=classes)
|
|
train_scores.append(_score(estimator, X_train, y_train, scorer))
|
|
test_scores.append(_score(estimator, X_test, y_test, scorer))
|
|
return np.array((train_scores, test_scores)).T
|
|
|
|
|
|
def validation_curve(estimator, X, y, param_name, param_range, groups=None,
|
|
cv=None, scoring=None, n_jobs=1, pre_dispatch="all",
|
|
verbose=0):
|
|
"""Validation curve.
|
|
|
|
Determine training and test scores for varying parameter values.
|
|
|
|
Compute scores for an estimator with different values of a specified
|
|
parameter. This is similar to grid search with one parameter. However, this
|
|
will also compute training scores and is merely a utility for plotting the
|
|
results.
|
|
|
|
Read more in the :ref:`User Guide <learning_curve>`.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : object type that implements the "fit" and "predict" methods
|
|
An object of that type which is cloned for each validation.
|
|
|
|
X : array-like, shape (n_samples, n_features)
|
|
Training vector, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
y : array-like, shape (n_samples) or (n_samples, n_features), optional
|
|
Target relative to X for classification or regression;
|
|
None for unsupervised learning.
|
|
|
|
param_name : string
|
|
Name of the parameter that will be varied.
|
|
|
|
param_range : array-like, shape (n_values,)
|
|
The values of the parameter that will be evaluated.
|
|
|
|
groups : array-like, with shape (n_samples,), optional
|
|
Group labels for the samples used while splitting the dataset into
|
|
train/test set.
|
|
|
|
cv : int, cross-validation generator or an iterable, optional
|
|
Determines the cross-validation splitting strategy.
|
|
Possible inputs for cv are:
|
|
|
|
- None, to use the default 3-fold cross validation,
|
|
- integer, to specify the number of folds in a `(Stratified)KFold`,
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train, test splits.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and ``y`` is
|
|
either binary or multiclass, :class:`StratifiedKFold` is used. In all
|
|
other cases, :class:`KFold` is used.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
scoring : string, callable or None, optional, default: None
|
|
A string (see model evaluation documentation) or
|
|
a scorer callable object / function with signature
|
|
``scorer(estimator, X, y)``.
|
|
|
|
n_jobs : integer, optional
|
|
Number of jobs to run in parallel (default 1).
|
|
|
|
pre_dispatch : integer or string, optional
|
|
Number of predispatched jobs for parallel execution (default is
|
|
all). The option can reduce the allocated memory. The string can
|
|
be an expression like '2*n_jobs'.
|
|
|
|
verbose : integer, optional
|
|
Controls the verbosity: the higher, the more messages.
|
|
|
|
Returns
|
|
-------
|
|
train_scores : array, shape (n_ticks, n_cv_folds)
|
|
Scores on training sets.
|
|
|
|
test_scores : array, shape (n_ticks, n_cv_folds)
|
|
Scores on test set.
|
|
|
|
Notes
|
|
-----
|
|
See :ref:`sphx_glr_auto_examples_model_selection_plot_validation_curve.py`
|
|
|
|
"""
|
|
X, y, groups = indexable(X, y, groups)
|
|
|
|
cv = check_cv(cv, y, classifier=is_classifier(estimator))
|
|
scorer = check_scoring(estimator, scoring=scoring)
|
|
|
|
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
|
|
verbose=verbose)
|
|
out = parallel(delayed(_fit_and_score)(
|
|
clone(estimator), X, y, scorer, train, test, verbose,
|
|
parameters={param_name: v}, fit_params=None, return_train_score=True)
|
|
# NOTE do not change order of iteration to allow one time cv splitters
|
|
for train, test in cv.split(X, y, groups) for v in param_range)
|
|
out = np.asarray(out)
|
|
n_params = len(param_range)
|
|
n_cv_folds = out.shape[0] // n_params
|
|
out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))
|
|
|
|
return out[0], out[1]
|
|
|
|
|
|
def _aggregate_score_dicts(scores):
|
|
"""Aggregate the list of dict to dict of np ndarray
|
|
|
|
The aggregated output of _fit_and_score will be a list of dict
|
|
of form [{'prec': 0.1, 'acc':1.0}, {'prec': 0.1, 'acc':1.0}, ...]
|
|
Convert it to a dict of array {'prec': np.array([0.1 ...]), ...}
|
|
|
|
Parameters
|
|
----------
|
|
|
|
scores : list of dict
|
|
List of dicts of the scores for all scorers. This is a flat list,
|
|
assumed originally to be of row major order.
|
|
|
|
Example
|
|
-------
|
|
|
|
>>> scores = [{'a': 1, 'b':10}, {'a': 2, 'b':2}, {'a': 3, 'b':3},
|
|
... {'a': 10, 'b': 10}] # doctest: +SKIP
|
|
>>> _aggregate_score_dicts(scores) # doctest: +SKIP
|
|
{'a': array([1, 2, 3, 10]),
|
|
'b': array([10, 2, 3, 10])}
|
|
"""
|
|
out = {}
|
|
for key in scores[0]:
|
|
out[key] = np.asarray([score[key] for score in scores])
|
|
return out
|