2013-11-28 67 views
0

我想用FANN近似平方函数。代码如下:FANN没有正确训练

#include "../FANN-2.2.0-Source/src/include/doublefann.h" 
#include "../FANN-2.2.0-Source/src/include/fann_cpp.h" 
#include <cstdlib> 
#include <iostream> 

using namespace std; 
using namespace FANN; 

//Remember: fann_type is double! 
int main(int argc, char** argv) { 
    //create a test network: [1,2,1] MLP 
    neural_net * net = new neural_net; 
    const unsigned int layers[3] = {1,3,1}; 
    net->create_standard_array(3,layers); 

    //net->create_standard(num_layers, num_input, num_hidden, num_output); 

    net->set_learning_rate(0.7f); 

    net->set_activation_steepness_hidden(0.7); 
    net->set_activation_steepness_output(0.7); 

    net->set_activation_function_hidden(SIGMOID_SYMMETRIC_STEPWISE); 
    net->set_activation_function_output(SIGMOID_SYMMETRIC_STEPWISE); 
    net->set_training_algorithm(TRAIN_QUICKPROP); 

    //cout<<net->get_train_error_function() 
    //exit(0); 
    //test the number 2 
    fann_type * testinput = new fann_type; 
    *testinput = 2; 
    fann_type * testoutput = new fann_type; 
    *testoutput = *(net->run(testinput)); 
    double outputasdouble = (double) *testoutput; 
    cout<<"Test output: "<<outputasdouble<<endl; 

    //make a training set of x->x^2 
    training_data * squaredata = new training_data; 
    squaredata->read_train_from_file("trainingdata.txt"); 

    net->train_on_data(*squaredata,1000,100,0.001); 

    cout<<endl<<"Easy!"; 
    return 0; 
} 

trainingdata.txt是这样的:

10 1 1 
1 1 
2 4 
3 9 
4 16 
5 25 
6 36 
7 49 
8 64 
9 81 
10 100 

我觉得我所做的一切都是正确的使用API​​。但是,当我运行它时,我会遇到一个巨大的错误,这个错误在训练中似乎不会减少。

Test output: -0.0311087 
Max epochs  1000. Desired error: 0.0010000000. 
Epochs   1. Current error: 633.9928588867. Bit fail 10. 
Epochs   100. Current error: 614.3250122070. Bit fail 9. 
Epochs   200. Current error: 614.3250122070. Bit fail 9. 
Epochs   300. Current error: 614.3250122070. Bit fail 9. 
Epochs   400. Current error: 614.3250122070. Bit fail 9. 
Epochs   500. Current error: 614.3250122070. Bit fail 9. 
Epochs   600. Current error: 614.3250122070. Bit fail 9. 
Epochs   700. Current error: 614.3250122070. Bit fail 9. 
Epochs   800. Current error: 614.3250122070. Bit fail 9. 
Epochs   900. Current error: 614.3250122070. Bit fail 9. 
Epochs   1000. Current error: 614.3250122070. Bit fail 9. 

Easy! 

我做错了什么?

回答

3

如果对输出图层使用sigmoid函数,则输出将提供(0,1)的范围。

您可能有两种选择,(1)将所有的输出分成一个常量,比如1e4。当测试数据出现时,您还可以将其除以1e4。问题在于,你可能无法预测大于100的平方数(100^2 = 1e4);(2)使隐藏层和输出层都成为线性的,并且网络将自动学习权重以给出你拥有的任何输出值。

+0

但是,如果我缩放,会不会改变数据,从而破坏近似? –

+0

你可能有两种选择,(1)将所有的输出分成一个常量,比如1e4。当测试数据出现时,您还可以将其除以1e4。问题在于,你可能无法预测大于100的平方数(100^2 = 1e4);(2)使隐藏层和输出层都成为线性的,并且网络将自动学习权重以给出你拥有的任何输出值。 – lennon310

+0

当你说“使隐藏层和输出层都是线性的”时,你的意思是激活功能,对吧? –