算法概述
優點 精度高、對異常值不敏感、無數據輸入假定。 缺點 計算復雜度高、空間復雜度高。 試用數據范圍 數值型和標稱型
工作原理:將新數據的每個特征與樣本集中數據對應特征進行比較,計算之間的距離值,選取樣本數據集中前k個最相似的數據。
偽代碼: 1. 計算已知類別數據集中的點與當前點之間的距離 2. 按照距離遞增次序排序 3. 選取與當前點距離最小的k個點 4. 確定前k個點所在類別的出現頻率 5. 返回前k個點出現頻率最高的類別作為當前點的預測分類
引入numpy包。
def createDataSet(): group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0]#4l diffMat = tile(inX, (dataSetSize,1)) - dataSet#inx按4*1重復排列再與dataset做差 sqDiffMat = diffMat**2#每個元素平方 sqDistances = sqDiffMat.sum(axis=1)#一個框里的都加起來 distances = sqDistances**0.5#加起來之后每個開根號 sortedDistIndicies = distances.argsort() #返回的是數組值從小到大的索引值,最小為0 classCount={} for i in range(k):#即0到k-1.最后得到的classCount是距離最近的k個點的類分布,即什么類出現幾次如{'A': 1, 'B': 2} voteIlabel = labels[sortedDistIndicies[i]]#返回從近到遠第i個點所對應的類 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#字典模式,記錄類對應出現次數.這里.get(a,b)就是尋找字典classcount的a對應的值,如果沒找到a,就顯示b sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]注1:這里的operator.itemgetter(1)是指定義一個函數,這個函數可以獲取它括號里的第1個元素 如
a=[3,4,2,6]b=operator.itemgetter(1)b(a)Out[116]: 4c=operator.itemgetter(0)c(a)Out[118]: 3如果operator.itemgetter(1,2)的話指的是用多級排序,以1為首,2為輔
注2:這里sorted(iterable,cmp=None,key=None,reverse=False)
參數解釋:
(1)iterable指定要排序的list或者iterable,不用多說;
(2)cmp為函數,指定排序時進行比較的函數,可以指定一個函數或者lambda函數,如:
(3)key為函數,指定取待排序元素的哪一項進行排序,函數用上面的例子來說明,key指定的lambda函數功能是去元素student的第三個域(即:student[2]) (4)reverse=False是升序,True是降。默認升序。
故綜合起來sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
是以第1個元素為排序標準,將classCount進行降序排列。 如classCount為{‘A’: 1, ‘B’: 2},則sortedClassCount[(‘B’, 2), (‘A’, 1)],則最后返回值sortedClassCount[0][0]為“B”
注3:另外代碼在使用前應加一句 group,labels=createDataSet()
因為group,labels是內置變量,在外面直接使用的話要再定義一遍。
最后輸入如classify0([1,0], group, labels, 2)
來使用
問題描述:訓練集對應0~9的數字,每個數字有200個這樣的txt,要進行識別,這里我們先把矩陣轉換為向量,再比較一個新的向量跟已有數據集向量的距離,找最為接近的k個訓練實例,對應的最多的類就是我們識別出的數字。
def img2vector(filename):#把32*32矩陣處理成1*1024向量 returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVectdef handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') #load the training set,list of filename m = len(trainingFileList) trainingMat = zeros((m,1024)) #形成array,共有m個元素,每個是1*1024向量 for i in range(m): fileNameStr = trainingFileList[i] #如'0_1.txt' fileStr = fileNameStr.split('.')[0] #take off “.txt”即用.分割成2個元素,取前一個 classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#厲害了 testFileList = listdir('testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) 出來結果為......the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the total number of errors is: 11the total error rate is: 0.011628值得注意的是,這段代碼一定要在已經導入os的情況下才能用,即必須最前面有 from os import listdir
新聞熱點
疑難解答