国产探花免费观看_亚洲丰满少妇自慰呻吟_97日韩有码在线_资源在线日韩欧美_一区二区精品毛片,辰东完美世界有声小说,欢乐颂第一季,yy玄幻小说排行榜完本

首頁 > 編程 > Python > 正文

Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取

2020-02-15 21:29:42
字體:
供稿:網(wǎng)友

以讀取VOC2012語義分割數(shù)據(jù)集為例,具體見代碼注釋:

VocDataset.py

from PIL import Imageimport torchimport torch.utils.data as dataimport numpy as npimport osimport torchvisionimport torchvision.transforms as transformsimport time#VOC數(shù)據(jù)集分類對應(yīng)顏色標簽VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],        [0, 64, 128]]#顏色標簽空間轉(zhuǎn)到序號標簽空間,就他媽這里浪費巨量的時間,這里還他媽的有問題def voc_label_indices(colormap, colormap2label):  """Assign label indices for Pascal VOC2012 Dataset."""  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])  #out = np.empty(idx.shape, dtype = np.int64)   out = colormap2label[idx]  out=out.astype(np.int64)#數(shù)據(jù)類型轉(zhuǎn)換  end = time.time()  return outclass MyDataset(data.Dataset):#創(chuàng)建自定義的數(shù)據(jù)讀取類  def __init__(self, root, is_train, crop_size=(320,480)):    self.rgb_mean =(0.485, 0.456, 0.406)    self.rgb_std = (0.229, 0.224, 0.225)    self.root=root    self.crop_size=crop_size    images = []#創(chuàng)建空列表存文件名稱    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')    with open(txt_fname, 'r') as f:      self.images = f.read().split()    #數(shù)據(jù)名稱整理    self.files = []    for name in self.images:      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)      self.files.append({        "img": img_file,        "label": label_file,        "name": name      })    self.colormap2label = np.zeros(256**3)    #整個循環(huán)的意思就是將顏色標簽映射為單通道的數(shù)組索引    for i, cm in enumerate(VOC_COLORMAP):      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i  #按照索引讀取每個元素的具體內(nèi)容  def __getitem__(self, index):        datafiles = self.files[index]    name = datafiles["name"]    image = Image.open(datafiles["img"])    label = Image.open(datafiles["label"]).convert('RGB')#打開的是PNG格式的圖片要轉(zhuǎn)到rgb的格式下,不然結(jié)果會比較要命    #以圖像中心為中心截取固定大小圖像,小于固定大小的圖像則自動填0    imgCenterCrop = transforms.Compose([       transforms.CenterCrop(self.crop_size),       transforms.ToTensor(),       transforms.Normalize(self.rgb_mean, self.rgb_std),#圖像數(shù)據(jù)正則化     ])    labelCenterCrop = transforms.CenterCrop(self.crop_size)    cropImage=imgCenterCrop(image)    croplabel=labelCenterCrop(label)    croplabel=torch.from_numpy(np.array(croplabel)).long()#把標簽數(shù)據(jù)類型轉(zhuǎn)為torch        #將顏色標簽圖轉(zhuǎn)為序號標簽圖    mylabel=voc_label_indices(croplabel, self.colormap2label)        return cropImage,mylabel  #返回圖像數(shù)據(jù)長度  def __len__(self):    return len(self.files)            
發(fā)表評論 共有條評論
用戶名: 密碼:
驗證碼: 匿名發(fā)表
主站蜘蛛池模板: 即墨市| 龙州县| 六安市| 梓潼县| 武胜县| 布尔津县| 军事| 娄底市| 瑞丽市| 西乌珠穆沁旗| 定襄县| 巍山| 镶黄旗| 郁南县| 蓝山县| 新巴尔虎左旗| 梧州市| 吉安县| 阿勒泰市| 普格县| 团风县| 梅州市| 黑河市| 台前县| 霍城县| 桦川县| 科尔| 吐鲁番市| 松桃| 聂拉木县| 武功县| 绍兴县| 南充市| 元阳县| 新竹县| 松滋市| 南木林县| 乐至县| 新源县| 天水市| 西乌珠穆沁旗|