Source code for fastNLP.modules.encoder.variational_rnn

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from fastNLP.modules.utils import initial_parameter

try:
    from torch import flip
except ImportError:
    def flip(x, dims):
        indices = [slice(None)] * x.dim()
        for dim in dims:
            indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
        return x[tuple(indices)]


[docs]class VarRnnCellWrapper(nn.Module): """Wrapper for normal RNN Cells, make it support variational dropout """ def __init__(self, cell, hidden_size, input_p, hidden_p): super(VarRnnCellWrapper, self).__init__() self.cell = cell self.hidden_size = hidden_size self.input_p = input_p self.hidden_p = hidden_p
[docs] def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): """ :param PackedSequence input_x: [seq_len, batch_size, input_size] :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] for other RNN, h_0, [batch_size, hidden_size] :param mask_x: [batch_size, input_size] dropout mask for input :param mask_h: [batch_size, hidden_size] dropout mask for hidden :return PackedSequence output: [seq_len, bacth_size, hidden_size] hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] for other RNN, h_n, [batch_size, hidden_size] """ def get_hi(hi, h0, size): h0_size = size - hi.size(0) if h0_size > 0: return torch.cat([hi, h0[:h0_size]], dim=0) return hi[:size] is_lstm = isinstance(hidden, tuple) input, batch_sizes = input_x output = [] cell = self.cell if is_reversed: batch_iter = flip(batch_sizes, [0]) idx = input.size(0) else: batch_iter = batch_sizes idx = 0 if is_lstm: hn = (hidden[0].clone(), hidden[1].clone()) else: hn = hidden.clone() hi = hidden for size in batch_iter: if is_reversed: input_i = input[idx-size: idx] * mask_x[:size] idx -= size else: input_i = input[idx: idx+size] * mask_x[:size] idx += size mask_hi = mask_h[:size] if is_lstm: hx, cx = hi hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) hi = cell(input_i, hi) hn[0][:size] = hi[0] hn[1][:size] = hi[1] output.append(hi[0]) else: hi = get_hi(hi, hidden, size) * mask_hi hi = cell(input_i, hi) hn[:size] = hi output.append(hi) if is_reversed: output = list(reversed(output)) output = torch.cat(output, dim=0) return PackedSequence(output, batch_sizes), hn
[docs]class VarRNNBase(nn.Module): """Implementation of Variational Dropout RNN network. refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) https://arxiv.org/abs/1512.05287`. """ def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, input_dropout=0, hidden_dropout=0, bidirectional=False): super(VarRNNBase, self).__init__() self.mode = mode self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.input_dropout = input_dropout self.hidden_dropout = hidden_dropout self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1 self._all_cells = nn.ModuleList() for layer in range(self.num_layers): for direction in range(self.num_directions): input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions cell = Cell(input_size, self.hidden_size, bias) self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) initial_parameter(self) self.is_lstm = (self.mode == "LSTM") def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): is_lstm = self.is_lstm idx = self.num_directions * n_layer + n_direction cell = self._all_cells[idx] hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) return output_x, hidden_x
[docs] def forward(self, input, hx=None): is_lstm = self.is_lstm is_packed = isinstance(input, PackedSequence) if not is_packed: seq_len = input.size(1) if self.batch_first else input.size(0) max_batch_size = input.size(0) if self.batch_first else input.size(1) seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) else: max_batch_size = int(input.batch_sizes[0]) input, batch_sizes = input if hx is None: hx = input.new_zeros(self.num_layers * self.num_directions, max_batch_size, self.hidden_size, requires_grad=True) if is_lstm: hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) mask_x = input.new_ones((max_batch_size, self.input_size)) mask_out = input.new_ones((max_batch_size, self.hidden_size * self.num_directions)) mask_h_ones = input.new_ones((max_batch_size, self.hidden_size)) nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) hidden = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) if is_lstm: cellstate = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) for layer in range(self.num_layers): output_list = [] input_seq = PackedSequence(input, batch_sizes) mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) for direction in range(self.num_directions): output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, mask_x if layer == 0 else mask_out, mask_h) output_list.append(output_x.data) idx = self.num_directions * layer + direction if is_lstm: hidden[idx] = hidden_x[0] cellstate[idx] = hidden_x[1] else: hidden[idx] = hidden_x input = torch.cat(output_list, dim=-1) if is_lstm: hidden = (hidden, cellstate) if is_packed: output = PackedSequence(input, batch_sizes) else: input = PackedSequence(input, batch_sizes) output, _ = pad_packed_sequence(input, batch_first=self.batch_first) return output, hidden
[docs]class VarLSTM(VarRNNBase): """Variational Dropout LSTM. """ def __init__(self, *args, **kwargs): super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)
[docs]class VarRNN(VarRNNBase): """Variational Dropout RNN. """ def __init__(self, *args, **kwargs): super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)
[docs]class VarGRU(VarRNNBase): """Variational Dropout GRU. """ def __init__(self, *args, **kwargs): super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)
# if __name__ == '__main__': # x = torch.Tensor([[1,2,3], [4,5,0], [6,0,0]])[:,:,None] * 0.1 # mask = (x != 0).float().view(3, -1) # seq_lens = torch.LongTensor([3,2,1]) # y = torch.Tensor([[0,1,1], [1,1,0], [0,0,0]]) # # rev = _reverse_packed_sequence(pack) # # # print(rev) # lstm = VarLSTM(input_size=1, num_layers=2, hidden_size=2, # batch_first=True, bidirectional=True, # input_dropout=0.0, hidden_dropout=0.0,) # # lstm = nn.LSTM(input_size=1, num_layers=2, hidden_size=2, # # batch_first=True, bidirectional=True,) # loss = nn.BCELoss() # m = nn.Sigmoid() # optim = torch.optim.SGD(lstm.parameters(), lr=1e-3) # for i in range(2000): # optim.zero_grad() # pack = pack_padded_sequence(x, seq_lens, batch_first=True) # out, hidden = lstm(pack) # out, lens = pad_packed_sequence(out, batch_first=True) # # print(lens) # # print(out) # # print(hidden[0]) # # print(hidden[0].size()) # # print(hidden[1]) # out = out.sum(-1) # out = m(out) * mask # l = loss(out, y) # l.backward() # optim.step() # if i % 50 == 0: # print(out)