Scikit-learnで学習した決定木をETEを使って可視化するモジュール(eteview)を構築してみた

概要

Scikit-learnで学習した決定木(パターン分類)をGraphvizではなく、ETEというライブラリを使って可視化するモジュール(eteview)を構築してみました。使い方は以下の通りです。

#必要なモジュールのインポート
from sklearn.datasets import load_iris
from sklearn import tree

if __name__ == "__main__":
    
    #irisデータの読み込み
    iris = load_iris()
     
    #決定木学習
    clf = tree.DecisionTreeClassifier(max_depth=2)
    clf.fit(iris.data, iris.target)
    
    #作ったライブラリで可視化
    eteview(iris, clf)

使うとこんな画像が出力できます。
mytree

ETEってそもそもどんなもの?という人は公式サイトや下記記事を見てみてください。

ETEを使ってPythonで決定木をシャレオツに表示


eteviewの開発コンセプト

eteviewは以下のコンセプトを元に構築しました。

  1. sklearnの決定木オブジェクトを直接引数にとれること
  2. graphvizでは出せないノードごとの円グラフやクラスを示す画像を表示できること
  3. リーフごとに振り分けられた学習データの特徴が分かること

sklearnの決定木オブジェクトを直接引数にとれること

ETEは有用なのですが、特殊なインプットファイルしか受け付けておらず、sklearnの決定木構造を受け付けることはできません。ここは大きな障壁ですので必ず自動化する必要がありました。
eteviewでは、sklearnの決定木オブジェクトの情報を読み込み、同じ構造の決定木をETE上に自動的に作ることができるようにしています。

graphvizでは出せないノードごとの円グラフやクラスを示す画像を表示できること

graphvizで描ける図(下記参照)は、分析者目線ではそこそこ有用なのですが、人に見せて説明するには見た目がちょっと。。。
iris
eteviewでは、ノードごとにデータの割合を示した円グラフを表示し、さらにその円グラフの大きさをデータの数で表現することで、ノードの特徴を分かりやすく表示できるようにしました。また、sklearnno
bunchオブジェクトのtarget_namesと同じ名前の画像ファイル(拡張子は.jpg等)を用意しておくことで、その画像が末端ノード(リーフ)に表示されるようにしました。
bunchオブジェクトが良くわからない、という方は下記をご覧ください。

PandasのデータフレームをScikit-learnの入力データに変換する方法(2.コーディング)


リーフごとに振り分けられた学習データの特徴が分かること

決定木が大きくなってくると、学習結果の各リーフにどんな特徴のデータが分けられているのかなかなか分からないですよね。分岐条件を目で追っていくのは面倒ですし、他の特徴量がどうなっているか分からないのもマイナスポイントです。そこで、eteviewでは割り当てられた学習データがどのような特徴なのかが一目でわかるように、各特徴量ごとのヒストグラムで表示する仕様としました。

eteviewの入出力

eteviewには7つの入力引数があります。また、出力は画像ファイルのみで戻り値はありません。help(eteview)と打つと、以下の説明が表示されます。基本的にはbunch, clf, ymax, figextを使います。

    """
    <概要>
    scikit-learnの決定木をeteを使って可視化するファンクション
     
    <引数>
    bunch:scikit-learn形式に変換したデータ
    clf:決定木オブジェクト(学習済み)
    ymax:ヒストグラムのy軸の最大値(デフォルト30)
    figext:クラス画像ファイルの拡張子(デフォルトjpg)
    outfilename:出力画像ファイル名(デフォルトmytree.png)
    outfiledpi:出力画像ファイルの解像度(デフォルト300)
    fontsize:グラフの文字フォント(デフォルト15)
     
    <出力>
    None:出力なし
    """

スクリプト

eteviewの全体スクリプトは以下の通りです。
GitHubにて、ソースコードを公開しています。

#必要なモジュールのインポート
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import tree
from ete3 import Tree, TreeStyle, TextFace, PieChartFace, ImgFace

def eteview(bunch, clf, ymax=30, figext="jpg", outfilename="mytree.png", outfiledpi=300, fontsize=15):
    """
    <概要>
    scikit-learnの決定木をeteを使って可視化するファンクション
     
    <引数>
    bunch:scikit-learn形式に変換したデータ
    clf:決定木オブジェクト(学習済み)
    ymax:ヒストグラムのy軸の最大値(デフォルト30)
    figext:クラス画像ファイルの拡張子(デフォルトjpg)
    outfilename:出力画像ファイル名(デフォルトmytree.png)
    outfiledpi:出力画像ファイルの解像度(デフォルト300)
    fontsize:グラフの文字フォント(デフォルト15)
     
    <出力>
    None:出力なし
    """
    
    #データフレーム作成
    df = pd.DataFrame(data=bunch.data, columns=bunch.feature_names)
    
    #各カラムの最大値と最小値を取得(グラフ作成時に必要になる)
    maxList = df.max()
    minList = df.min()
    
    #各データの到達ノードIDを取得
    df['#NAMES'] = clf.tree_.apply(bunch.data.astype(np.float32))
    leafList = df['#NAMES'].unique()
    leafList.sort()    

    #グラフで使う色(好みで変えてください)
    cmap = plt.get_cmap('gist_rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(bunch.feature_names))]
    deffont = plt.rcParams["font.size"]
    #グラフのフォントサイズ
    plt.rcParams["font.size"] = fontsize
    
    #到達ノードIDごとに要素ごとのヒストグラムを作成し保存
    for i in leafList:
        fig = plt.figure(figsize=(3*len(bunch.feature_names), 3), dpi=300)
        tdf = df[df['#NAMES'] == i]
        
        for j, c in enumerate(bunch.feature_names):
            ax = fig.add_subplot(1, len(bunch.feature_names), j+1)
            tdf.plot.hist(title=c, color=colors[j],
            bins=np.arange(minList, maxList, (maxList-minList)/30), xlim=(minList, maxList), ylim=(0, ymax))
        
        fig.tight_layout()
        fig.savefig(str(i) + ".jpeg")
        plt.close(fig)
    
    #eteのTreeインスタンスを構築
    tree = Tree()
    
    #treeオブジェクトを構築する
    for i in range(clf.tree_.node_count):
        #ルートノードの名称は0とする
        if i == 0:
            tree.name = str(0)
        
        #親ノードを設定
        node = tree.search_nodes(name=str(i))[0]
        
        #ノードごとに配分の円グラフを作成
        Pie = PieChartFace(percents=clf.tree_.value[i][0] / clf.tree_.value[i].sum() * 100
        , width=clf.tree_.n_node_samples[i]
        , height=clf.tree_.n_node_samples[i])
        Pie.opacity = 0.8
        Pie.hz_align = 1
        Pie.vt_align = 1
        
        #円グラフをセット
        node.add_face(Pie, column=2, position="branch-right")

        #左側の子ノードに関する処理
        if clf.tree_.children_left[i] > -1:
            #ノード名称はtreeのリストIDと一致させる
            node.add_child(name=str(clf.tree_.children_left[i]))
            #子ノードに移る
            node = tree.search_nodes(name=str(clf.tree_.children_left[i]))[0]
            #分岐条件を追加
            node.add_face(TextFace(bunch.feature_names[clf.tree_.feature[i]]), column=0, position="branch-top")
            node.add_face(TextFace(u"≦" + "{0:.2f}".format(clf.tree_.threshold[i])), column=1, position="branch-bottom")
            #親ノードに戻っておく
            node = tree.search_nodes(name=str(i))[0]
        
        #右側の子ノードに関する処理(上記と同様)
        if clf.tree_.children_right[i] > -1:
            node.add_child(name=str(clf.tree_.children_right[i]))
            node = tree.search_nodes(name=str(clf.tree_.children_right[i]))[0]
            node.add_face(TextFace(bunch.feature_names[clf.tree_.feature[i]]), column=0, position="branch-top")
            node.add_face(TextFace(">" + "{0:.2f}".format(clf.tree_.threshold[i])), column=1, position="branch-bottom")
            node = tree.search_nodes(name=str(i))[0]
        
        #リーフノードに関する処理
        if clf.tree_.children_left[i] == -1 and clf.tree_.children_right[i] == -1:
            
            #リーフの情報を取得
            text1 = "{0:.0f}".format(clf.tree_.value[i][0][np.argmax(clf.tree_.value.T, axis=0)[0][i]] / clf.tree_.n_node_samples[i] * 100) + "%"
            text2 = "{0:.0f}".format(clf.tree_.value[i][0][np.argmax(clf.tree_.value.T, axis=0)[0][i]]) +"/"+ "{0:.0f}".format(clf.tree_.n_node_samples[i])
            
            #リーフの情報を書き込み
            node.add_face(TextFace(bunch.target_names[np.argmax(clf.tree_.value.T, axis=0)[0][i]])
            , column=4, position="branch-right")
            node.add_face(TextFace(text1)
            , column=4, position="branch-right")
            node.add_face(TextFace(text2)
            , column=4, position="branch-right")
            
            #クラスに対応した画像を設置
            imgface = ImgFace(bunch.target_names[np.argmax(clf.tree_.value.T, axis=0)[0][i]] + "." + figext, height=80)
            imgface.margin_left = 10
            imgface.margin_right = 10            
            node.add_face(imgface, column=3, position="branch-right")
            
            #作成したヒストグラムを設置
            imgface2 = ImgFace(str(i) + ".jpeg", height=150)
            node.add_face(imgface2, column=4, position="aligned")
    
    #不要な要素を表示しないように設定    
    ts = TreeStyle()
    ts.show_leaf_name = False
    ts.show_scale = False
    
    #ファイル保存
    tree.render(outfilename, dpi=outfiledpi, tree_style=ts)
    
    #グラフのフォントサイズを元に戻す
    plt.rcParams["font.size"] = deffont

まとめ

Scikit-learnで構築した決定木をETEで簡単に表示するライブラリ「eteview」を作ってみました。Graphvizで作成する決定木よりも説明しやすい絵が描けると思いますのでぜひ使ってみてください。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA