Source code for fastNLP.core.optimizer

import torch


[docs]class Optimizer(object): """ :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param kwargs: additional parameters. """ def __init__(self, model_params, **kwargs): if model_params is not None and not hasattr(model_params, "__next__"): raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) self.model_params = model_params self.settings = kwargs
[docs]class SGD(Optimizer): """ :param float lr: learning rate. Default: 0.01 :param float momentum: momentum. Default: 0 :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. """ def __init__(self, lr=0.001, momentum=0, model_params=None): if not isinstance(lr, float): raise TypeError("learning rate has to be float.") super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): if self.model_params is None: # careful! generator cannot be assigned. return torch.optim.SGD(model_params, **self.settings) else: return torch.optim.SGD(self.model_params, **self.settings)
[docs]class Adam(Optimizer): """ :param float lr: learning rate :param float weight_decay: :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. """ def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): if not isinstance(lr, float): raise TypeError("learning rate has to be float.") super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): if self.model_params is None: # careful! generator cannot be assigned. return torch.optim.Adam(model_params, **self.settings) else: return torch.optim.Adam(self.model_params, **self.settings)