2017-02-03 67 views
3

培训我在大数据集上使用Keras(使用MagnaTagATune数据集进行音乐自动标记)。所以我尝试使用fit_generator()函数与自定义数据生成器。但是在培训过程中损失函数和指标的价值不会改变。看起来我的网络根本没有训练。Keras:网络不使用fit_generator()

当我使用fit()函数而不是fit_generator()时,一切正常,但我无法将整个数据集保存在内存中。

我试图既Theano和TensorFlow后端

主代码:

if __name__ == '__main__': 
    model = models.FCN4() 
    model.compile(optimizer='adam', 
        loss='binary_crossentropy', 
        metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall']) 
    gen = mttutils.generator_v2(csv_path, melgrams_dir) 
    history = model.fit_generator(gen.generate(0,750), 
            samples_per_epoch=750, 
            nb_epoch=80, 
            validation_data=gen.generate(750,1000,False), 
            nb_val_samples=250) 
    # RESULTS SAVING 
    np.save(output_history, history.history) 
    model.save(output_model) 

类generator_v2:

genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast', 
     'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing', 
     'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal', 
     'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 
     'female vocal', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', 'choral'] 

def __init__(self, csv_path, melgrams_dir): 

    def get_dict_vals(dictionary, keys): 
     vals = [] 
     for key in keys: 
      vals.append(dictionary[key]) 
     return vals 

    self.melgrams_dir = melgrams_dir 
    with open(csv_path, newline='') as csvfile: 
     reader = csv.DictReader(csvfile, dialect='excel-tab') 
     self.labels = [] 
     for row in reader: 
      labels_arr = np.array(get_dict_vals(
       row, self.genres)).astype(np.int) 
      labels_arr = labels_arr.reshape((1, labels_arr.shape[0])) 
      if (np.sum(labels_arr) > 0): 
       self.labels.append((row['mp3_path'], labels_arr)) 
     self.size = len(self.labels) 


def generate(self, begin, end): 
    while(1): 
     for count in range(begin, end): 
      try: 
       item = self.labels[count] 
       mels = np.load(os.path.join(
        self.melgrams_dir, item[0] + '.npy')) 
       tags = item[1] 
       yield((mels, tags)) 
      except FileNotFoundError: 
       continue 

为了制备用于配合阵列()函数我用这个代码:

def TEST_get_data_array(csv_path, melgrams_dir): 
    gen = generator_v2(csv_path, melgrams_dir).generate(0,100) 
    item = next(gen) 
    x = np.array(item[0]) 
    y = np.array(item[1]) 
    for i in range(0,100): 
     item = next(gen.training) 
     x = np.concatenate((x,item[0]),axis = 0) 
     y = np.concatenate((y,item[1]),axis = 0) 
    return(x,y) 

对不起,如果我的代码风格不好。谢谢你!

UPD 1: 我试着使用return(X,y),而不是yield(X,y)但没有任何变化。我的新发电机类的

部分:

def generate(self): 
    if((self.count < self.begin) or (self.count >= self.end)): 
     self.count = self.begin 
    item = self.labels[self.count] 
    mels = np.load(os.path.join(self.melgrams_dir, item[0] + '.npy')) 
    tags = item[1] 
    self.count = self.count + 1 
    return((mels, tags)) 

def __next__(self): # fit_generator() uses this method 
    return self.generate() 

fit_generator电话:

history = model.fit_generator(tr_gen, 
           samples_per_epoch = tr_gen.size, 
           nb_epoch = 120, 
           validation_data = val_gen, 
           nb_val_samples = val_gen.size) 

日志:

Epoch 1/120 
10554/10554 [==============================] - 545s - loss: 1.7240 - acc: 0.8922 
Epoch 2/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 3/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 4/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
... etc (loss is always 1.8922; acc is always 0.8820) 
+0

在'用于范围(开始,结束)'之前,您可能会洗牌您的数据。 –

回答

2

我有同样的问题,因为你与产量的方法。所以我只是存储了当前的索引,并用return语句返回了一个批处理。

所以我只是用return (X, y)而不是yield (X,y)它工作。我不确定这是为什么。如果有人能够阐明这一点,这将是很酷的。

编辑: 您需要将生成器传递给该函数,而不仅仅调用该函数。类似这样的:

model.fit_generator(gen, samples_per_epoch=750, 
            nb_epoch=80, 
            validation_data=gen, 
            nb_val_samples=250) 

Keras会在调用数据的同时调用您的__next__函数。

+0

我试过了,但没有任何变化。请检查我是否正确理解你(我的代码与'return'语句是在主文章的末尾)。谢谢! – Ladislao

+0

应该像这样传递发电机时工作。如果不是,你可以发布你的错误信息? –

+0

是的,我将我的生成器传递到'fit_generator'函数中。没有例外或错误。问题是在培训过程中损失函数的价值没有变化(我已经将日志添加到主要职位)。它看起来像网络不刷新它的权重。这在模型中不是一个错误,因为'fit'函数(使用数组而不是生成器)可以正常工作。 – Ladislao

0

在'生成'方法中,有一个while语句。

def generate(self, begin, end): 
    while(1): # this 
     for count in range(begin, end): 
      try: 
       # something 
       yield(...) 

      except FileNotFoundError: 
       continue 

我觉得不需要这种说法,所以

def generate(self, begin, end): 
    for count in range(begin, end): 
     try: 
      # something 
      yield(...) 

     except FileNotFoundError: 
      continue 
+0

它提出了一个例外: '文件 “/usr/local/lib/python3.4/dist-packages/keras/engine/training.py”,线1528,在fit_generator STR(generator_output)) ValueError异常:输出生成器应该是一个元组(x,y,sample_weight)或(x,y)。发现:没有' 发电机必须是无止境的,因为它必须在下一个时代返回相同批次的数据 – Ladislao

+0

我明白了,对不起,我感激不尽。 – hmm

相关问题