Source code for fastNLP.io.model_io

import torch

from fastNLP.io.base_loader import BaseLoader


[docs]class ModelLoader(BaseLoader): """ Loader for models. """ def __init__(self): super(ModelLoader, self).__init__()
[docs] @staticmethod def load_pytorch(empty_model, model_path): """Load model parameters from ".pkl" files into the empty PyTorch model. :param empty_model: a PyTorch model with initialized parameters. :param str model_path: the path to the saved model. """ empty_model.load_state_dict(torch.load(model_path))
[docs] @staticmethod def load_pytorch_model(model_path): """Load the entire model. :param str model_path: the path to the saved model. """ return torch.load(model_path)
[docs]class ModelSaver(object): """Save a model :param str save_path: the path to the saving directory. Example:: saver = ModelSaver("./save/model_ckpt_100.pkl") saver.save_pytorch(model) """ def __init__(self, save_path): self.save_path = save_path
[docs] def save_pytorch(self, model, param_only=True): """Save a pytorch model into ".pkl" file. :param model: a PyTorch model :param bool param_only: whether only to save the model parameters or the entire model. """ if param_only is True: torch.save(model.state_dict(), self.save_path) else: torch.save(model, self.save_path)