import torch.nn as nn
from fastNLP.modules.utils import initial_parameter
[docs]class LSTM(nn.Module):
"""Long Short Term Memory
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
"""
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.get_hidden = get_hidden
initial_parameter(self, initial_method)
[docs] def forward(self, x, h0=None, c0=None):
if h0 is not None and c0 is not None:
x, (ht, ct) = self.lstm(x, (h0, c0))
else:
x, (ht, ct) = self.lstm(x)
if self.get_hidden:
return x, (ht, ct)
else:
return x
if __name__ == "__main__":
lstm = LSTM(10)