2010-05-03 90 views
4

我使用的是libsvm,文档使我相信有一种方法可以输出输出分类精度的相信概率。这是吗?如果是这样,任何人都可以提供一个清晰的例子来说明如何在代码中做到这一点?如何计算使用libsvm进行多类预测的概率?

目前,我使用的Java库通过以下方式

SvmModel model = Svm.svm_train(problem, parameters); 
    SvmNode x[] = getAnArrayOfSvmNodesForProblem(); 
    double predictedValue = Svm.svm_predict(model, x); 

回答

7

鉴于你的代码片段,我会假设你希望使用libSVM打包的Java API,而不是更详细的由jlibsvm提供。

为了能够与概率估计预测,与svm_parameter字段训练模型probability设置为1。然后,只需更改您的代码,以便它调用svm方法svm_predict_probability而不是svm_predict

修改你的片段中,我们有:

parameters.probability = 1; 
svm_model model = svm.svm_train(problem, parameters); 

svm_node x[] = problem.x[0]; // let's try the first data pt in problem 
double[] prob_estimates = new double[NUM_LABEL_CLASSES]; 
svm.svm_predict_probability(model, x, prob_estimates); 

值得知道训练多类概率估计可以改变预测由分类制成。有关更多信息,请参阅问题Calculating Nearest Match to Mean/Stddev Pair With LibSVM

+0

@dmcer哪个包具有较小的学习曲线(Java API的使用LIBSVM或jlibsvm包装)?一般来说,我是SVM的新手。 – GobiasKoffi 2010-10-15 14:49:35

+0

@rohanbk - 可能是jlibsvm,因为它看起来和感觉就像一个典型的Java API。 – dmcer 2010-10-15 23:20:13

+0

@dmcer您是否有使用WEKA进行SVM的经验? – GobiasKoffi 2010-10-16 05:26:43

1

接受的答案就像一个魅力。确保在训练期间设置probability = 1

如果你想在信心没有与门限符合下降的预测,这里是代码示例:

double confidenceScores[] = new double[model.nr_class]; 
svm.svm_predict_probability(model, svmVector, confidenceScores); 

/*System.out.println("text="+ text); 
for (int i = 0; i < model.nr_class; i++) { 
    System.out.println("i=" + i + ", labelNum:" + model.label[i] + ", name=" + classLoadMap.get(model.label[i]) + ", score="+confidenceScores[i]); 
}*/ 

//finding max confidence; 
int maxConfidenceIndex = 0; 
double maxConfidence = confidenceScores[maxConfidenceIndex]; 
for (int i = 1; i < confidenceScores.length; i++) { 
    if(confidenceScores[i] > maxConfidence){ 
     maxConfidenceIndex = i; 
     maxConfidence = confidenceScores[i]; 
    } 
} 

double threshold = 0.3; // set this based data & no. of classes 
int labelNum = model.label[maxConfidenceIndex]; 
// reverse map number to name 
String targetClassLabel = classLoadMap.get(labelNum); 
LOG.info("classNumber:{}, className:{}; confidence:{}; for text:{}", 
     labelNum, targetClassLabel, (maxConfidence), text); 
if (maxConfidence < threshold) { 
    LOG.info("Not enough confidence; threshold={}", threshold); 
    targetClassLabel = null; 
} 
return targetClassLabel;