Source code for fastNLP.modules.aggregator.kmax_pool

# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn
# import torch.nn.functional as F


[docs]class KMaxPool(nn.Module): """K max-pooling module.""" def __init__(self, k=1): super(KMaxPool, self).__init__() self.k = k
[docs] def forward(self, x): # [N,C,L] -> [N,C*k] x, index = torch.topk(x, self.k, dim=-1, sorted=False) x = torch.reshape(x, (x.size(0), -1)) return x