鱼C论坛

 找回密码
 立即注册
查看: 1620|回复: 0

[技术交流] 机器学习系列------决策函数

[复制链接]
发表于 2018-6-14 09:51:09 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        新启一个文件,首先把以前的代码复制过来:
  1. from sklearn.datasets import fetch_mldata



  2. mnist=fetch_mldata('MNIST original',data_home='.\datasets')


  3. X,y=mnist["data"],mnist["target"]
  4. %matplotlib inline
  5. import matplotlib
  6. import matplotlib.pyplot as plt

  7. X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]

  8. import numpy as np
  9. shuffle_index=np.random.permutation(60000)
  10. X_train,y_train=X_train[shuffle_index],y_train[shuffle_index]
  11. some_digit=X[36000]
  12. y_train_5=(y_train==5)
  13. y_test_5=(y_test==5)
  14. from sklearn.linear_model import SGDClassifier

  15. sgd_clf=SGDClassifier(random_state=42)
  16. sgd_clf.fit(X_train,y_train_5)
复制代码

        建立决策函数,决策函数是通过调试阀值帮助我们预测精准度的一个工具:
  1. y_scores=sgd_clf.decision_function([some_digit])
  2. y_scores
复制代码

        会显示:array([ 66994.58438748]),这个数是预测准确的一个分数,但是必须要有一个阀值来判断边界:
  1. threshold=0
  2. y_some_digit_pred=(y_scores>threshold)
  3. y_some_digit_pred
复制代码

        以上我们把阀值设定为0,他会显示:array([ True], dtype=bool)。因为我们这个分数已经超过了阀值,就说明预测结果是对的,但是阀值又通过什么来确定呢,我们把阀值设定高点试试:
  1. threshold=200000
  2. y_some_digit_pred=(y_scores>threshold)
  3. y_some_digit_pred
复制代码

        显示:array([False], dtype=bool)。就说明阀值一定要取的正合适才行,就要通过准确率召回曲线确定阀值:
  1. from sklearn.model_selection import cross_val_predict
  2. y_scores=cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
  3. from sklearn.metrics import precision_recall_curve
  4. precisions,recalls,thresholds=precision_recall_curve(y_train_5,y_scores[:,1])
复制代码

        然后我们写个画图功能的函数:
  1. def plot_precision_recall_vs_threshold(precisions,recalls,thresholds):
  2.     plt.plot(thresholds,precisions[0:-1],"b--",label="Precision")
  3.     plt.plot(thresholds,recalls[:-1],"g-",label="Recall")
  4.     plt.xlabel("Threshold")
  5.     plt.legend(loc="upper left")
  6.     plt.ylim([0,1])
复制代码

        调用画图函数:
  1. plot_precision_recall_vs_threshold(precisions,recalls,thresholds)
  2. plt.show()
复制代码

        输出的图像为:
dsfsdfd.png
        从上图可以得出准确度追高的时候在50000左右,我们可以看一下决策函数大于50000时的评分:
  1. from sklearn.metrics import precision_score,recall_score
  2. y_train_pred_90=(y_scores>50000)
  3. precision_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
复制代码

        显示结果为:0.90712074303405577
  1. recall_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
复制代码

        0.54049068437557646

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-20 10:55

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表