2017-05-24 26 views
0

我正在研究GAN,并决定使用HyperGAN来实现我的算法。它是使用TensorFlow的DCGAN封装。 HyperGAN使用TF的检查点方法保存输出。如何获得张量流模型的输出值和输入值?

后来,我试图使用运行负载的模型:

import tensorflow as tf 
sess=tf.Session()  
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 
sess.run(tf.global_variables_initializer()) 

然而,由于它的一个GAN,它需要一个输入潜在向量并输出图像。这是使用

out_image = sess.run(last_node, feed_dict(input_node: value)) 

完成,但因为我装的模式,我不知道是什么的最后一个节点的名称是什么,输入节点占位符的名称是。如何获取用于创建图形的名称?我尝试使用TensorBoard进行可视化,但该图很大,因此卡住了。

回答

0

你应该尽量在图形中打印张量清单:

with tf.Graph().as_default() as graph: 
.... 

count = 0 
for op in graph.get_operations(): 
    print op.values() 
    count+=1 
    if count == 50: 
     assert False 

为了看图形的第一个50个节点,你会看到这样的事情:

(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,) 

我把计数放在那里,因为通常终端会输出很多张数,以至于初始输入节点名称在终端中消失。

最后,干脆注释掉线计数使用:

#count = 0 
for op in graph.get_operations(): 
    print op.values() 
    #count+=1 
    #if count == 50: 
    # assert False 

拿到打印出来的最后几个节点(即你的输出节点)。

相关问题