一、PyTorch批訓練
1. 概述
PyTorch提供了一種將數(shù)據(jù)包裝起來進行批訓練的工具——DataLoader。使用的時候,只需要將我們的數(shù)據(jù)首先轉換為torch的tensor形式,再轉換成torch可以識別的Dataset格式,然后將Dataset放入DataLoader中就可以啦。
import torch import torch.utils.data as Data torch.manual_seed(1) # 設定隨機數(shù)種子 BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(0.5, 5, 10) # 將數(shù)據(jù)轉換為torch的dataset格式 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 將torch_dataset置入Dataloader中 loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, # 批大小 # 若dataset中的樣本數(shù)不能被batch_size整除的話,最后剩余多少就使用多少 shuffle=True, # 是否隨機打亂順序 num_workers=2, # 多線程讀取數(shù)據(jù)的線程數(shù) ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('Epoch:', epoch, '|Step:', step, '|batch_x:', batch_x.numpy(), '|batch_y', batch_y.numpy()) ''''' shuffle=True Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3. 3.5 1. 1.5 0.5] Epoch: 0 |Step: 1 |batch_x: [ 9. 10. 4. 8. 5.] |batch_y [ 4.5 5. 2. 4. 2.5] Epoch: 1 |Step: 0 |batch_x: [ 3. 4. 2. 9. 10.] |batch_y [ 1.5 2. 1. 4.5 5. ] Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4. 2.5 3. ] Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1. 3. 3.5] Epoch: 2 |Step: 1 |batch_x: [ 10. 4. 8. 1. 5.] |batch_y [ 5. 2. 4. 0.5 2.5] shuffle=False Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 0 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 1 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 2 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] ''' 2. TensorDataset
classtorch.utils.data.TensorDataset(data_tensor, target_tensor)
TensorDataset類用來將樣本及其標簽打包成torch的Dataset,data_tensor,和target_tensor都是tensor。
3. DataLoader
代碼如下:classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=<function default_collate>, pin_memory=False,drop_last=False)
dataset就是Torch的Dataset格式的對象;batch_size即每批訓練的樣本數(shù)量,默認為;shuffle表示是否需要隨機取樣本;num_workers表示讀取樣本的線程數(shù)。
二、PyTorch的Optimizer優(yōu)化器
本實驗中,首先構造一組數(shù)據(jù)集,轉換格式并置于DataLoader中,備用。定義一個固定結構的默認神經網絡,然后為每個優(yōu)化器構建一個神經網絡,每個神經網絡的區(qū)別僅僅是優(yōu)化器不同。通過記錄訓練過程中的loss值,最后在圖像上呈現(xiàn)得到各個優(yōu)化器的優(yōu)化過程。
新聞熱點
疑難解答