signed

QiShunwang

“诚信为本、客户至上”

GAN模型的学习(11)———Build graph

2021/6/3 15:45:20   来源:

首先跑一遍Model,然后saver = tf.train.Saver()保存和加载模型
saver.save(sess, ‘路径 + 模型文件名’)
tf.train.Saver()
NOTE:

Tensorflow 会自动生成4个文件 第一个文件为 model.ckpt.meta,保存了 Tensorflow
计算图的结构,可以简单理解为神经网络的网络结构。 model.ckpt.index 和
model.ckpt.data--of- 文件保存了所有变量的取值。 最后一个文件为 checkpoint
文件,保存了一个目录下所有的模型文件列表。

feed_dict_test_init = {gan.test_path_placeholder: test_paths}
feed_dict_train_init = {gan.path_placeholder: paths}

这里解释不好,也不强行解释了
在tf.Session()中,有个tf.ConfigProto()它的主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算,在这里也不过多解释了,详情见:Tensorflow中tf.ConfigProto()详解
正式进入session

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

初始化全局变量,初始化局部变量

train_handle = sess.run(gan.train_iterator.string_handle())
test_handle = sess.run(gan.test_iterator.string_handle())

返回表示此迭代器的字符串值 tf.Tensor

print("restore_last--->",args.restore_last)
print("model_checkpoint_path--->",ckpt.model_checkpoint_path)
if args.restore_last and ckpt.model_checkpoint_path:
    # Continue training saved model
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('{} restored.'.format(ckpt.model_checkpoint_path))
else:
    if args.restore_path:
        new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))
        new_saver.restore(sess, args.restore_path)
        print('{} restored.'.format(args.restore_path))

tensorflow保存和恢复模型saver.restor

tf.train.Saver() 与tf.train.import_meta_graph要点

for epoch in range(config.num_epochs):

    sess.run(gan.train_iterator.initializer, feed_dict=feed_dict_train_init)

    # Run diagnostics
    G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,
        start_time, epoch, args.name, G_loss_best, D_loss_best)

    while True:
        try:
            # Update generator
            # for _ in range(8):
            feed_dict = {gan.training_phase: True, gan.handle: train_handle}
            sess.run(gan.G_train_op, feed_dict=feed_dict)

            # Update discriminator 
            step, _ = sess.run([gan.D_global_step, gan.D_train_op], feed_dict=feed_dict)

            if step % config.diagnostic_steps == 0:
                G_loss_best, D_loss_best = Utils.run_diagnostics(gan, config, directories, sess, saver, train_handle,
                    start_time, epoch, args.name, G_loss_best, D_loss_best)
                Utils.single_plot(epoch, step, sess, gan, train_handle, args.name, config)
                # for _ in range(4):
                #    sess.run(gan.G_train_op, feed_dict=feed_dict)


        except tf.errors.OutOfRangeError:
            print('End of epoch!')
            break

        except KeyboardInterrupt:
            save_path = saver.save(sess, os.path.join(directories.checkpoints,
                '{}_last.ckpt'.format(args.name)), global_step=epoch)
            print('Interrupted, model saved to: ', save_path)
            sys.exit()

argparse.ArgumentParser()用法解析