2014-11-20 43 views
7

我正在训练一个神经网络,大约有五千兆字节的数据存储为numpy阵列。数据被分成100000行的块,并且我已经以随机顺序对所有块进行了6个循环的训练。不幸的是,网络已经开始过度适应。我认为它仍然有能力更密切地拟合数据;我怀疑每个组块内部的规律性开始互相矛盾,我需要更彻底地对数据进行洗牌,以便能够以不同的组合进行训练。我想在尝试获取更多培训数据之前尝试一下。统一洗牌5千兆字节的numpy数据

有没有人知道一个很好的方式来产生360万(非常长)的numpy数据行的新排列?我想过使用one of these技术,但使用numpy.savetxt编写这些阵列产生令人难以置信的巨大的文件,我不能告诉如何操纵标准npy文件中的单个行,有助于解决此问题。

现在,我最好的想法是在数据中创建成对索引(c, r)的置换,其中c选择一个块,r从该块中选择一行。我可以将每行存储在一个新的预分配数组中,然后保存它。但我想知道是否有一个不太可怕的I/O绑定解决方案。有没有一些原则性的方法可以将随机的块组合在一起,直到你得到一个统计上独立于开始置换的置换?

+0

你可以捕鱼yates的行,然后fisher yates的列?既然你只是交换单独的行/列,它不应该滥用你的记忆。如果速度是问题,你可以把它作为C扩展(你需要大量的掉期使它实际上是随机的)。 – 2014-11-20 21:48:59

+0

对不起,我不清楚 - 我不需要洗牌列,只是行。只是没有好的方法将它全部加载到内存中,也不是一些更明显的基于磁盘的方法。 – senderle 2014-11-20 22:25:42

回答

6

迄今为止我尝试过的东西中,PyTables解决方案当前是最好的,其次是使用numpy支持memmapped数组的解决方案。但PyTables解决方案并不简单。如果您使用整数的整数数组来直接索引PyTables数组,它非常缓慢。更快的是以下两步过程:

  1. 使用布尔索引数组选择数组的一个随机子集。 这必须以块状方式完成。如果将索引数组直接传递给PyTables数组,则速度很慢。
    • 预分配一个numpy数组并创建一个将PyTables数组分割成块的切片列表。
    • 将每个块完全读取到中,然后使用索引数组的相应块为该块选择正确的值。
    • 将选定的值存储在预分配的数组中。
  2. 然后对预分配的数组进行洗牌。

这个过程产生一个像普通混洗过程一样随机的置换。如果这看起来不明显,请考虑这一点:​​。这种方法足够快,可以在每个训练周期进行随机洗牌。它也能够将数据压缩至〜650M--几乎达到90%的通货紧缩。

这是我目前的实施;这对于语料库中的每个训练块都被调用一次。 (返回数组别处洗牌。)

def _h5_fast_bool_ix(self, h5_array, ix, read_chunksize=100000): 
    '''Iterate over an h5 array chunkwise to select a random subset 
    of the array. `h5_array` should be the array itself; `ix` should 
    be a boolean index array with as many values as `h5_array` has 
    rows; and you can optionally set the number of rows to read per 
    chunk with `read_chunksize` (default is 100000). For some reason 
    this is much faster than using `ix` to index the array directly.''' 

    n_chunks = h5_array.shape[0]/read_chunksize 
    slices = [slice(i * read_chunksize, (i + 1) * read_chunksize) 
       for i in range(n_chunks)] 

    a = numpy.empty((ix.sum(), h5_array.shape[1]), dtype=float) 
    a_start = 0 
    for sl in slices: 
     chunk = h5_array[sl][ix[sl]] 
     a_end = a_start + chunk.shape[0] 
     a[a_start:a_end] = chunk 
     a_start = a_end 

    return a 

这有点疯狂,我认为为O(n^2)方法(遍历整个PyTables阵列,每块),在这种情况下速度比的O( n)方法(随机选择一行中的每一行)。但是,嘿,它的作品。稍微间接一点,这可以适应加载任意非随机排列,但是这增加了比它在这里值得的复杂性。

mmap解决方案仅供参考,对于那些因任何原因需要纯粹numpy解决方案的人员。它在大约25分钟内洗牌所有数据,而上述解决方案在不到一半的时间内管理相同数据。这也应该线性扩展,因为mmap允许(相对)高效的随机访问。

import numpy 
import os 
import random 

X = [] 
Y = [] 

for filename in os.listdir('input'): 
    X.append(numpy.load(os.path.join('input', filename), mmap_mode='r')) 

for filename in os.listdir('output'): 
    Y.append(numpy.load(os.path.join('output', filename), mmap_mode='r')) 

indices = [(chunk, row) for chunk, rows in enumerate(X) 
         for row in range(rows.shape[0])] 
random.shuffle(indices) 

newchunks = 50 
newchunksize = len(indices)/newchunks 

for i in range(0, len(indices), newchunksize): 
    print i 
    rows = [X[chunk][row] for chunk, row in indices[i:i + newchunksize]] 
    numpy.save('X_shuffled_' + str(i), numpy.array(rows)) 
    rows = [Y[chunk][row] for chunk, row in indices[i:i + newchunksize]] 
    numpy.save('Y_shuffled_' + str(i), numpy.array(rows)) 
0

以下假设您的数据已经被分成了某种可轻松检索的记录。 (我不知道是否有用于numpy数据的标准文件格式。)

  1. dict的形式创建数据的索引,通过ñ映射每个唯一的记录ID(0 - 1 )再次找到数据的一些手段。例如,如果它全部位于一个二进制文件中,则会存储形式为(file_offset, record_length)的元组。没有必要坚持数据本身。

  2. 创建的Ñ元素的列表,包含索引dict的键(再次,0至Ñ - 1)。

  3. 随机播放记录ID列表。 (如果需要,请提供您自己的随机数生成器。)

  4. 打开包含混洗数据的新文件(或其他)。

  5. 从开始到结束从列表中读取记录ID。对于每个记录ID,请在索引中查找该记录的位置。在该位置抓取数据并将其附加到输出文件。

伪代码:

# This assumes a binary file of unequal-length 
# records. It also assumes that the file won't 
# be changed while we're doing this. 

# Create index. 
index = {} 
rec_offset = 0 
for rec_id, record in original_data.iterate_records(): 
    # This bit depends greatly on how your data 
    # is stored... 
    rec_length = len(record) 
    index[rec_id] = (rec_offset, rec_length) 
    rec_offset += rec_length 

# Shuffle. 
num_records_indexed = rec_id + 1 # rec_id is still in scope. 
records_order = list(range(num_records_indexed)) 
records_order = random.shuffle(records_order, "<optional_RNG_here>") 

# Create new shuffled-data file. 
with open("output_file.bin", "wb") as output: 
    for rec_id in records_order: 
     rec_offset, rec_length = index[rec_id] 
     record = original_data.get_rec_at(rec_offset, rec_length) 
     output.write(record) 

索引,重排,和去索引都是O(Ñ),所以最糟糕的部分应该是I/O:读取所述数据,然后复制它(第二次阅读,再加上写)。