程序员的救赎 发表于 2019-8-15 13:57:42

初识机器学习——鸢尾花分类

本帖最后由 程序员的救赎 于 2019-8-23 13:49 编辑

一、查看数据集


        简单说明一下数据集的构成及其作用:
                训练集——学习知识——>课本;
                验证集——检验学习能力——>课后习题;
                测试集——检验泛化能力——>考试;   
             一般来说,验证集从训练集中划分,有时候并不是必须的。

        鸢尾花数据集是机器学习和统计学中一个经典的数据集,包含在scikit-learn的datasets模块中from sklearn.datasets import load_iris
iris_dataset = load_iris()

        load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值
iris_dataset.keys()dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])

        DESCR键对应的值是数据集的简要说明,这里只截取部分
print(iris_dataset['DESCR'][:193]).. _iris_dataset:

        Iris plants dataset
        --------------------

        **Data Set Characteristics:**

           :Number of Instances: 150 (50 in each of three classes)
           :Number of Attributes: 4 numeric, pre

        target_names 键对应的值是一个字符串数组,里面包含要预测的花的花种
        iris_dataset['target_names']array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

        feature_names 键对应的值是一个字符串列表,对每一个特征进行了说明
        iris_dataset['feature_names']['sepal length (cm)',
       'sepal width (cm)',
        'petal length (cm)',
       'petal width (cm)']

        数据包含在target和data中
        iris_dataset['data'][:5]# 打印前五行array([,
              ,
              ,
              ,
              ])

        iris_dataset['data'].shape# 查看数组形状(150, 4)
        可以看出,数组形状为150x4。每一行为一朵花的数据(sample),每一列为花的不同属性(feature)

        iris_dataset['target']array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
        target为一维数组,0, 1,2分别和上面的target_names下标对应的类别相对应,一个数字标记了一朵花的类别 (label)


二、划分训练数据和测试数据

        from sklearn.model_selection import train_test_split
        划分数据集, 可以设置测试集的比例(默认为0.25), 指定随机数生成的种子(确保每次得到的输出是固定的)
X_train, X_test, y_train, y_test = train_test_split(
iris_dataset['data'], iris_dataset['target'], random_state=0)
        数据集用X表示, 标签用y表示是受函数y = f(x)启发,使用大写的X代表二维数组,而小写的y代表一维标量。 打乱数据集后再进行划分的目的是防止因为标签是有序的而导致测试集或者训练集中的数据类别不完整,进而影响训练结果

X_train.shape(112, 4)
X_test.shape(38, 4)


三、观察数据

        在构建学习模型之前,通常要检查一下数据。一是做模型选择,二是检测数据是否异常。模型选择要根据数据的特定进行判断,而检测数据异常最佳方法之一是将其可视化
        import pandas as pd
import mglearn
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)

        绘制散点图
grr = pd.plotting.scatter_matrix(iris_dataframe,
                                 c=y_train,# 颜色参数,使用的训练集的label
                                 figsize=(15, 15), # 图像尺寸
                                 marker='o',# 标记类型
                                 hist_kwds={'bins': 20},
                                 s=60,
                                 alpha=.8, # 图像透明度
                                 cmap=mglearn.cm3)


        从图中可以知道,从花瓣和花萼的测量标签数据基本可以将三个类别区分开,说明机器学习很可能学会区分它们


四、训练模型——k近邻
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)# 设置邻居数目
knn.fit(X_train, y_train)KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
         metric_params=None, n_jobs=None, n_neighbors=1, p=2,
         weights='uniform')

        模型训练好之后,需要进行预测
import numpy as np
X_new = np.array([])
X_new.shape        (1, 4)
predicion = knn.predict(X_new)
predicionarray()

        预测结果表明这个花属于类别0, 即setosa


五、模型评估

        y_pred = knn.predict(X_test)
y_pred        array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])

        计算预测精度
        np.mean(y_pred == y_test)0.9736842105263158
      预测精度代表了我们模型在面对新的数据时的处理能力(泛化能力),精度越高,说明模型越好。

        至此,第一个机器学习任务结束。

机器学习的一般步骤:
      获取数据——数据清洗和处理——划分数据集——模型选择——参数设置——训练模型——模型评估——调整模型——预测结果

      由于这里数据集是标准数据集,所以不需要清洗数据。


页: [1]
查看完整版本: 初识机器学习——鸢尾花分类