在照著Tensorflow官網的demo敲了一遍分類器項目的代碼后,運行倒是成功了,結果也不錯。但是最終還是要訓練自己的數據,所以嘗試準備加載自定義的數據,然而demo中只是出現了fashion_mnist.load_data()并沒有詳細的讀取過程,隨后我又找了些資料,把讀取的過程記錄在這里。
首先提一下需要用到的模塊:
import osimport kerasimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.preprocessing.image import ImageDataGeneratorfrom sklearn.model_selection import train_test_split
圖片分類器項目,首先確定你要處理的圖片分辨率將是多少,這里的例子為30像素:
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
其次確定你圖片的方式目錄:
image_path = r'D:/Projects/ImageClassifier/data/set'path = "./data"# 你也可以使用相對路徑的方式# image_path =os.path.join(path, "set")
目錄下的結構如下:

相應的label.txt如下:
動漫
風景
美女
物語
櫻花
接下來是接在labels.txt,如下:
label_name = "labels.txt"label_path = os.path.join(path, label_name)class_names = np.loadtxt(label_path, type(""))這里簡便起見,直接利用了numpy的loadtxt函數直接加載。
之后便是正式處理圖片數據了,注釋就寫在里面了:
re_load = Falsere_build = False# re_load = Truere_build = Truedata_name = "data.npz"data_path = os.path.join(path, data_name)model_name = "model.h5"model_path = os.path.join(path, model_name)count = 0# 這里判斷是否存在序列化之后的數據,re_load是一個開關,是否強制重新處理,測試用,可以去除。if not os.path.exists(data_path) or re_load:  labels = []  images = []  print('Handle images')  # 由于label.txt是和圖片防止目錄的分類目錄一一對應的,即每個子目錄的目錄名就是labels.txt里的一個label,所以這里可以通過讀取class_names的每一項去拼接path后讀取  for index, name in enumerate(class_names):    # 這里是拼接后的子目錄path    classpath = os.path.join(image_path, name)    # 先判斷一下是否是目錄    if not os.path.isdir(classpath):      continue    # limit是測試時候用的這里可以去除    limit = 0    for image_name in os.listdir(classpath):      if limit >= max_size:        break      # 這里是拼接后的待處理的圖片path      imagepath = os.path.join(classpath, image_name)      count = count + 1      limit = limit + 1      # 利用Image打開圖片      img = Image.open(imagepath)      # 縮放到你最初確定要處理的圖片分辨率大小      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))      # 轉為灰度圖片,這里彩色通道會干擾結果,并且會加大計算量      img = img.convert("L")      # 轉為numpy數組      img = np.array(img)      # 由(30,30)轉為(1,30,30)(即`channels_first`),當然你也可以轉換為(30,30,1)(即`channels_last`)但為了之后預覽處理后的圖片方便這里采用了(1,30,30)的格式存放      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))      # 這里利用循環生成labels數據,其中存放的實際是class_names中對應元素的索引      labels.append([index])      # 添加到images中,最后統一處理      images.append(img)      # 循環中一些狀態的輸出,可以去除      print("{} class: {} {} limit: {} {}"         .format(count, index + 1, class_names[index], limit, imagepath))  # 最后一次性將images和labels都轉換成numpy數組  npy_data = np.array(images)  npy_labels = np.array(labels)  # 處理數據只需要一次,所以我們選擇在這里利用numpy自帶的方法將處理之后的數據序列化存儲  np.savez(data_path, x=npy_data, y=npy_labels)  print("Save images by npz")else:  # 如果存在序列化號的數據,便直接讀取,提高速度  npy_data = np.load(data_path)["x"]  npy_labels = np.load(data_path)["y"]  print("Load images by npz")image_data = npy_datalabels_data = npy_labels            
新聞熱點
疑難解答