在
sklearn
> 0.24 的版本上,有一个API(sklearn.tree.export_text
),可以导出决策树的决策路径信息,导出之后格式如下:|--- petal length (cm) <= 2.45
| |--- value: [0.00]
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- petal width (cm) <= 1.65
| | | | |--- value: [1.00]
| | | |--- petal width (cm) > 1.65
| | | | |--- value: [2.00]
| | |--- petal length (cm) > 4.95
| | | |--- petal width (cm) <= 1.55
| | | | |--- value: [2.00]
| | | |--- petal width (cm) > 1.55
| | | | |--- sepal length (cm) <= 6.95
| | | | | |--- value: [1.00]
| | | | |--- sepal length (cm) > 6.95
| | | | | |--- value: [2.00]
| |--- petal width (cm) > 1.75
| | |--- petal length (cm) <= 4.85
| | | |--- sepal length (cm) <= 5.95
| | | | |--- value: [1.00]
| | | |--- sepal length (cm) > 5.95
| | | | |--- value: [2.00]
| | |--- petal length (cm) > 4.85
| | | |--- value: [2.00]
但是当python版本限定在2.7.5的时候,只能用 sklearn == 0.20.0,不存在这个API,好在还可以从 sklearn 的源码扒一下实现,然后套一下:
from sklearn.base import is_classifier
from sklearn.tree import _tree
def _compute_depth(tree, node):
"""
Returns the depth of the subtree rooted in node.
"""
def compute_depth_(
current_node, current_depth, children_left, children_right, depths
):
depths += [current_depth]
left = children_left[current_node]
right = children_right[current_node]
if left != -1 and right != -1:
compute_depth_(
left, current_depth + 1, children_left, children_right, depths
)
compute_depth_(
right, current_depth + 1, children_left, children_right, depths
)
depths = []
compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
return max(depths)
def export_text(
decision_tree,
*,
feature_names=None,
max_depth=10,
spacing=3,
decimals=2,
show_weights=False,
):
"""Build a text report showing the rules of a decision tree.
Note that backwards compatibility may not be supported.
Parameters
----------
decision_tree : object
The decision tree estimator to be exported.
It can be an instance of
DecisionTreeClassifier or DecisionTreeRegressor.
feature_names : list of str, default=None
A list of length n_features containing the feature names.
If None generic names will be used ("feature_0", "feature_1", ...).
max_depth : int, default=10
Only the first max_depth levels of the tree are exported.
Truncated branches will be marked with "...".
spacing : int, default=3
Number of spaces between edges. The higher it is, the wider the result.
decimals : int, default=2
Number of decimal digits to display.
show_weights : bool, default=False
If true the classification weights will be exported on each leaf.
The classification weights are the number of samples each class.
Returns
-------
report : str
Text summary of all the rules in the decision tree.
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
"""
# check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
class_names = decision_tree.classes_
right_child_fmt = "{} {} <= {}\n"
left_child_fmt = "{} {} > {}\n"
truncation_fmt = "{} {}\n"
if max_depth < 0:
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)
if feature_names is not None and len(feature_names) != tree_.n_features:
raise ValueError(
"feature_names must contain %d elements, got %d"
% (tree_.n_features, len(feature_names))
)
if spacing <= 0:
raise ValueError("spacing must be > 0, given %d" % spacing)
if decimals < 0:
raise ValueError("decimals must be >= 0, given %d" % decimals)
if isinstance(decision_tree, DecisionTreeClassifier):
value_fmt = "{}{} weights: {}\n"
if not show_weights:
value_fmt = "{}{}{}\n"
else:
value_fmt = "{}{} value: {}\n"
if feature_names:
feature_names_ = [
feature_names[i] if i != _tree.TREE_UNDEFINED else None
for i in tree_.feature
]
else:
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]
export_text.report = ""
def _add_leaf(value, class_name, indent):
val = ""
is_classification = isinstance(decision_tree, DecisionTreeClassifier)
if show_weights or not is_classification:
val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
val = "[" + "".join(val)[:-2] + "]"
if is_classification:
val += " class: " + str(class_name)
export_text.report += value_fmt.format(indent, "", val)
def print_tree_recurse(node, depth):
indent = ("|" + (" " * spacing)) * depth
indent = indent[:-spacing] + "-" * spacing
value = None
if tree_.n_outputs == 1:
value = tree_.value[node][0]
else:
value = tree_.value[node].T[0]
class_name = np.argmax(value)
if tree_.n_classes[0] != 1 and tree_.n_outputs == 1:
class_name = class_names[class_name]
if depth <= max_depth + 1:
info_fmt = ""
info_fmt_left = info_fmt
info_fmt_right = info_fmt
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_names_[node]
threshold = tree_.threshold[node]
threshold = "{1:.{0}f}".format(decimals, threshold)
export_text.report += right_child_fmt.format(indent, name, threshold)
export_text.report += info_fmt_left
print_tree_recurse(tree_.children_left[node], depth + 1)
export_text.report += left_child_fmt.format(indent, name, threshold)
export_text.report += info_fmt_right
print_tree_recurse(tree_.children_right[node], depth + 1)
else: # leaf
_add_leaf(value, class_name, indent)
else:
subtree_depth = _compute_depth(tree_, node)
if subtree_depth == 1:
_add_leaf(value, class_name, indent)
else:
trunc_report = "truncated branch of depth %d" % subtree_depth
export_text.report += truncation_fmt.format(indent, trunc_report)
print_tree_recurse(0, 1)
return export_text.report
试了一下,完美运行。