我正在使用(encog 3.3.0库)构建用于图像识别的神经网络。为了避免混淆我的神经网络,我将图像转换为50x50灰度,因为我基本上想从图像中进行一些与颜色无关的特征提取。我有两个输出类。Encog ::非收敛错误率
我的输入::一个CSV文件,其中包含318行,每行有2502列。每行对应一个图像。前2500列是图像的50x50像素,后2列是输出类。 输入有159行,其中2500个正常图像像素,然后1,0输出和159行,其中有2500个正常图像像素,然后0,1作为输出。 0意味着它不属于,1意味着它属于那个类。
我的输入:: 318行和2502列。以下是其中的一行::
255,243,251,255,244,255,235,67,51,52,53,54,54,54,53,53,53,54,55,55 ..... 53, 54,54,53,53,52,54,54,54,54,54,54,54,54,57,57,57,57,57,57,57,57,57,57,0,1
最后的0,1代表输出类。
我的图层::我有3个图层。具有2500个神经元的输入层,具有1000个神经元的隐藏层和具有2个神经元的输出层。
问题:当我开始以0.7的学习率和0.8的动量训练网络时,即使经过100次迭代,错误率也不会收敛并持续在0.45-0.5左右振荡。
下面是我的代码::
公共类image_recognition {
static final int COLUMNS = 2500;
static final int OUTPUT = 2;
public BasicNetwork network;
public double[][] input;
public double[][] ideal;
public MLDataSet trainingSet;
public void createNetwork() {
network = new BasicNetwork();
//simpleFeedForward(int input, int hidden1, int hidden2, int output, boolean tanh)
network = EncogUtility.simpleFeedForward(image_recognition.COLUMNS, 1000, 0, image_recognition.OUTPUT, false);
network.reset();
}
public void train() {
//BasicMLDataSet(double[][] input, double[][] ideal)
trainingSet = new BasicMLDataSet(input, ideal);
//Backpropagation(ContainsFlat network, MLDataSet training, double learnRate, double momentum)
final Backpropagation train = new Backpropagation(network, trainingSet, 0.7, 0.8);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
long time = System.currentTimeMillis();
System.out.println("after iteration time :: ");
System.out.println(time);
epoch++;
} while ((epoch < 5000) && (train.getError() > 0.3));
}
public double evaluate() {
System.out.println("Neural Network Results:");
for(MLDataPair pair: trainingSet) {
final MLData output = network.compute(pair.getInput());
String actualoutput1 = String.format("%.6f", output.getData(0));
String idealoutput1 = String.format("%.1f", pair.getIdeal().getData(0));
String actualoutput2 = String.format("%.6f", output.getData(1));
String idealoutput2 = String.format("%.1f", pair.getIdeal().getData(1));
System.out.println("actual1 = " + actualoutput1 + ", actual2 = " + actualoutput2 + " ,ideal1 = " + idealoutput1 + " ,ideal2 = " + idealoutput2 );
}
return 0;
}
public void load(String filename) throws IOException {
int size = 0;
ReadCSV csv;
csv = new ReadCSV(filename, false, CSVFormat.DECIMAL_POINT);
while (csv.next()) {
size++;
}
csv.close();
// allocate enough space
input = new double[size][image_recognition.COLUMNS];
ideal = new double[size][image_recognition.OUTPUT];
// now load it
int index = 0;
csv = new ReadCSV(filename, false, CSVFormat.DECIMAL_POINT);
while (csv.next()) {
for(int i=0;i<image_recognition.COLUMNS;i++)
{
input[index][i] = Double.parseDouble(csv.get(i));
}
for(int i=0;i<image_recognition.OUTPUT;i++)
{
ideal[index][i] = Double.parseDouble(csv.get(image_recognition.COLUMNS+i));
}
index++;
}
csv.close();
}
public static void main(final String args[]) {
try {
image_recognition prg = new image_recognition();
long b1 = System.currentTimeMillis();
System.out.println("before loading time :: ");
System.out.println(b1);
prg.load("mycsv.csv");
long a1 = System.currentTimeMillis();
System.out.println("after loading, before creating network time :: ");
System.out.println(a1);
prg.createNetwork();
long a2 = System.currentTimeMillis();
System.out.println("after creating network, before training time :: ");
System.out.println(a2);
prg.train();
long a3 = System.currentTimeMillis();
System.out.println("after training, before testing time :: ");
System.out.println(a3);
prg.evaluate();
} catch (Throwable t) {
t.printStackTrace();
}
}
}
我的输出::
大纪元#1错误:0.48833917036172103
大纪元# 2错误:0.5
历元#3错误:0.5
历元#4错误:0.5
历元#5错误:0.45956570930539425
.........
历元#23错误:0.4744859426599884
历元#24错误:0.5
历元#25错误:0.5
...........
历元#49错误:0.5912731593753425
历元#50错误:0.5
历元#51错误:0.5031968130459842
.. .........
历元#71错误:0.5046318360708989
历元#72错误:0.49357338328109024
历元#73错误:0.486820369587797
...........
历元#103错误:0.5155249407683976
Epoch#104错误:0.4835673679113441
大纪元#105错误:0.49407335871268354
.........
大纪元#142错误:0.49038913805594664
大纪元#143错误:0.4660191340060382
请指引我为什么错误率不会收敛。我试着运行它以获得更多迭代,但仍然不会收敛。我需要的错误是至少0.1。
您尝试过哪些网络拓扑结构?很可能你的网络布局不适合这个问题。 –