1、保存整個網絡結構信息和模型參數信息:
torch.save(model_object, './model.pth')
直接加載即可使用:
model = torch.load('./model.pth')
2、只保存網絡的模型參數-推薦使用
torch.save(model_object.state_dict(), './params.pth')
加載則要先從本地網絡模塊導入網絡,然后再加載參數:
from models import AgeModelmodel = AgeModel()model.load_state_dict(torch.load('./params.pth'))
以上這篇pytorch模型存儲的2種實現方法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答