本文將原始的numpy array數(shù)據(jù)在pytorch下封裝為Dataset類(lèi)的數(shù)據(jù)集,為后續(xù)深度網(wǎng)絡(luò)訓(xùn)練提供數(shù)據(jù)。
加載并保存圖像信息
首先導(dǎo)入需要的庫(kù),定義各種路徑。
import osimport matplotlibfrom keras.datasets import mnistimport numpy as npfrom torch.utils.data.dataset import Datasetfrom PIL import Imageimport scipy.miscroot_path = 'E:/coding_ex/pytorch/Alexnet/data/'base_path = 'baseset/'training_path = 'trainingset/'test_path = 'testset/'
這里將數(shù)據(jù)集分為三類(lèi),baseset為所有數(shù)據(jù)(trainingset+testset),trainingset是訓(xùn)練集,testset是測(cè)試集。直接通過(guò)keras.dataset加載mnist數(shù)據(jù)集,不能自動(dòng)下載的話(huà)可以手動(dòng)下載.npz并保存至相應(yīng)目錄下。
def LoadData(root_path, base_path, training_path, test_path): (x_train, y_train), (x_test, y_test) = mnist.load_data() x_baseset = np.concatenate((x_train, x_test)) y_baseset = np.concatenate((y_train, y_test)) train_num = len(x_train) test_num = len(x_test) #baseset file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w') file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w') for i in range(train_num + test_num): file_img.write(root_path + base_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_baseset[i])+'/n') #label# scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i]) matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i]) file_img.close() file_label.close() #trainingset file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w') file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w') for i in range(train_num): file_img.write(root_path + training_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_train[i])+'/n') #label# scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i]) matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i]) file_img.close() file_label.close() #testset file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w') file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w') for i in range(test_num): file_img.write(root_path + test_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_test[i])+'/n') #label# scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i]) matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i]) file_img.close() file_label.close()
使用這段代碼時(shí),需要建立相應(yīng)的文件夾及.txt文件,./data文件夾結(jié)構(gòu)如下:
/img文件夾
由于mnist數(shù)據(jù)集其實(shí)是灰度圖,這里用matplotlib保存的圖像是偽彩色圖像。
新聞熱點(diǎn)
疑難解答
圖片精選
網(wǎng)友關(guān)注