2

我在查看TF Slim介绍性文档,并且从我所了解的情况来看,每次运行只需要一批图像数据(32幅图像)。很显然,人们想通过这个循环来训练许多不同的批次。介绍不包括这一点。这怎么能正确完成。我想应该有一些方法来指定一个加载批处理函数,它应该在开始批处理训练事件时自动调用,但我似乎无法在介绍中找到一个简单的例子。Tensorflow Slim的批量培训

# Note that this may take several minutes. 

import os 

from datasets import flowers 
from nets import inception 
from preprocessing import inception_preprocessing 

slim = tf.contrib.slim 
image_size = inception.inception_v1.default_image_size 


def get_init_fn(): 
    """Returns a function run by the chief worker to warm-start the training.""" 
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] 

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 

    variables_to_restore = [] 
    for var in slim.get_model_variables(): 
     excluded = False 
     for exclusion in exclusions: 
      if var.op.name.startswith(exclusion): 
       excluded = True 
       break 
     if not excluded: 
      variables_to_restore.append(var) 

    return slim.assign_from_checkpoint_fn(
     os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 
     variables_to_restore) 


train_dir = '/tmp/inception_finetuned/' 

with tf.Graph().as_default(): 
    tf.logging.set_verbosity(tf.logging.INFO) 

    dataset = flowers.get_split('train', flowers_data_dir) 
    images, _, labels = load_batch(dataset, height=image_size, width=image_size) 

    # Create the model, use the default arg scope to configure the batch norm parameters. 
    with slim.arg_scope(inception.inception_v1_arg_scope()): 
     logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 

    # Specify the loss function: 
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 
    slim.losses.softmax_cross_entropy(logits, one_hot_labels) 
    total_loss = slim.losses.get_total_loss() 

    # Create some summaries to visualize the training process: 
    tf.scalar_summary('losses/Total Loss', total_loss) 

    # Specify the optimizer and create the train op: 
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 
    train_op = slim.learning.create_train_op(total_loss, optimizer) 

    # Run the training: 
    final_loss = slim.learning.train(
     train_op, 
     logdir=train_dir, 
     init_fn=get_init_fn(), 
     number_of_steps=2) 


print('Finished training. Last batch loss %f' % final_loss) 
+0

是不是代码示例中的函数load_batch undefined共享?我不熟悉你的例子,但我会开始阅读这个功能,以了解批处理过程。 – pltrdy

+0

它在这里给出https://github.com/tensorflow/models/blob/master/slim/slim_walkthough.ipynb但是除了获得批量外,这没有任何作用。 –

+0

所以你基本上只需要迭代批次? – pltrdy

回答

1

slim.learning.train函数包含一个训练循环,所以你给的代码不会对图像的多个批次的事实火车。

请参阅here in the source code,其中train_step_fn在while循环内被调用。 train_step(默认值为train_step_fn)包含行sess.run([train_op, global_step]...),该行实际上在单批图像上运行训练操作。

+0

好吧,我在load_batch函数中放了一个print语句,并且训练了超过1步,发现加载批处理函数只被调用一次,所以这意味着相同的数据被用于多个步骤,因此这个问题。 –

+0

此外,我没有在调用learning.train时指定load_batch函数,那么它如何才能“知道”使用它来加载新批次? –

+0

我已经做了更多的研究,看起来有一个队列可以从每批自动加载的地方建立起来。为了测试这个,我在这里有一个相关的问题http://stackoverflow.com/questions/41868871/tensorflow-slim-debugging-during-training。请尽可能评论。 –