import torch
from torch import nn
import math
from fastNLP.modules.utils import mask_softmax
[docs]class Attention(torch.nn.Module):
def __init__(self, normalize=False):
super(Attention, self).__init__()
self.normalize = normalize
[docs] def forward(self, query, memory, mask):
similarities = self._atten_forward(query, memory)
if self.normalize:
return mask_softmax(similarities, mask)
return similarities
def _atten_forward(self, query, memory):
raise NotImplementedError
[docs]class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)
[docs] def forward(self, Q, K, V, seq_mask=None):
"""
:param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size]
:param seq_mask: [batch, seq_len]
"""
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if seq_mask is not None:
output.masked_fill_(seq_mask.lt(1), -float('inf'))
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)
[docs]class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):
out_feat = key_size if (i % 3) != 2 else value_size
self.in_linear.append(nn.Linear(input_size, out_feat))
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
self.out_linear = nn.Linear(value_size * num_atte, output_size)
[docs] def forward(self, Q, K, V, seq_mask=None):
heads = []
for i in range(len(self.attes)):
j = i * 3
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
headi = self.attes[i](qi, ki, vi, seq_mask)
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)