Scikit-learnの決定木で各データの到達ノードを知る方法

概要

Pythonにはscikit-learnという便利な機械学習モジュールがあり、これを用いて決定木学習が可能です。しかし、学習結果や予測結果には学習データや予測データが木のどのノード(リーフ)に辿り着いたかの情報までは含まれておらず、何とか取得したかったので調べたところで、簡単にできたので結果をメモしておきます。
以降では、scikit-learnが持っているサンプルデータ(iris)を使って決定木学習をしたデータを対象に確認していきます。

#必要なモジュールのインポート
from sklearn.datasets import load_iris
from sklearn import tree
 
if __name__ == "__main__":
     
    #irisデータの読み込み
    iris = load_iris()
     
    #決定木学習
    clf = tree.DecisionTreeClassifier()
    clf.fit(iris.data, iris.target)
    

過去記事を参考に決定木を画像化するとこんな感じ。
iris

到達ノード番号の取得

残念ながら、clfのオブジェクトにはそれぞれのデータがどのノードに辿り着いたかは含まれていません。ソースコードを追ってみると、どうやら学習や予測部分はCで処理されており、Pythonから簡単に取得することは難しそうなことが分かりました。ただし、clfオブジェクトにはapplyというメンバ関数がおり、これを使うことで、新しく予測を実施し、各予測対象データがどのノードに辿り着くかを取得することが出来ることが分かりました。

#float32でないと動作しないので変換して入力
clf.tree_.apply(iris.data.astype(np.float32))
Out[0]: 
array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  5,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5, 14,  5,  5,  5,  5,  5,  5, 10,  5,  5,  5,  5,  5, 10,  5,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5, 16, 16,
       16, 16, 16, 16,  6, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        8, 16, 16, 16, 16, 16, 16, 15, 16, 16, 11, 16, 16, 16,  8,  8, 16,
       16, 16, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], dtype=int64)

irisのデータは150行あるので、全部をapplyすると150個の結果が返ってきます。上記で1とか5とか書いている値が辿り着いたノード番号です。このノード番号が決定木のどの部分に相当するか等は下記でまとめていますので参考にしてみてください。

Scikit-learnで学習した決定木構造の取得方法


まとめ

短いですが、各データの到達ノードを知る方法について調査結果をメモしておきました。clf.tree_.applyがキーワードです。


コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA