当前位置: 首页 > ML-算法 > 正文

K-近邻算法实现

PlantUML Syntax:<br />
@startuml<br />
scale 600*400<br />
skinparam defaultFontName AR PL UKai CN<br />
skinparam defaultFontSize 16</p>
<p>title</p>
<p>kNN算法流程图</p>
<p>end title</p>
<p> start<br />
 :(1) 计算已知类别数据集中的点与当前点之间的距离;<br />
 :(2) 按照距离递增次序排序;<br />
 :(3) 选取与当前点距离最小的k个点;<br />
 :(4) 确定前k个点所在类别的出现频率;<br />
 :(5) 返回前k个点出现频率最高的类别作为当前点的预测分类;<br />
 stop</p>
<p>@enduml<br />

#!/usr/bin/env python3

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import operator

def main():
 group, labels = createTrainingDataSet()
 label = classify0([0, 0], group, labels, 3)
 print(label)

def dating_test():
 datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
 datingDataMat, a, b = autoNorm(datingDataMat)
 print(datingDataMat[0:15])
 print(datingLabels[0:15])
 fig = plt.figure()
 ax = fig.add_subplot(111)
 ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],
 15.0 * np.array(datingLabels), 15.0 * np.array(datingLabels))
 ax.scatter(datingDataMat[:, 0], datingDataMat[:, 1],
 15.0 * np.array(datingLabels), 15.0 * np.array(datingLabels))
 plt.show()

def main_test():
 # group, labels = createTrainingDataSet()
 # print(group.shape)
 # print(np.tile(np.array([1, 2]), (group.shape[0], 2)))
 a = np.arange(15).reshape(3, 5)
 print(a)
 print(a.T)
 print(a.ndim)
 print(a.shape)
 print(a.dtype)

def createTrainingDataSet():
 group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
 labels = ['A', 'A', 'B', 'B']
 return group, labels

def classify0(inX, trainingDataSet, labels, k):
 dataSetSize = trainingDataSet.shape[0]
 diffMat = np.tile(inX, (dataSetSize, 1)) - trainingDataSet
 # 将inX 平铺(复制)成跟训练集同样的行数
 # 如果inX = [a,b], 训练集有2行,
 # 则结果为 一个矩阵
 # [[a,b],
 # [a,b]]
 # 然后再和训练集做矩阵减法
 sqDiffMat = diffMat ** 2
 sqDistances = sqDiffMat.sum(axis=1)
 distances = sqDistances ** 0.5
 sortedDistIndices = distances.argsort()
 classCount = {}
 for i in range(k):
 voteILabel = labels[sortedDistIndices[i]]
 classCount[voteILabel] = classCount.get(voteILabel, 0) + 1
 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
 return sortedClassCount[0][0]

def file2matrix(filename):
 fr = open(filename)
 numberOfLines = len(fr.readlines()) # get the number of lines in the file
 returnMat = np.zeros((numberOfLines, 3)) # prepare matrix to return
 classLabelVector = [] # prepare labels return
 fr = open(filename)
 index = 0
 for line in fr.readlines():
 line = line.strip()
 listFromLine = line.split('\t')
 returnMat[index, :] = listFromLine[0:3]
 classLabelVector.append(int(listFromLine[-1]))
 index += 1
 return returnMat, classLabelVector

def autoNorm(dataSet):
 minVals = dataSet.min(0)
 maxVals = dataSet.max(0)
 ranges = maxVals - minVals
 normDataSet = np.zeros(np.shape(dataSet))
 m = dataSet.shape[0]
 normDataSet = dataSet - np.tile(minVals, (m, 1))
 normDataSet = normDataSet / np.tile(ranges, (m, 1)) # element wise divide
 return normDataSet, ranges, minVals

def datingClassTest():
 hoRatio = 0.1
 datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
 normMat, ranges, minVals = autoNorm(datingDataMat)
 m = normMat.shape[0]
 numTestVecs = int(m * hoRatio)
 errorCount = 0.0
 for i in range(numTestVecs):
 classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
 print(f"the classifier came back with: {classifierResult}, the real answer is {datingLabels[i]}")
 if classifierResult != datingLabels[i]:
 errorCount += 1.0
 print(f"the total error rata is {errorCount / float(numTestVecs)}")

def realP():
 datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
 normMat, ranges, minVals = autoNorm(datingDataMat)
 inArr = np.array([50000, 10, 0.5])
 r = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)
 print(f"predict result is {r}")

if __name__ == '__main__':
 datingClassTest()
 realP()
 # main()
 # main_test()
 # dating_test()
赞 赏

   微信赞赏  支付宝赞赏


本文固定链接: https://www.jack-yin.com/coding/ml/ml-algorithm/2795.html | 边城网事

该日志由 边城网事 于2019年08月04日发表在 ML-算法 分类下, 你可以发表评论,并在保留原文地址及作者的情况下引用到你的网站或博客。
原创文章转载请注明: K-近邻算法实现 | 边城网事

K-近邻算法实现 暂无评论

发表评论

快捷键:Ctrl+Enter