import torch
[docs]class Optimizer(object):
"""
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
:param kwargs: additional parameters.
"""
def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs
[docs]class SGD(Optimizer):
"""
:param float lr: learning rate. Default: 0.01
:param float momentum: momentum. Default: 0
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
"""
def __init__(self, lr=0.001, momentum=0, model_params=None):
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
def construct_from_pytorch(self, model_params):
if self.model_params is None:
# careful! generator cannot be assigned.
return torch.optim.SGD(model_params, **self.settings)
else:
return torch.optim.SGD(self.model_params, **self.settings)
[docs]class Adam(Optimizer):
"""
:param float lr: learning rate
:param float weight_decay:
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
"""
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
weight_decay=weight_decay)
def construct_from_pytorch(self, model_params):
if self.model_params is None:
# careful! generator cannot be assigned.
return torch.optim.Adam(model_params, **self.settings)
else:
return torch.optim.Adam(self.model_params, **self.settings)