Source code for fastNLP.modules.encoder.transformer
from torch import nn
from ..aggregator.attention import MultiHeadAtte
from ..dropout import TimestepDropout
[docs]class TransformerEncoder(nn.Module):
[docs] class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
"""
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。
:param inner_size: int, FFN层的hidden大小
:param key_size: int, 每个head的维度大小。
:param value_size: int,每个head中value的维度。
:param num_head: int,head的数量。
:param dropout: float。
"""
super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(),
nn.Linear(inner_size, model_size),
TimestepDropout(dropout),)
self.norm2 = nn.LayerNorm(model_size)
[docs] def forward(self, input, seq_mask=None, atte_mask_out=None):
"""
:param input: [batch, seq_len, model_size]
:param seq_mask: [batch, seq_len]
:return: [batch, seq_len, model_size]
"""
attention = self.atte(input, input, input, atte_mask_out)
norm_atte = self.norm1(attention + input)
attention *= seq_mask
output = self.ffn(norm_atte)
output = self.norm2(output + norm_atte)
output *= seq_mask
return output
def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
[docs] def forward(self, x, seq_mask=None):
output = x
if seq_mask is None:
atte_mask_out = None
else:
atte_mask_out = (seq_mask < 1)[:,None,:]
seq_mask = seq_mask[:,:,None]
for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out)
return output