きっと続かんブログ

勉強したことや人に言いたいことを書く。

TensorFlowでモデルの保存・読み込み

例示モデルの設計

input : 任意のデータ数×3次元ベクトル
w : 3×1次元の空行列
output = input × w

任意の個数の3次元ベクトルをゼロベクトルに変形する。

保存

import tensorflow as tf
import numpy as np

input = tf.placeholder(tf.int32, shape=(None,3))
w = tf.Variable(np.array([[0,0,0]]).T, dtype=tf.int32, name="v1")
output = tf.matmul(input,w)

feed_dict = dict()
var_in = [[1,1,1],[2,2,2],[3,3,3]]
feed_dict = {input : var_in}

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(output, feed_dict=feed_dict)

saver = tf.train.Saver()
saver.save(sess, "./hoge.ckpt")

読み込みと実行

import tensorflow as tf
import numpy as np

input = tf.placeholder(tf.int32, shape=(None,3))
w = tf.Variable(np.array([[3,3,3]]).T, dtype=tf.int32, name="v1")
output = tf.matmul(input,w)

feed_dict = dict()
var_in = [[1,1,1],[2,2,2],[3,3,3]]
feed_dict = {input : var_in}

saver = tf.train.Saver()

sess = tf.Session()
saver.restore(sess, "hoge.ckpt")

sess.run(output, feed_dict=feed_dict)

結果にwの変更np.array([[3,3,3]]).Tが反映されない(=保存時のパラメータが読み込まれている)ことに注目しよう。

読み込むときの注意

・モデル作成時と同じplaceholderとVariableを宣言しなければならない。Variableの値は適当で構わない。
・Variableを初期化tf.global_variables_initializer()すると読み込んだパラメータが消去されるので注意。
Saverの宣言tf.train.Saver()はセッションオブジェクトの作成tf.Session()の前に行う。 順番は逆でも構わない