Source code for fastNLP.modules.aggregator.attention

import math

import torch
import torch.nn.functional as F
from torch import nn

from fastNLP.modules.utils import mask_softmax


[docs]class Attention(torch.nn.Module): def __init__(self, normalize=False): super(Attention, self).__init__() self.normalize = normalize
[docs] def forward(self, query, memory, mask): similarities = self._atten_forward(query, memory) if self.normalize: return mask_softmax(similarities, mask) return similarities
def _atten_forward(self, query, memory): raise NotImplementedError
[docs]class DotAtte(nn.Module): def __init__(self, key_size, value_size): super(DotAtte, self).__init__() self.key_size = key_size self.value_size = value_size self.scale = math.sqrt(key_size)
[docs] def forward(self, Q, K, V, seq_mask=None): """ :param Q: [batch, seq_len, key_size] :param K: [batch, seq_len, key_size] :param V: [batch, seq_len, value_size] :param seq_mask: [batch, seq_len] """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if seq_mask is not None: output.masked_fill_(seq_mask.lt(1), -float('inf')) output = nn.functional.softmax(output, dim=2) return torch.matmul(output, V)
[docs]class MultiHeadAtte(nn.Module): def __init__(self, input_size, output_size, key_size, value_size, num_atte): super(MultiHeadAtte, self).__init__() self.in_linear = nn.ModuleList() for i in range(num_atte * 3): out_feat = key_size if (i % 3) != 2 else value_size self.in_linear.append(nn.Linear(input_size, out_feat)) self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) self.out_linear = nn.Linear(value_size * num_atte, output_size)
[docs] def forward(self, Q, K, V, seq_mask=None): heads = [] for i in range(len(self.attes)): j = i * 3 qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) headi = self.attes[i](qi, ki, vi, seq_mask) heads.append(headi) output = torch.cat(heads, dim=2) return self.out_linear(output)
[docs]class Bi_Attention(nn.Module): def __init__(self): super(Bi_Attention, self).__init__() self.inf = 10e12
[docs] def forward(self, in_x1, in_x2, x1_len, x2_len): # in_x1: [batch_size, x1_seq_len, hidden_size] # in_x2: [batch_size, x2_seq_len, hidden_size] # x1_len: [batch_size, x1_seq_len] # x2_len: [batch_size, x2_seq_len] assert in_x1.size()[0] == in_x2.size()[0] assert in_x1.size()[2] == in_x2.size()[2] # The batch size and hidden size must be equal. assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] # The seq len in in_x and x_len must be equal. assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] batch_size = in_x1.size()[0] x1_max_len = in_x1.size()[1] x2_max_len = in_x2.size()[1] in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] a_mask = a_mask.view(batch_size, x1_max_len, -1) a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] b_mask = x2_len.le(0.5).float() * -self.inf b_mask = b_mask.view(batch_size, -1, x2_max_len) b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] attention_b_t = torch.transpose(attention_b, 1, 2) out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] return out_x1, out_x2