2016-04-29 54 views
0

我试图重新创建这个图像使用Python给出2类和他们从分类器相关的预测概率。如何在Scikitlearn中绘制S形概率曲线?

我希望看到这样的事情: sigmoid curve

它不工作,虽然,因为我得到了大部分直线。 **注:我知道显示的这些数据目前可疑和/或不好。我需要调整输入&模型,但想看看情节

基本上,我认为我会“改正”predict_proba()输出,所以他们都是关于“0”类(意思是如果它预测为“1”类,它是“0”类的概率是1-(1classProbability),使得95%预测它是类“1”变成5%变为类“0”。然后按照我的修正。predicition价值的东西乙状结肠十岁上下的最终

不幸的是,我结束了这一点: enter image description here

这里的我的蟒蛇在那里我试图(失败)的一大块绘制概率乙状结肠:

########################### 
## I removed my original Python code because it was very, very wrong so as to avoid any confusion. 
########################### 

仅供参考,下面是在Matlab我想要在我的Python模型复制的情节。

%Build the model 
mdl = fitglm(X, Y, 'distr', 'binomial', 'link', 'logit') 
%Build the sigmoid model 
B = mdl.Coefficients{:, 1}; 
Z = mdl.Fitted.LinearPredictor 
yhat = glmval(B, X, 'logit'); 
figure, scatter(Z, yhat), hold on, 
gscatter(Z, zeros(length(X),1)-0.1, Y) % plot original classes 
hold off, xlabel('\bf Z'), grid on, ylim([-0.2 1.05]) 
title('\bf Predicted Probability of each record') 

回答

0

可能有更Python的方式来做到这一点,但这里是我能想出到底:在这种情况下

(请记住,数据不完全分离的,所以曲线doen't具有传统的外观与在0.50点的S形曲线上分离的类。)

############################################################################# 
#### Draws a sigmoid probability plot from prediction results ############### 
############################################################################# 
import matplotlib.pyplot as plt 
import numpy as np 
print ('-'*40) 

# make the predictions (class) and also get the prediction probabilities 
y_train_predict = clf.predict(X_train) 
y_train_predictProbas = clf.predict_proba(X_train) 
y_train_predictProbas = y_train_predictProbas[:, 1] 

y_test_predict = clf.predict(X_test) 
y_test_predictProbas = clf.predict_proba(X_test) 
y_test_predictProbas = y_test_predictProbas[:, 1] 

#Get the thetas from the model 
thetas = clf.coef_[0] 
intercept = clf.intercept_[0] 
print 'thetas=' 
print thetas 
print 'intercept=' 
print intercept 

#Display the predictors and their associated Thetas 
for idx, x in enumerate(thetas): 
    print "Predictor: " + str(labels[idx+1]) + "=" + str(x) 

#append intercept to thetas (because scikitlearn doesn't normally output theta0 
interceptAndThetas = np.append([intercept],thetas) 
X_testWithThetaZero = [] 
for row in X_test: 
    X_testWithThetaZero.append(np.append([1],row)) 

#Calculate the dot product for plotting the sigmoid 
dotProductResult = []  
for idx, x in enumerate(X_testWithThetaZero): 
    dotProductResult.append(np.dot(x, interceptAndThetas))  


fig, ax1 = plt.subplots() 

wrongDotProducts = [] 
rightDotProducts = [] 
#Build the plot 
for idx in range(0,len(dotProductResult)): 
    #plot the predicted value on the sigmoid curve 
    if y_test[idx] == 1: 
     ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['green'],linewidths=0.0) 
    else: 
     ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['black'],linewidths=0.0) 

    #plot the actual 
    if y_test[idx] == 1: 
     ax1.scatter(dotProductResult[idx], y_test[idx], c=['green'],linewidths=0.0) 
     #determine which ones are "wrong" so we can make a histogram 
     if y_test_predictProbas[idx] < 0.5: 
      wrongDotProducts.append(dotProductResult[idx]) 
     else: 
      rightDotProducts.append(dotProductResult[idx]) 
    else: 
     ax1.scatter(dotProductResult[idx], y_test[idx], c=['black'],linewidths=0.0) 
     #determine which ones are "wrong" so we can make a histogram 
     if y_test_predictProbas[idx] > 0.5: 
      wrongDotProducts.append(dotProductResult[idx]) 
     else: 
      rightDotProducts.append(dotProductResult[idx])   

#plt.xlim([-0.05, numInstances + 0.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('x') 
plt.grid(which="major", axis='both',markevery=0.10) # which='major', 
plt.ylabel('Prediction Probability') 
plt.title('Sigmoid Curve & Histogram of Predictions') 


## Add a histogram to show where the correct/incorrect prediction distributions 
ax2 = ax1.twinx() 
ax2.hist(wrongDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="/", color="black", alpha=0.2) 
ax2.hist(rightDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="\\", color="green", alpha=0.2) 

ax2.set_ylabel('Histogram Count of Actual Class\n1=Green 0=Black') 
ax2.set_xlabel('') 
ax2.set_title('') 
plt.show()  

enter image description here