TensorFlowでモデルの保存・読み込み
例示モデルの設計
input : 任意のデータ数×3次元ベクトル
w : 3×1次元の空行列
output = input × w
任意の個数の3次元ベクトルをゼロベクトルに変形する。
保存
import tensorflow as tf import numpy as np input = tf.placeholder(tf.int32, shape=(None,3)) w = tf.Variable(np.array([[0,0,0]]).T, dtype=tf.int32, name="v1") output = tf.matmul(input,w) feed_dict = dict() var_in = [[1,1,1],[2,2,2],[3,3,3]] feed_dict = {input : var_in} sess = tf.Session() sess.run(tf.global_variables_initializer()) sess.run(output, feed_dict=feed_dict) saver = tf.train.Saver() saver.save(sess, "./hoge.ckpt")
読み込みと実行
import tensorflow as tf import numpy as np input = tf.placeholder(tf.int32, shape=(None,3)) w = tf.Variable(np.array([[3,3,3]]).T, dtype=tf.int32, name="v1") output = tf.matmul(input,w) feed_dict = dict() var_in = [[1,1,1],[2,2,2],[3,3,3]] feed_dict = {input : var_in} saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, "hoge.ckpt") sess.run(output, feed_dict=feed_dict)
結果にwの変更np.array([[3,3,3]]).T
が反映されない(=保存時のパラメータが読み込まれている)ことに注目しよう。
読み込むときの注意
・モデル作成時と同じplaceholderとVariableを宣言しなければならない。Variableの値は適当で構わない。
・Variableを初期化tf.global_variables_initializer()
すると読み込んだパラメータが消去されるので注意。
・Saverの宣言順番は逆でも構わないtf.train.Saver()
はセッションオブジェクトの作成tf.Session()
の前に行う。