|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
新启一个文件,首先把以前的代码复制过来:
- from sklearn.datasets import fetch_mldata
- mnist=fetch_mldata('MNIST original',data_home='.\datasets')
- X,y=mnist["data"],mnist["target"]
- %matplotlib inline
- import matplotlib
- import matplotlib.pyplot as plt
- X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]
- import numpy as np
- shuffle_index=np.random.permutation(60000)
- X_train,y_train=X_train[shuffle_index],y_train[shuffle_index]
- some_digit=X[36000]
- y_train_5=(y_train==5)
- y_test_5=(y_test==5)
- from sklearn.linear_model import SGDClassifier
- sgd_clf=SGDClassifier(random_state=42)
- sgd_clf.fit(X_train,y_train_5)
复制代码
建立决策函数,决策函数是通过调试阀值帮助我们预测精准度的一个工具:
- y_scores=sgd_clf.decision_function([some_digit])
- y_scores
复制代码
会显示:array([ 66994.58438748]),这个数是预测准确的一个分数,但是必须要有一个阀值来判断边界:
- threshold=0
- y_some_digit_pred=(y_scores>threshold)
- y_some_digit_pred
复制代码
以上我们把阀值设定为0,他会显示:array([ True], dtype=bool)。因为我们这个分数已经超过了阀值,就说明预测结果是对的,但是阀值又通过什么来确定呢,我们把阀值设定高点试试:
- threshold=200000
- y_some_digit_pred=(y_scores>threshold)
- y_some_digit_pred
复制代码
显示:array([False], dtype=bool)。就说明阀值一定要取的正合适才行,就要通过准确率召回曲线确定阀值:
- from sklearn.model_selection import cross_val_predict
- y_scores=cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
- from sklearn.metrics import precision_recall_curve
- precisions,recalls,thresholds=precision_recall_curve(y_train_5,y_scores[:,1])
复制代码
然后我们写个画图功能的函数:
- def plot_precision_recall_vs_threshold(precisions,recalls,thresholds):
- plt.plot(thresholds,precisions[0:-1],"b--",label="Precision")
- plt.plot(thresholds,recalls[:-1],"g-",label="Recall")
- plt.xlabel("Threshold")
- plt.legend(loc="upper left")
- plt.ylim([0,1])
复制代码
调用画图函数:
- plot_precision_recall_vs_threshold(precisions,recalls,thresholds)
- plt.show()
复制代码
输出的图像为:
从上图可以得出准确度追高的时候在50000左右,我们可以看一下决策函数大于50000时的评分:
- from sklearn.metrics import precision_score,recall_score
- y_train_pred_90=(y_scores>50000)
- precision_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
复制代码
显示结果为:0.90712074303405577
- recall_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
复制代码
0.54049068437557646 |
|