# python: 3.6
# encoding: utf-8
import torch
import torch.nn as nn
# import torch.nn.functional as F
import fastNLP.modules.encoder as encoder
[docs]class CNNText(torch.nn.Module):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence
Classification.'
"""
def __init__(self, embed_num,
embed_dim,
num_classes,
kernel_nums=(3, 4, 5),
kernel_sizes=(3, 4, 5),
padding=0,
dropout=0.5):
super(CNNText, self).__init__()
# no support for pre-trained embedding currently
self.embed = encoder.Embedding(embed_num, embed_dim)
self.conv_pool = encoder.ConvMaxpool(
in_channels=embed_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes,
padding=padding)
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)
[docs] def forward(self, word_seq):
"""
:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return output: dict of torch.LongTensor, [batch_size, num_classes]
"""
x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return {'pred': x}
[docs] def predict(self, word_seq):
"""
:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
"""
output = self(word_seq)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}