308 lines
13 KiB
Python
308 lines
13 KiB
Python
|
"""
|
||
|
Testing for export functions of decision trees (sklearn.tree.export).
|
||
|
"""
|
||
|
|
||
|
from re import finditer, search
|
||
|
|
||
|
from numpy.random import RandomState
|
||
|
|
||
|
from sklearn.base import is_classifier
|
||
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||
|
from sklearn.ensemble import GradientBoostingClassifier
|
||
|
from sklearn.tree import export_graphviz
|
||
|
from sklearn.externals.six import StringIO
|
||
|
from sklearn.utils.testing import (assert_in, assert_equal, assert_raises,
|
||
|
assert_less_equal, assert_raises_regex,
|
||
|
assert_raise_message)
|
||
|
from sklearn.exceptions import NotFittedError
|
||
|
|
||
|
# toy sample
|
||
|
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||
|
y = [-1, -1, -1, 1, 1, 1]
|
||
|
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
|
||
|
w = [1, 1, 1, .5, .5, .5]
|
||
|
y_degraded = [1, 1, 1, 1, 1, 1]
|
||
|
|
||
|
|
||
|
def test_graphviz_toy():
|
||
|
# Check correctness of export_graphviz
|
||
|
clf = DecisionTreeClassifier(max_depth=3,
|
||
|
min_samples_split=2,
|
||
|
criterion="gini",
|
||
|
random_state=2)
|
||
|
clf.fit(X, y)
|
||
|
|
||
|
# Test export code
|
||
|
contents1 = export_graphviz(clf, out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box] ;\n' \
|
||
|
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||
|
'value = [3, 3]"] ;\n' \
|
||
|
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test with feature_names
|
||
|
contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"],
|
||
|
out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box] ;\n' \
|
||
|
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||
|
'value = [3, 3]"] ;\n' \
|
||
|
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test with class_names
|
||
|
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box] ;\n' \
|
||
|
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||
|
'value = [3, 3]\\nclass = yes"] ;\n' \
|
||
|
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \
|
||
|
'class = yes"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' \
|
||
|
'class = no"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test plot_options
|
||
|
contents1 = export_graphviz(clf, filled=True, impurity=False,
|
||
|
proportion=True, special_characters=True,
|
||
|
rounded=True, out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box, style="filled, rounded", color="black", ' \
|
||
|
'fontname=helvetica] ;\n' \
|
||
|
'edge [fontname=helvetica] ;\n' \
|
||
|
'0 [label=<X<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>' \
|
||
|
'value = [0.5, 0.5]>, fillcolor="#e5813900"] ;\n' \
|
||
|
'1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, ' \
|
||
|
'fillcolor="#e58139ff"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, ' \
|
||
|
'fillcolor="#399de5ff"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test max_depth
|
||
|
contents1 = export_graphviz(clf, max_depth=0,
|
||
|
class_names=True, out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box] ;\n' \
|
||
|
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||
|
'value = [3, 3]\\nclass = y[0]"] ;\n' \
|
||
|
'1 [label="(...)"] ;\n' \
|
||
|
'0 -> 1 ;\n' \
|
||
|
'2 [label="(...)"] ;\n' \
|
||
|
'0 -> 2 ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test max_depth with plot_options
|
||
|
contents1 = export_graphviz(clf, max_depth=0, filled=True,
|
||
|
out_file=None, node_ids=True)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box, style="filled", color="black"] ;\n' \
|
||
|
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
|
||
|
'samples = 6\\nvalue = [3, 3]", fillcolor="#e5813900"] ;\n' \
|
||
|
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||
|
'0 -> 1 ;\n' \
|
||
|
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||
|
'0 -> 2 ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test multi-output with weighted samples
|
||
|
clf = DecisionTreeClassifier(max_depth=2,
|
||
|
min_samples_split=2,
|
||
|
criterion="gini",
|
||
|
random_state=2)
|
||
|
clf = clf.fit(X, y2, sample_weight=w)
|
||
|
|
||
|
contents1 = export_graphviz(clf, filled=True,
|
||
|
impurity=False, out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box, style="filled", color="black"] ;\n' \
|
||
|
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
|
||
|
'value = [[3.0, 1.5, 0.0]\\n' \
|
||
|
'[3.0, 1.0, 0.5]]", fillcolor="#e5813900"] ;\n' \
|
||
|
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
|
||
|
'[3, 0, 0]]", fillcolor="#e58139ff"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
|
||
|
'value = [[0.0, 1.5, 0.0]\\n' \
|
||
|
'[0.0, 1.0, 0.5]]", fillcolor="#e5813986"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
|
||
|
'[0, 1, 0]]", fillcolor="#e58139ff"] ;\n' \
|
||
|
'2 -> 3 ;\n' \
|
||
|
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
|
||
|
'[0.0, 0.0, 0.5]]", fillcolor="#e58139ff"] ;\n' \
|
||
|
'2 -> 4 ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test regression output with plot_options
|
||
|
clf = DecisionTreeRegressor(max_depth=3,
|
||
|
min_samples_split=2,
|
||
|
criterion="mse",
|
||
|
random_state=2)
|
||
|
clf.fit(X, y)
|
||
|
|
||
|
contents1 = export_graphviz(clf, filled=True, leaves_parallel=True,
|
||
|
out_file=None, rotate=True, rounded=True)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box, style="filled, rounded", color="black", ' \
|
||
|
'fontname=helvetica] ;\n' \
|
||
|
'graph [ranksep=equally, splines=polyline] ;\n' \
|
||
|
'edge [fontname=helvetica] ;\n' \
|
||
|
'rankdir=LR ;\n' \
|
||
|
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
|
||
|
'value = 0.0", fillcolor="#e5813980"] ;\n' \
|
||
|
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
|
||
|
'fillcolor="#e5813900"] ;\n' \
|
||
|
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
|
||
|
'headlabel="True"] ;\n' \
|
||
|
'2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \
|
||
|
'fillcolor="#e58139ff"] ;\n' \
|
||
|
'0 -> 2 [labeldistance=2.5, labelangle=45, ' \
|
||
|
'headlabel="False"] ;\n' \
|
||
|
'{rank=same ; 0} ;\n' \
|
||
|
'{rank=same ; 1; 2} ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
# Test classifier with degraded learning set
|
||
|
clf = DecisionTreeClassifier(max_depth=3)
|
||
|
clf.fit(X, y_degraded)
|
||
|
|
||
|
contents1 = export_graphviz(clf, filled=True, out_file=None)
|
||
|
contents2 = 'digraph Tree {\n' \
|
||
|
'node [shape=box, style="filled", color="black"] ;\n' \
|
||
|
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
|
||
|
'fillcolor="#e5813900"] ;\n' \
|
||
|
'}'
|
||
|
|
||
|
assert_equal(contents1, contents2)
|
||
|
|
||
|
|
||
|
def test_graphviz_errors():
|
||
|
# Check for errors of export_graphviz
|
||
|
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
|
||
|
|
||
|
# Check not-fitted decision tree error
|
||
|
out = StringIO()
|
||
|
assert_raises(NotFittedError, export_graphviz, clf, out)
|
||
|
|
||
|
clf.fit(X, y)
|
||
|
|
||
|
# Check if it errors when length of feature_names
|
||
|
# mismatches with number of features
|
||
|
message = ("Length of feature_names, "
|
||
|
"1 does not match number of features, 2")
|
||
|
assert_raise_message(ValueError, message, export_graphviz, clf, None,
|
||
|
feature_names=["a"])
|
||
|
|
||
|
message = ("Length of feature_names, "
|
||
|
"3 does not match number of features, 2")
|
||
|
assert_raise_message(ValueError, message, export_graphviz, clf, None,
|
||
|
feature_names=["a", "b", "c"])
|
||
|
|
||
|
# Check class_names error
|
||
|
out = StringIO()
|
||
|
assert_raises(IndexError, export_graphviz, clf, out, class_names=[])
|
||
|
|
||
|
# Check precision error
|
||
|
out = StringIO()
|
||
|
assert_raises_regex(ValueError, "should be greater or equal",
|
||
|
export_graphviz, clf, out, precision=-1)
|
||
|
assert_raises_regex(ValueError, "should be an integer",
|
||
|
export_graphviz, clf, out, precision="1")
|
||
|
|
||
|
|
||
|
def test_friedman_mse_in_graphviz():
|
||
|
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
|
||
|
clf.fit(X, y)
|
||
|
dot_data = StringIO()
|
||
|
export_graphviz(clf, out_file=dot_data)
|
||
|
|
||
|
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
|
||
|
clf.fit(X, y)
|
||
|
for estimator in clf.estimators_:
|
||
|
export_graphviz(estimator[0], out_file=dot_data)
|
||
|
|
||
|
for finding in finditer("\[.*?samples.*?\]", dot_data.getvalue()):
|
||
|
assert_in("friedman_mse", finding.group())
|
||
|
|
||
|
|
||
|
def test_precision():
|
||
|
|
||
|
rng_reg = RandomState(2)
|
||
|
rng_clf = RandomState(8)
|
||
|
for X, y, clf in zip(
|
||
|
(rng_reg.random_sample((5, 2)),
|
||
|
rng_clf.random_sample((1000, 4))),
|
||
|
(rng_reg.random_sample((5, )),
|
||
|
rng_clf.randint(2, size=(1000, ))),
|
||
|
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
|
||
|
max_depth=1),
|
||
|
DecisionTreeClassifier(max_depth=1, random_state=0))):
|
||
|
|
||
|
clf.fit(X, y)
|
||
|
for precision in (4, 3):
|
||
|
dot_data = export_graphviz(clf, out_file=None, precision=precision,
|
||
|
proportion=True)
|
||
|
|
||
|
# With the current random state, the impurity and the threshold
|
||
|
# will have the number of precision set in the export_graphviz
|
||
|
# function. We will check the number of precision with a strict
|
||
|
# equality. The value reported will have only 2 precision and
|
||
|
# therefore, only a less equal comparison will be done.
|
||
|
|
||
|
# check value
|
||
|
for finding in finditer("value = \d+\.\d+", dot_data):
|
||
|
assert_less_equal(
|
||
|
len(search("\.\d+", finding.group()).group()),
|
||
|
precision + 1)
|
||
|
# check impurity
|
||
|
if is_classifier(clf):
|
||
|
pattern = "gini = \d+\.\d+"
|
||
|
else:
|
||
|
pattern = "friedman_mse = \d+\.\d+"
|
||
|
|
||
|
# check impurity
|
||
|
for finding in finditer(pattern, dot_data):
|
||
|
assert_equal(len(search("\.\d+", finding.group()).group()),
|
||
|
precision + 1)
|
||
|
# check threshold
|
||
|
for finding in finditer("<= \d+\.\d+", dot_data):
|
||
|
assert_equal(len(search("\.\d+", finding.group()).group()),
|
||
|
precision + 1)
|