K-邻近均值算法

    科技2025-05-17  8

    代码来自Peter Harrington《机器学习实战》 相关数据集见www.manning.com/MachineLearninginAction kNN.py

    from numpy import * import operator class k_algrithom: def __init__(self): self.inx = None self.dataset = None self.labels = None self.k = None # def createDataset(self): # group = array([1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]) # labaels = ['A', 'A', 'B', 'B'] # return group, labaels def classify0(self,inx, dataset, labels, k): # inx:输入向量,dataset:输入样本,labels:标签向量,k:选择最邻近的数目 dataSetSize = dataset.shape[0] diffMat = tile(inx, (dataSetSize, 1)) - dataset sqDiffMat = diffMat ** 2 sqDistance = sqDiffMat.sum(axis=1) distance = sqDistance ** 0.5 sortedDistIndices = distance.argsort() classcount = {} for i in range(k): votelabel = labels[sortedDistIndices[i]] classcount[votelabel] = classcount.get(votelabel, 0) + 1 sortedclasscount = sorted(classcount.items(), key=operator.itemgetter(1), reverse=True) return sortedclasscount[0][0] def file2matrix(self,filename):#将文本记录转化为numpy的解析程序 # 得到文件行数 fr = open(filename) arrayOlines = fr.readlines() numberOflines = len(arrayOlines) returnMat = zeros((numberOflines,3)) classLabelVector = [] index = 0 # 解析文件数据得到列表 for line in arrayOlines: line = line.strip() listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat, classLabelVector def autonorm(self, dataset):# 归一化特征值 minvals = dataset.min(0) maxvals = dataset.max(0) ranges = maxvals - minvals normDataset = zeros(shape(dataset)) m = dataset.shape[0] normDataset = dataset - tile(minvals, (m, 1)) normDataset = normDataset / tile(ranges, (m, 1)) return normDataset, ranges, minvals def imgvector(self, filename):# 将图像转化为向量 returnVect = zeros((1, 1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0, 32*i + j] = int(lineStr[j]) return returnVect

    running.py

    from kmean.kNN import * import matplotlib.pyplot as plt import os datingDataMat, datingLabels = k_algrithom().file2matrix('datingTestSet2.txt') # fig = plt.figure() # ax = fig.add_subplot(111) # ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0 * array(datingLabels), 15.0 * array(datingLabels)) # plt.show() # normMat, ranges, minVals = k_algrithom().autonorm(datingDataMat) # print(normMat, ranges, minVals) def datingClassTest(): # 分类器对约会网站的测试代码 hoRitio = 0.1 datingDataMat, datingLabels = k_algrithom().file2matrix('datingTestSet2.txt') normMat, ranges, minvals = k_algrithom().autonorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRitio) errorCount = 0.0 for i in range(numTestVecs): classfireResult = k_algrithom().classify0(normMat[i, :], normMat[numTestVecs:m, :],\ datingLabels[numTestVecs:m], 3) print('the classfier came back with: %d, the real answer is :%d'\ % (classfireResult, datingLabels[i])) if (classfireResult != datingLabels[i]): errorCount += 1.0 print('the total error rate is: %f'% (errorCount/float(numTestVecs))) def handwritingClassTest():# 实现手写输入识别 hwLabels = [] trainingFileList = os.listdir('trainingDigits') m = len(trainingFileList) trainingMat = zeros((m, 1024)) for i in range(m): filenameStr = trainingFileList[i] fileStr = filenameStr.split('.')[0] classNumberStr = int(fileStr.split('_')[0]) hwLabels.append(classNumberStr) trainingMat[i,:] = k_algrithom().imgvector('trainingDigits/%s' % filenameStr) testfileList = os.listdir('testDigits') errorcount = 0.0 mTest = len(testfileList) for i in range(mTest): filenameStr = testfileList[i] fileStr2 = filenameStr.split('.')[0] classNumberStr = int(fileStr.split('_')[0]) vectorUnderTest = k_algrithom().imgvector('testDigits/%s' % filenameStr) classfireResult = k_algrithom().classify0(vectorUnderTest, \ trainingMat, hwLabels, 3) print('the classfier came back with: %d, the real answer is :%d' \ % (classfireResult, datingLabels[i])) if (classfireResult != datingLabels[i]): errorcount += 1.0 print('\nthe total number of errors is: %d' % errorcount) print('\nthe total error rate is: %f' % (errorcount / float(mTest))) handwritingClassTest()
    Processed: 0.015, SQL: 8