Source code for fastNLP.modules.encoder.lstm

import torch.nn as nn

from fastNLP.modules.utils import initial_parameter


[docs]class LSTM(nn.Module): """Long Short Term Memory Args: input_size : input size hidden_size : hidden size num_layers : number of hidden layers. Default: 1 dropout : dropout rate. Default: 0.5 bidirectional : If True, becomes a bidirectional RNN. Default: False. """ def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, bidirectional=False, bias=True, initial_method=None, get_hidden=False): super(LSTM, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) self.get_hidden = get_hidden initial_parameter(self, initial_method)
[docs] def forward(self, x, h0=None, c0=None): if h0 is not None and c0 is not None: x, (ht, ct) = self.lstm(x, (h0, c0)) else: x, (ht, ct) = self.lstm(x) if self.get_hidden: return x, (ht, ct) else: return x
if __name__ == "__main__": lstm = LSTM(10)