本文主要介紹了pytorch cnn 識(shí)別手寫的字實(shí)現(xiàn)自建圖片數(shù)據(jù),分享給大家,具體如下:
# library# standard libraryimport os # third-party libraryimport torchimport torch.nn as nnfrom torch.autograd import Variablefrom torch.utils.data import Dataset, DataLoaderimport torchvisionimport matplotlib.pyplot as pltfrom PIL import Imageimport numpy as np# torch.manual_seed(1)  # reproducible # Hyper ParametersEPOCH = 1        # train the training data n times, to save time, we just train 1 epochBATCH_SIZE = 50LR = 0.001       # learning rate  root = "./mnist/raw/" def default_loader(path):  # return Image.open(path).convert('RGB')  return Image.open(path) class MyDataset(Dataset):  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):    fh = open(txt, 'r')    imgs = []    for line in fh:      line = line.strip('/n')      line = line.rstrip()      words = line.split()      imgs.append((words[0], int(words[1])))    self.imgs = imgs    self.transform = transform    self.target_transform = target_transform    self.loader = loader    fh.close()  def __getitem__(self, index):    fn, label = self.imgs[index]    img = self.loader(fn)    img = Image.fromarray(np.array(img), mode='L')    if self.transform is not None:      img = self.transform(img)    return img,label  def __len__(self):    return len(self.imgs) train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE) class CNN(nn.Module):  def __init__(self):    super(CNN, self).__init__()    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)      nn.Conv2d(        in_channels=1,       # input height        out_channels=16,      # n_filters        kernel_size=5,       # filter size        stride=1,          # filter movement/step        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1      ),               # output shape (16, 28, 28)      nn.ReLU(),           # activation      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)    )    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)      nn.ReLU(),           # activation      nn.MaxPool2d(2),        # output shape (32, 7, 7)    )    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes   def forward(self, x):    x = self.conv1(x)    x = self.conv2(x)    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)    output = self.out(x)    return output, x  # return x for visualization cnn = CNN()print(cnn) # net architecture optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parametersloss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted  # training and testingfor epoch in range(EPOCH):  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader    b_x = Variable(x)  # batch x    b_y = Variable(y)  # batch y     output = cnn(b_x)[0]        # cnn output    loss = loss_func(output, b_y)  # cross entropy loss    optimizer.zero_grad()      # clear gradients for this training step    loss.backward()         # backpropagation, compute gradients    optimizer.step()        # apply gradients     if step % 50 == 0:      cnn.eval()      eval_loss = 0.      eval_acc = 0.      for i, (tx, ty) in enumerate(test_loader):        t_x = Variable(tx)        t_y = Variable(ty)        output = cnn(t_x)[0]        loss = loss_func(output, t_y)        eval_loss += loss.data[0]        pred = torch.max(output, 1)[1]        num_correct = (pred == t_y).sum()        eval_acc += float(num_correct.data[0])      acc_rate = eval_acc / float(len(test_data))      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))            
新聞熱點(diǎn)
疑難解答
圖片精選