2017-08-01 82 views
0

我的问题是关于如何从多个(或分片)tfrecords获取批量输入。我已阅读示例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410。基本流程是,以训练集为例,(1)首先从这些文件名中生成一系列记录(例如,train-000-of-005,train-001-of-005,...),(2),生成一个列表并将它们送入tf.train.string_input_producer (3)同时生成一个tf.RandomShuffleQueue做其他的事情,(4)使用tf.train.batch_join来生成批量输入。有没有更简单的方法来处理来自tfrecords的批量输入?

我认为这很复杂,我不确定这个过程的逻辑。在我的情况下,我有一个.npy文件列表,我想要生成分片tfrecords(多个分离的tfrecords,而不仅仅是一个大文件)。这些.npy文件中的每一个都包含不同数量的正面和负面样本(2个类别)。一个基本的方法是生成一个大的tfrecord文件。但该文件太大(~20Gb)。所以我诉诸分片tfrecords。有没有更简单的方法来做到这一点?谢谢。

回答

11

整个过程使用Dataset API简化。这里有两个部分:(1): Convert numpy array to tfrecords(2,3,4): read the tfrecords to generate batches。从numpy的阵列tfrecords的

1. 创建:

def npy_to_tfrecords(...): 
     # write records to a tfrecords file 
     writer = tf.python_io.TFRecordWriter(output_file) 

     # Loop through all the features you want to write 
     for ... : 
      let say X is of np.array([[...][...]]) 
      let say y is of np.array[[0/1]] 

     # Feature contains a map of string to feature proto objects 
     feature = {} 
     feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten())) 
     feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=y)) 

     # Construct the Example proto object 
     example = tf.train.Example(features=tf.train.Features(feature=feature)) 

     # Serialize the example to a string 
     serialized = example.SerializeToString() 

     # write the serialized objec to the disk 
     writer.write(serialized) 
     writer.close() 

2. 使用DataSet API阅读tfrecords(tensorflow> = 1.2):

# Creates a dataset that reads all of the examples from filenames. 
    filenames = ["file1.tfrecord", "file2.tfrecord", ..."fileN.tfrecord"] 
    dataset = tf.contrib.data.TFRecordDataset(filenames) 

    # example proto decode 
    def _parse_function(example_proto): 
     keys_to_features = {'X':tf.FixedLenFeature((shape_of_npy_array), tf.float32), 
          'y': tf.FixedLenFeature((), tf.int64, default_value=0)} 
     parsed_features = tf.parse_single_example(example_proto, keys_to_features) 
    return parsed_features['X'], parsed_features['y'] 

    # Parse the record into tensors. 
    dataset = dataset.map(_parse_function) 

    # Shuffle the dataset 
    dataset = dataset.shuffle(buffer_size=10000) 

    # Repeat the input indefinitly 
    dataset = dataset.repeat() 

    # Generate batches 
    dataset = dataset.batch(batch_size) 

    # Create a one-shot iterator 
    iterator = dataset.make_one_shot_iterator() 

    # Get batch X and y 
    X, y = iterator.get_next() 
+0

啊,我非常感谢您的详细回答!你救了我的命! – mining

+0

嗨,先生,这个api是否支持'tf.train.shuffle_batch' api中的'num_threads'或'capacity'?在我的情况下,如果网络很小,那么GPU中的执行速度要比数据加载速度快,这会导致GPU空闲时间。所以我想排队取数据总是满的。谢谢。 – mining

+2

检查:https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map –

相关问题