2017-05-13 54 views
2

我在我的java代码中使用了weka中的LibSVM。我正在尝试做一个回归。下面是我的代码,Java,weka LibSVM不能正确预测

public static void predict() { 

    try { 
     DataSource sourcePref1 = new DataSource("train_pref2new.arff"); 
     Instances trainData = sourcePref1.getDataSet(); 

     DataSource sourcePref2 = new DataSource("testDatanew.arff"); 
     Instances testData = sourcePref2.getDataSet(); 

     if (trainData.classIndex() == -1) { 
      trainData.setClassIndex(trainData.numAttributes() - 2); 
     } 

     if (testData.classIndex() == -1) { 
      testData.setClassIndex(testData.numAttributes() - 2); 
     } 

     LibSVM svm1 = new LibSVM(); 

     String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1"); 
     String[] optionsArray = options.split(" "); 
     svm1.setOptions(optionsArray); 

     svm1.buildClassifier(trainData); 

     for (int i = 0; i < testData.numInstances(); i++) { 

      double pref1 = svm1.classifyInstance(testData.instance(i));     
      System.out.println("predicted value : " + pref1); 

     } 

    } catch (Exception ex) { 
     Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex); 
    } 
} 

但预测值我从这个代码得到的是比预测值,我使用的Weka GUI越来越不同。

示例: 下面是我为java代码和weka GUI提供的单个测试数据。

的Java代码的预测值作为1.9064516129032265而Weka的GUI的预测值是10.043。我为Java代码和Weka GUI使用相同的训练数据集和相同的参数。

我希望你明白我的问题。有人告诉我我的代码有什么问题吗?

回答

2

您正在使用错误的算法执行SVM回归。 LibSVM用于分类。你想要的那个是SMOreg,这是一个用于回归的特定SVM。

下面是一个完整的示例,显示如何使用SMOreg使用Weka Explorer GUI以及Java API。对于数据,我将使用Weka发行版附带的cpu.arff数据文件。请注意,我将使用此文件进行培训和测试,但理想情况下,您将拥有单独的数据集。

使用Weka的浏览器GUI

  1. 打开资源管理器WEKA GUI,单击Preprocess选项卡上,单击Open File,然后打开cpu.arff文件应该在你的Weka的分布。在我的系统上,该文件在weka-3-8-1/data/cpu.arff之下。资源管理器窗口应如下所示:

Weka Explorer - Choosing the file

  • 点击Classify标签。它应该被称为“预测”,因为您可以在这里进行分类和回归。在Classifier下,点击Choose,然后选择wekaclassifiersfunctionsSMOreg,如下所示。
  • Weka Explorer - Choosing the regression algorithm

  • 现在生成回归模型和评估它。在Test Options下选择Use training set,这样我们的训练集也用于测试(正如我在上面提到的,这不是理想的方法)。现在按Start,结果应该如下所示:
  • Weka Explorer - Results from testing

    记下RMSE值(74.5996)的。我们将在Java代码实现中重新讨论这一点。

    使用Java API

    下面是使用Weka的API来复制在Weka的浏览器GUI前面显示的结果一个完整的Java程序。

    import weka.classifiers.functions.SMOreg; 
    import weka.classifiers.Evaluation; 
    import weka.core.Instance; 
    import weka.core.Instances; 
    import weka.core.converters.ConverterUtils.DataSource; 
    
    public class Tester { 
    
        /** 
        * Builds a regression model using SMOreg, the SVM for regression, and 
        * evaluates it with the Evalution framework. 
        */ 
        public void buildAndEvaluate(String trainingArff, String testArff) throws Exception { 
    
         System.out.printf("buildAndEvaluate() called.\n"); 
    
         // Load the training and test instances. 
         Instances trainingInstances = DataSource.read(trainingArff); 
         Instances testInstances = DataSource.read(testArff); 
    
         // Set the true value to be the last field in each instance. 
         trainingInstances.setClassIndex(trainingInstances.numAttributes()-1); 
         testInstances.setClassIndex(testInstances.numAttributes()-1); 
    
         // Build the SMOregression model. 
         SMOreg smo = new SMOreg(); 
         smo.buildClassifier(trainingInstances); 
    
         // Use Weka's evaluation framework. 
         Evaluation eval = new Evaluation(trainingInstances); 
         eval.evaluateModel(smo, testInstances); 
    
         // Print the options that were used in the ML algorithm. 
         String[] options = smo.getOptions(); 
         System.out.printf("Options used:\n"); 
         for (String option : options) { 
          System.out.printf("%s ", option); 
         } 
         System.out.printf("\n\n"); 
    
         // Print the algorithm details. 
         System.out.printf("Algorithm:\n %s\n", smo.toString()); 
    
         // Print the evaluation results. 
         System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false)); 
        } 
    
        /** 
        * Builds a regression model using SMOreg, the SVM for regression, and 
        * tests each data instance individually to compute RMSE. 
        */ 
        public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception { 
    
         System.out.printf("buildAndTestEachInstance() called.\n"); 
    
         // Load the training and test instances. 
         Instances trainingInstances = DataSource.read(trainingArff); 
         Instances testInstances = DataSource.read(testArff); 
    
         // Set the true value to be the last field in each instance. 
         trainingInstances.setClassIndex(trainingInstances.numAttributes()-1); 
         testInstances.setClassIndex(testInstances.numAttributes()-1); 
    
         // Build the SMOregression model. 
         SMOreg smo = new SMOreg(); 
         smo.buildClassifier(trainingInstances); 
    
         int numTestInstances = testInstances.numInstances(); 
    
         // This variable accumulates the squared error from each test instance. 
         double sumOfSquaredError = 0.0; 
    
         // Loop over each test instance. 
         for (int i = 0; i < numTestInstances; i++) { 
    
          Instance instance = testInstances.instance(i); 
    
          double trueValue = instance.value(testInstances.classIndex()); 
          double predictedValue = smo.classifyInstance(instance); 
    
          // Uncomment the next line to see every prediction on the test instances. 
          //System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue); 
    
          double error = trueValue - predictedValue; 
          sumOfSquaredError += (error * error); 
         } 
    
         // Print the RMSE results. 
         double rmse = Math.sqrt(sumOfSquaredError/numTestInstances); 
         System.out.printf("RMSE = %10.5f\n", rmse); 
        } 
    
        public static void main(String argv[]) throws Exception { 
    
         Tester classify = new Tester(); 
         classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff"); 
         classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff"); 
        } 
    } 
    

    我已经写了训练的SMOreg模型和训练数据运行预测评估模型两种功能。

    • buildAndEvaluate()使用的Weka Evaluation框架运行一系列测试,以得到完全相同的 结果作为资源管理GUI评估模型。值得注意的是,它产生了一个RMSE值。

    • buildAndTestEachInstance()评估由明确 遍历每个测试实例,进行预测,计算 误差,并且计算总的RMSE模型。请注意,此RMSE匹配 从buildAndEvaluate()开始的一个,后者与Explorer GUI中的 匹配。

    下面是编译和运行程序的结果。

    prompt> javac -cp weka.jar Tester.java 
    
    prompt> java -cp .:weka.jar Tester 
    
    buildAndEvaluate() called. 
    Options used: 
    -C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007 
    
    Algorithm: 
    SMOreg 
    
    weights (not support vectors): 
    +  0.01 * (normalized) MYCT 
    +  0.4321 * (normalized) MMIN 
    +  0.1847 * (normalized) MMAX 
    +  0.1175 * (normalized) CACH 
    +  0.0973 * (normalized) CHMIN 
    +  0.0235 * (normalized) CHMAX 
    -  0.0168 
    
    
    
    Number of kernel evaluations: 21945 (93.081% cached) 
    
    Results 
    ===== 
    
    Correlation coefficient     0.9044 
    Mean absolute error      31.7392 
    Root mean squared error     74.5996 
    Relative absolute error     33.0908 % 
    Root relative squared error    46.4953 % 
    Total Number of Instances    209  
    
    buildAndTestEachInstance() called. 
    RMSE = 74.59964 
    
    +0

    其实Libsvm有2个SVM类型的回归,nu-SVR和epsilon-SVR。通过定义算法的-S参数,我可以决定使用哪种svm类型。在我的代码中,我使用了epsilon-SVR(-S 3)。但是你的代码确实帮我找到了我的代码中的错误。 setClassIndex在我的代码中是错误的。我用你的代码,它的工作。非常感谢您的帮助。 – udi