概要
Pythonにはscikit-learnという便利な機械学習モジュールがあり、これを用いて決定木学習が可能です。しかし、学習結果や予測結果には学習データや予測データが木のどのノード(リーフ)に辿り着いたかの情報までは含まれておらず、何とか取得したかったので調べたところで、簡単にできたので結果をメモしておきます。
以降では、scikit-learnが持っているサンプルデータ(iris)を使って決定木学習をしたデータを対象に確認していきます。
#必要なモジュールのインポート from sklearn.datasets import load_iris from sklearn import tree if __name__ == "__main__": #irisデータの読み込み iris = load_iris() #決定木学習 clf = tree.DecisionTreeClassifier() clf.fit(iris.data, iris.target)
過去記事を参考に決定木を画像化するとこんな感じ。
到達ノード番号の取得
残念ながら、clfのオブジェクトにはそれぞれのデータがどのノードに辿り着いたかは含まれていません。ソースコードを追ってみると、どうやら学習や予測部分はCで処理されており、Pythonから簡単に取得することは難しそうなことが分かりました。ただし、clfオブジェクトにはapplyというメンバ関数がおり、これを使うことで、新しく予測を実施し、各予測対象データがどのノードに辿り着くかを取得することが出来ることが分かりました。
#float32でないと動作しないので変換して入力 clf.tree_.apply(iris.data.astype(np.float32)) Out[0]: array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 14, 5, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 16, 16, 16, 16, 16, 16, 6, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 8, 16, 16, 16, 16, 16, 16, 15, 16, 16, 11, 16, 16, 16, 8, 8, 16, 16, 16, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], dtype=int64)
irisのデータは150行あるので、全部をapplyすると150個の結果が返ってきます。上記で1とか5とか書いている値が辿り着いたノード番号です。このノード番号が決定木のどの部分に相当するか等は下記でまとめていますので参考にしてみてください。
まとめ
短いですが、各データの到達ノードを知る方法について調査結果をメモしておきました。clf.tree_.applyがキーワードです。