2012-10-10 20 views
2

我写了Pybrain神经网络的这个简单测试,但它并没有像我期望的那样工作。这个想法是训练一个数字集到4095的数据集,其中包含素数和非素数。神经网络对不同的激活报告相同的响应

#!/usr/bin/env python 
# A simple feedforward neural network that attempts to learn Primes 

from pybrain.datasets import ClassificationDataSet 
from pybrain.tools.shortcuts import buildNetwork 
from pybrain.supervised import BackpropTrainer 

class PrimesDataSet(ClassificationDataSet): 
    """ A dataset for primes """ 

    def generatePrimes(self, n): 
     if n == 2: 
      return [2] 
     elif n < 2: 
      return [] 
     s = range(3, n + 1, 2) 
     mroot = n ** 0.5 
     half = (n + 1)/2 - 1 
     i = 0 
     m = 3 
     while m <= mroot: 
      if s[i]: 
       j = (m * m - 3)/2 
       s[j] = 0 
       while j < half: 
        s[j] = 0 
        j += m 
      i = i + 1 
      m = 2 * i + 3 
     return [2] + [x for x in s if x] 

    def binaryString(self, n): 
     return "{0:12b}".format(n) 

    def __init__(self): 
     ClassificationDataSet.__init__(self, 12, 1) 
     primes = self.generatePrimes(4095) 
     for prime in primes: 
      b = self.binaryString(prime).split() 
      self.addSample(b, [1]) 
     for n in range(4095): 
      if n not in primes: 
       b = self.binaryString(n).split() 
       self.addSample(b, [0]) 

def testTraining(): 
    d = PrimesDataSet() 
    d._convertToOneOfMany() 
    n = buildNetwork(d.indim, 12, d.outdim, recurrent=True) 
    t = BackpropTrainer(n, learningrate = 0.01, momentum = 0.99, verbose = True) 
    t.trainOnDataset(d, 100) 
    t.testOnData(verbose=True) 
    print "Is 7 prime? ", n.activate(d.binaryString(7).split()) 
    print "Is 6 prime? ", n.activate(d.binaryString(6).split()) 
    print "Is 100 prime? ", n.activate(d.binaryString(100).split()) 


if __name__ == '__main__': 
    testTraining() 

直索(请)这是否甚至有可能的问题,我的问题是,7,6,和100的最后三个报表打印测试是否是质都返回相同的:

Is 7 prime? [ 0.34435841 0.65564159] 
Is 6 prime? [ 0.34435841 0.65564159] 
Is 100 prime? [ 0.34435841 0.65564159] 

(或类似的东西) 我解释这些结果的方式是神经网络以65%的确定性预测这些数字中的每一个是是质数。我的神经网络学会了如何处理所有的输入,或者我做错了什么?

回答

0

看起来你实际上只使用一个输入。

d.binaryString(7).split() 

相当于

"{0:12b}".format(7).split() 

计算结果为

['111']. 

我想你打算是像

[int(c) for c in "{0:012b}".format(7)] 

其结果是

[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1] 

P.S.检查你输入统计模型到底是什么总是个好主意:)

+0

谢谢!最简单的事情 - 我甚至没想过要检查。 – lambda