Source code for fastNLP.modules.encoder.transformer

from torch import nn

from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization


[docs]class TransformerEncoder(nn.Module):
[docs] class SubLayer(nn.Module): def __init__(self, input_size, output_size, key_size, value_size, num_atte): super(TransformerEncoder.SubLayer, self).__init__() self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) self.norm1 = LayerNormalization(output_size) self.ffn = nn.Sequential(nn.Linear(output_size, output_size), nn.ReLU(), nn.Linear(output_size, output_size)) self.norm2 = LayerNormalization(output_size)
[docs] def forward(self, input, seq_mask): attention = self.atte(input) norm_atte = self.norm1(attention + input) output = self.ffn(norm_atte) return self.norm2(output + norm_atte)
def __init__(self, num_layers, **kargs): super(TransformerEncoder, self).__init__() self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])
[docs] def forward(self, x, seq_mask=None): return self.layers(x, seq_mask)