鱼C论坛

 找回密码
 立即注册
查看: 1441|回复: 0

[技术交流] 机器学习实战--利用k-近邻算法改进约会网站

[复制链接]
发表于 2018-4-10 09:57:23 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 微言大义 于 2018-4-10 10:43 编辑

Python小白,照着《机器学习实战》敲了k-近邻算法的代码,添加了一些注释,完整的代码可以运行,背景环境还需要看书
  1. import numpy as np
  2. import operator
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. from pylab import mpl

  6. def file2matrix(filename):
  7.     fr = open(filename)
  8.     numberOfLines = len(fr.readlines())     #得到文件行数
  9.     returnMat = np.zeros((numberOfLines,3))    #创建以0填充的矩阵Numpy,将另一维度设置为固定的3
  10.     classLabelVector = []
  11.     fr = open(filename)
  12.     index = 0
  13.     for line in fr.readlines():
  14.         line = line.strip()     #截取掉所有的回车字符
  15.         listFromLine = line.split('\t')     #使用tab字符\t将上一步得到的整行数据分割成一个元素列表
  16.         returnMat[index,:] = listFromLine[0:3]  #选取前三个元素,存储到特征矩阵中
  17.         classLabelVector.append(int(listFromLine[-1]))  #利用负索引将表的最后一列存储到 classLabelVector中,必须明确列表中存储的元素是整形,否则Python会将这些元素当做字符串处理
  18.         index += 1
  19.     return returnMat,classLabelVector
  20.    

  21. #分析数据:使用Matplotlib创建散点图
  22. datingDataMat,datingLabels =file2matrix('datingTestSet2.txt')
  23. fig = plt.figure()
  24. ax = fig.add_subplot(111)
  25. ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
  26. mpl.rcParams['font.sans-serif'] = ['SimHei']    #将坐标轴标记转换为汉字
  27. plt.xlabel('玩视频游戏所耗时间比')
  28. plt.ylabel('每周消费的冰激凌公斤数')
  29. plt.show()
  30. fig = plt.figure(2)
  31. ax = fig.add_subplot(111)
  32. ax.scatter(datingDataMat[:,1],datingDataMat[:,0],15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
  33. plt.xlabel('玩视频游戏所耗时间比')
  34. plt.ylabel('每年获取的飞行常客旅程数')
  35. plt.show()


  36. #归一化特征值
  37. def autoNorm(dataSet):
  38.     minVals = dataSet.min(0)    #将最小值放在minVal中,参数0使得函数可以从列中选取最小值,而不是从当前行的最小值
  39.     maxVals = dataSet.max(0)
  40.     ranges = maxVals-minVals    #计算可能的取值范围
  41.     normDataSet = np.zeros(np.shape(dataSet))  #创建新的返回矩阵
  42.     m = dataSet.shape[0]
  43.     normDataSet = dataSet - np.tile(minVals,(m,1)) #tile函数将变量内容复制成输入矩阵同样大小的矩阵
  44.     normDataSet = normDataSet/np.tile(ranges,(m,1))
  45.     return normDataSet,ranges,minVals   #特征值相除

  46. #执行autoNorm函数,监测函数执行结果
  47. normMat,ranges,minVals = autoNorm(datingDataMat)

  48. def classify0(inX,dataSet,labels,k):   #用于分类的输入向量inX,输入的训练样本集dataSet,标签向量labels,用于选择最近邻居数目的k值
  49.         dataSetSize = dataSet.shape[0]   #利用shape函数读取dataSet第一维度长度
  50.         diffMat = np.tile(inX,(dataSetSize,1)) - dataSet        #利用欧氏距离公式计算两向量点之间的距离
  51.         sqDiffMat = diffMat ** 2
  52.         sqDistance = sqDiffMat.sum(axis=1)          #将矩阵的每一行向量相加求和
  53.         distances = sqDistance**0.5
  54.         sortedDistIndicies = distances.argsort()        #返回数组值从小到大的索引值
  55.         classCount = {}                #建立空字典
  56.         for i in range(k):                #选择距离最小的k个点
  57.                 voteIlabel = labels[sortedDistIndicies[i]]
  58.                 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1        #将classCount字典分解成元组列表
  59.         sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)        #导入运算符模块的itemgetter方法,按照第二个元素的次序对元组进行排序
  60.         return sortedClassCount[0][0]

  61. #测试错误率
  62. def datingClassTest():
  63.     hoRatio = 0.1
  64.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
  65.     normMat,ranges,minVals = autoNorm(datingDataMat)
  66.     m = normMat.shape[0]
  67.     numTestVecs = int(m*hoRatio)
  68.     errorCount = 0.0
  69.     for i in range(numTestVecs):
  70.         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
  71.         print("the classifier came back with: %d, the real answer is: %d"% (classifierResult, datingLabels[i]))
  72.         if (classifierResult != datingLabels[i]):
  73.             errorCount += 1.0   
  74.     print ("the total error rate is: %f" % (errorCount/float(numTestVecs)))

  75. #约会网站预测函数,调用之前的函数,添加了输入函数,以便用户输入
  76. def classifyPerson():
  77.     resultList = ['not at all','in small doses','in large doses']
  78.     percentTats = float(input("percentage of time spent playing video games?"))
  79.     ffMiles = float(input("frequent flier miles earned per year?"))
  80.     iceCream = float(input("liters of ice cream consumed per year?"))
  81.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
  82.     normMat,ranges,minVals = autoNorm(datingDataMat)
  83.     inArr = np.array([ffMiles,percentTats,iceCream])
  84.     classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
  85.     print("you will probably like this person:",resultList[classifierResult - 1])
复制代码


运行程序,得散点图如下:

执行框输入classifyPerson()即可,运行结果如下:
percentage of time spent playing video games?10

frequent flier miles earned per year?10000

liters of ice cream consumed per year?0.5

you will probably like this person: in small doses

散点图

散点图

散点图

散点图
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-27 09:18

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表