保存模型
保存模型僅僅是為了測試的時(shí)候,只需要
torch.save(model.state_dict, path)
path 為保存的路徑
但是有時(shí)候模型及數(shù)據(jù)太多,難以一次性訓(xùn)練完的時(shí)候,而且用的還是 Adam優(yōu)化器的時(shí)候, 一定要保存好訓(xùn)練的優(yōu)化器參數(shù)以及epoch
state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } torch.save(state, path)
因?yàn)檫@里
def adjust_learning_rate(optimizer, epoch): lr_t = lr lr_t = lr_t * (0.3 ** (epoch // 2)) for param_group in optimizer.param_groups: param_group['lr'] = lr_t
學(xué)習(xí)率是根據(jù)epoch變化的, 如果不保存epoch的話,基本上每次都從epoch為0開始訓(xùn)練,這樣學(xué)習(xí)率就相當(dāng)于不變了!!
恢復(fù)模型
恢復(fù)模型只用于測試的時(shí)候,
model.load_state_dict(torch.load(path))
path為之前存儲(chǔ)模型時(shí)的路徑
但是如果是用于繼續(xù)訓(xùn)練的話,
checkpoint = torch.load(path)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']+1
依次恢復(fù)出模型 優(yōu)化器參數(shù)以及epoch
以上這篇Pytorch保存模型用于測試和用于繼續(xù)訓(xùn)練的區(qū)別詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持武林網(wǎng)之家。
新聞熱點(diǎn)
疑難解答
圖片精選