K-近邻分类算法的python实现及案例分析

文章目录2,处理分类问题二、约会网站配对效果判定三、手写数字识别 总结
一、电影类别分类k-近邻算法是一种基本分类与回归方法,我们可以使用k-近邻算法分类一个电影是是爱情片还是动作片 。1.准备电影数据
以下是我们已有的数据集合,也就是训练样本,假如有一部未看过的电影,如何确定它是爱情片还是动作片呢,我们可以使用kNN来解决这个问题 。
电影名称打斗场景接吻镜头电影类型
Man
104
爱情片
He 's Notinto Dudes
100
爱情片
Woman
81
爱情片
Kevin
101
10
动作片
Robo
99
动作片
Amped
98
动作片
1.创建数据集
首先,使用numpy创建我们所需的数据集代码,
import numpy as npimport operator"""group - 数据集labels - 分类标签"""# 创建数据集def createDataSet():group = np.array([[3, 104], [2, 100], [1, 81], [101, 10], [99, 5], [98, 2]])labels = ['爱情片', '爱情片', '爱情片', '动作片', '动作片', '动作片']return group, labelsif __name__ == '__main__':# 创建数据集group, labels = createDataSet()print(group)print(labels)
运行结果如下
2,处理分类问题
根据欧式距离公式计算出两点之间的距离,并返回前K个点的分类结果 。
2.1分类代码
代码如下:
"""parameters:inX - 用于分类的数据(测试集)dataSet - 用于训练的数据(训练集)labels - 分类标签k - kNN算法参数,选择距离最小的k个点returns:sortedClassCount[0][0] - 分类结果"""# 分类def classifyo(inX,dataSet,labels,k):dataSetSize = dataSet.shape[0]# shape[0] 返回dataSet的行数diffMat = np.tile(inX, (dataSetSize, 1))-dataSet# 二维特征相减后平方sqDiffMat = diffMat ** 2# sum(0)列相加,sum(1)行相加sqDistances = sqDiffMat.sum(axis=1)# 开方计算出距离distances = sqDistances ** 0.5# 返回distances中元素从小到大排序后的索引值sortedDistIndicies = distances.argsort()# 定一个记录类别次数的字典classCount = {}for i in range(k):# 取出前k个元素的类别voteIlabel = labels[sortedDistIndicies[i]]classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]if __name__ == '__main__':# 创建数据集group, labels = createDataSet()# 测试集test = [18, 90]# kNN分类test_class = classifyo(test, group, labels, 3)# 打印分类结果print(print(test_class)
运行结果如下
二、约会网站配对效果判定 1. 收集数据
海伦一直使用在线约会网站寻找适合自己的约会对象 。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人 。经过一番总结,她发现曾交往过三种类型的人:
海伦收集约会数据巳经有了一段时间,她把这些数据存放在文本文件.tet中,每个样本数据占据一行,总共有1000行 。海伦的样本主要包含以下3种特征:
2. 准备数据 2.1 从文本文件中解析数据
将上述特征数据输入到分类器之前,还需将待处理数据的格式转换为分类器可以接受的格式 。创建名为的函数,该函数输入为文件名字符串,输出为训练样本矩阵和类标签向量 。
import numpy as npdef file2matrix(filename):fr = open(filename)arrayOLines = fr.readlines()numberOfLines = len(arrayOLines)#返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列returnMat = np.zeros((numberOfLines, 3))classLabelVector = []index = 0for line in arrayOLines:line = line.strip()listFromLine = line.split('\t')returnMat[index, :] = listFromLine[0:3]if listFromLine[-1] == 'didntLike':classLabelVector.append(1)elif listFromLine[-1] == 'smallDoses':classLabelVector.append(2)elif listFromLine[-1] == 'largeDoses':classLabelVector.append(3)index += 1return returnMat, classLabelVectorif __name__ == '__main__':# 打开的文件名filename = "datingTestSet.txt"# 打开并处理数据datingDataMat, datingLabels = file2matrix(filename)print(datingDataMat)print(datingLabels)