国产探花免费观看_亚洲丰满少妇自慰呻吟_97日韩有码在线_资源在线日韩欧美_一区二区精品毛片,辰东完美世界有声小说,欢乐颂第一季,yy玄幻小说排行榜完本

首頁 > 編程 > Python > 正文

使用PyTorch實現MNIST手寫體識別代碼

2020-02-15 21:29:29
字體:
來源:轉載
供稿:網友

實驗環境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu環境(可選)

MNIST數據集介紹

MNIST 包括6萬張28x28的訓練樣本,1萬張測試樣本,可以說是CV里的“Hello Word”。本文使用的CNN網絡將MNIST數據的識別率提高到了99%。下面我們就開始進行實戰。

導入包

import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformstorch.__version__

定義超參數

BATCH_SIZE=512EPOCHS=20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

數據集

我們直接使用PyTorch中自帶的dataset,并使用DataLoader對訓練數據和測試數據分別進行讀取。如果下載過數據集這里download可選擇False

train_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=True, download=True,             transform=transforms.Compose([              transforms.ToTensor(),              transforms.Normalize((0.1307,), (0.3081,))            ])),    batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=False, transform=transforms.Compose([              transforms.ToTensor(),              transforms.Normalize((0.1307,), (0.3081,))            ])),    batch_size=BATCH_SIZE, shuffle=True)

定義網絡

該網絡包括兩個卷積層和兩個線性層,最后輸出10個維度,即代表0-9十個數字。

class ConvNet(nn.Module):  def __init__(self):    super().__init__()    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24)     self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)    self.fc1 = nn.Linear(20*10*10,500)    self.fc2 = nn.Linear(500,10)  def forward(self,x):    in_size = x.size(0)    out = self.conv1(x)    out = F.relu(out)    out = F.max_pool2d(out, 2, 2)     out = self.conv2(out)    out = F.relu(out)    out = out.view(in_size,-1)    out = self.fc1(out)    out = F.relu(out)    out = self.fc2(out)    out = F.log_softmax(out,dim=1)    return out

實例化網絡

model = ConvNet().to(DEVICE) # 將網絡移動到gpu上optimizer = optim.Adam(model.parameters()) # 使用Adam優化器

定義訓練函數

def train(model, device, train_loader, optimizer, epoch):  model.train()  for batch_idx, (data, target) in enumerate(train_loader):    data, target = data.to(device), target.to(device)    optimizer.zero_grad()    output = model(data)    loss = F.nll_loss(output, target)    loss.backward()    optimizer.step()    if(batch_idx+1)%30 == 0:       print('Train Epoch: {} [{}/{} ({:.0f}%)]/tLoss: {:.6f}'.format(        epoch, batch_idx * len(data), len(train_loader.dataset),        100. * batch_idx / len(train_loader), loss.item()))            
發表評論 共有條評論
用戶名: 密碼:
驗證碼: 匿名發表
主站蜘蛛池模板: 滨海县| 思南县| 那曲县| 保定市| 綦江县| 泗阳县| 淮滨县| 金华市| 丹巴县| 崇阳县| 云林县| 龙游县| 弥渡县| 额济纳旗| 怀安县| 哈巴河县| 邹城市| 南平市| 全椒县| 英山县| 南城县| 延庆县| 枣阳市| 安岳县| 潞城市| 乌兰察布市| 清原| 商城县| 保山市| 永胜县| 高安市| 台湾省| 交口县| 梁平县| 香格里拉县| 思南县| 宜春市| 连城县| 海城市| 吴川市| 乌兰浩特市|