2017-04-30 140 views
-1

我正在使用tf.train.Saversaverestore保存和恢复TensorFlow模型。在恢复过程中,我正在加载新的输入数据。该restore方法抛出这个错误:TensorFlow变量名称 - 保存/恢复中的分配错误

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [1334,3] rhs shape= [1246,3] [[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@Variable_2"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_2, save/RestoreV2_6)]]

这似乎是说,问题是出在Variable_2,但一个人如何确定哪些变量的代码对应于Variable_2

回答

-1

当您创建一个新的变量得到它是一个独特的名字。 Saver.restore在检查点查看同名。如果你需要一些初始化的变量来自不同的关卡有不同的名称,请看看tf.contrib.framework.init_from_checkpoint

+0

谢谢,但我不太以下错误。我使用的是保存检查点来加载检查点的相同代码;在保存和恢复之间没有创建新的变量。 –

0
  • 如果要恢复的模式,这样就前馈,则该模型的形状和型号杜彦武应该是一样的,当你保存它
  • 所以上面的错误是,当你正在恢复的模型中的一个说其中保存有张形状[1246,3],但您要指派给一个张量,其形状为[1334,3]
  • 明确知道哪些变量是指的名字,你可以指定唯一的名称张量,例如a = tf.placeholder("float", [3, 3], name="tensor_a")
  • 所以,现在在恢复模式,你知道你的模型与NAME =“tensor_a”,这是3倍形状的图的张量3
  • 快速教程在代码:

    # Create some variables. 
    v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer) 
    v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer) 
    
    inc_v1 = v1.assign(v1+1) 
    dec_v2 = v2.assign(v2-1) 
    
    # Add an op to initialize the variables. 
    init_op = tf.global_variables_initializer() 
    
    # Add ops to save and restore all the variables. 
    saver = tf.train.Saver() 
    
    # Later, launch the model, initialize the variables, do some work, and save the 
    # variables to disk. 
    with tf.Session() as sess: 
        sess.run(init_op) 
        # Do some work with the model. 
        inc_v1.op.run() 
        dec_v2.op.run() 
        # Save the variables to disk. 
        save_path = saver.save(sess, "/tmp/model.ckpt") 
        print("Model saved in file: %s" % save_path) 
    
    tf.reset_default_graph() 
    
    # Create some variables. 
    d1 = tf.get_variable("v1", shape=[3]) 
    d2 = tf.get_variable("v2", shape=[5]) 
    
    # Add ops to save and restore all the variables. 
    saver = tf.train.Saver() 
    
    # Later, launch the model, use the saver to restore variables from disk, and 
    # do some work with the model. 
    with tf.Session() as sess: 
        # Restore variables from disk. 
        saver.restore(sess, "/tmp/model.ckpt") 
        print("Model restored.") 
        # Check the values of the variables 
        print("v1 : %s" % d1.eval()) 
        print("v2 : %s" % d2.eval()) 
    
  • 如果你在上面的代码D1注意到且v1具有相同的形状,现在如果你改变任何国税发可变形状会扔给你一个错误,它类似于你越来越