PythonでGraphvizを使ってニューラルネットワークを描画する

概要

説明資料を作るときに意外と書くのが面倒なニューラルネットワークの構成図を、PythonのPydotplus(又はPydot)からGraphvizを使って簡単に作るモジュール(neuralViz)を作成しました。以下イメージです。

構成を記載するだけで使えるように作っています。

neuralViz([4, 5, 5, 3])

環境構築

Anacondaがインストールされている前提で、本モジュールを使うために以下の環境を導入してください。

プログラム

以下がソースコードです。適当に張り付けて使ってください。

#必要なモジュールのインポート
import numpy as np
import pydotplus

def neuralViz(shape, fileName='default.png', plot=True, Wlist=None):
     
    #形状をnp配列に変換
    args = np.array(shape)
     
    #空のmatrixを作成(adjacency_matrix作成用)
    matrix = np.zeros((args.sum(), args.sum()))
    
    #結合の強さに関するmatrixを取得したいときに使います。(本記事では未使用)
    if Wlist is not None:
        wmatrix = np.zeros((args.sum(), args.sum())) 
     
    #結ぶノード間に1を入力
    for i in range(len(args)-1):
        tmp1 = args[:i+1].sum()
        tmp0 = args[:i].sum()
        matrix[tmp1:tmp1+args[i+1],tmp0:tmp0+args[i]] = 1   
              
        if  Wlist is not None:
            wmatrix[tmp1:tmp1+args[i+1],tmp0:tmp0+args[i]] = Wlist[i].data
     
    #GraphVizの機能を使って描画
    g=pydotplus.graph_from_adjacency_matrix(matrix.T.tolist(), node_prefix=0)
 
    #何もしないとNodeが作られないので足す。ついでに文字を消す。
    for i in range(args.sum()):
        n = pydotplus.Node(i+1)
        n.set_fontsize(0)
        g.add_node(n)
    
    if plot == True:
        #グラフ書き出し('default.png')    
        g.write_png(fileName, prog='dot')
        
    else:
        return g, matrix.T, wmatrix.T 

いろいろ試してみる

せっかく作ったのでいろいろ試してみました。

単純・複雑な構成にトライ

ちょっと複雑な構成。

neuralViz([3, 4, 6, 3, 2, 5])


逆に単純な構成。

neuralViz([1, 1, 1])


どちらもうまく描けてます。

ノードの文字色や塗りつぶし色を変えてみる

gを返すように改造すると、g.get_node_list()やg.get_edge_list()でノードやエッジの情報を編集できるので試してみました。

#まずモデル作成(モジュールの最後にreturn gを追加してください)
g = neuralViz([3, 5, 5, 4])
#ノード一覧を取得
nodeList = g.get_node_list()
#適当なノードを選択しいろいろ編集
node = nodeList[10]
node.set_color('green')
node.set_style('filled')
node = nodeList[11]
node.set_color('red')
#書き出し
g.write_png("test.png", prog='dot')


是非試してみてください。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA