import torch
from fastNLP.io.base_loader import BaseLoader
[docs]class ModelLoader(BaseLoader):
"""
Loader for models.
"""
def __init__(self):
super(ModelLoader, self).__init__()
[docs] @staticmethod
def load_pytorch(empty_model, model_path):
"""Load model parameters from ".pkl" files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param str model_path: the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))
[docs] @staticmethod
def load_pytorch_model(model_path):
"""Load the entire model.
:param str model_path: the path to the saved model.
"""
return torch.load(model_path)
[docs]class ModelSaver(object):
"""Save a model
:param str save_path: the path to the saving directory.
Example::
saver = ModelSaver("./save/model_ckpt_100.pkl")
saver.save_pytorch(model)
"""
def __init__(self, save_path):
self.save_path = save_path
[docs] def save_pytorch(self, model, param_only=True):
"""Save a pytorch model into ".pkl" file.
:param model: a PyTorch model
:param bool param_only: whether only to save the model parameters or the entire model.
"""
if param_only is True:
torch.save(model.state_dict(), self.save_path)
else:
torch.save(model, self.save_path)