Source code for fastNLP.modules.utils

import torch
import torch.nn as nn
import torch.nn.init as init


def mask_softmax(matrix, mask):
    if mask is None:
        result = torch.nn.functional.softmax(matrix, dim=-1)
    else:
        raise NotImplementedError
    return result


[docs]def initial_parameter(net, initial_method=None): """A method used to initialize the weights of PyTorch models. :param net: a PyTorch model :param initial_method: str, one of the following initializations - xavier_uniform - xavier_normal (default) - kaiming_normal, or msra - kaiming_uniform - orthogonal - sparse - normal - uniform """ if initial_method == 'xavier_uniform': init_method = init.xavier_uniform_ elif initial_method == 'xavier_normal': init_method = init.xavier_normal_ elif initial_method == 'kaiming_normal' or initial_method == 'msra': init_method = init.kaiming_normal_ elif initial_method == 'kaiming_uniform': init_method = init.kaiming_uniform_ elif initial_method == 'orthogonal': init_method = init.orthogonal_ elif initial_method == 'sparse': init_method = init.sparse_ elif initial_method == 'normal': init_method = init.normal_ elif initial_method == 'uniform': init_method = init.uniform_ else: init_method = init.xavier_normal_ def weights_init(m): # classname = m.__class__.__name__ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn if initial_method is not None: init_method(m.weight.data) else: init.xavier_normal_(m.weight.data) init.normal_(m.bias.data) elif isinstance(m, nn.LSTM): for w in m.parameters(): if len(w.data.size()) > 1: init_method(w.data) # weight else: init.normal_(w.data) # bias elif hasattr(m, 'weight') and m.weight.requires_grad: init_method(m.weight.data) else: for w in m.parameters(): if w.requires_grad: if len(w.data.size()) > 1: init_method(w.data) # weight else: init.normal_(w.data) # bias # print("init else") net.apply(weights_init)
[docs]def seq_mask(seq_len, max_len): """Create sequence mask. :param seq_len: list or torch.Tensor, the lengths of sequences in a batch. :param max_len: int, the maximum sequence length in a batch. :return mask: torch.LongTensor, [batch_size, max_len] """ if not isinstance(seq_len, torch.Tensor): seq_len = torch.LongTensor(seq_len) seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len]