import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import copy
import numpy as np
import torch
from collections import defaultdict
from torch import nn
from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.utils import seq_mask
[docs]def mst(scores):
"""
with some modification to support parser output for MST decoding
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = scores.min() - 1
eye = np.eye(length)
scores = scores * (1 - eye) + min_score * eye
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
roots = np.where(heads[tokens] == 0)[0] + 1
if len(roots) < 1:
root_scores = scores[tokens, 0]
head_scores = scores[tokens, heads[tokens]]
new_root = tokens[np.argmax(root_scores / head_scores)]
heads[new_root] = 0
elif len(roots) > 1:
root_scores = scores[roots, 0]
scores[roots, 0] = 0
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1
new_root = roots[np.argmin(
scores[roots, new_heads] / root_scores)]
heads[roots] = new_heads
heads[new_root] = 0
edges = defaultdict(set)
vertices = set((0,))
for dep, head in enumerate(heads[tokens]):
vertices.add(dep + 1)
edges[head].add(dep + 1)
for cycle in _find_cycle(vertices, edges):
dependents = set()
to_visit = set(cycle)
while len(to_visit) > 0:
node = to_visit.pop()
if node not in dependents:
dependents.add(node)
to_visit.update(edges[node])
cycle = np.array(list(cycle))
old_heads = heads[cycle]
old_scores = scores[cycle, old_heads]
non_heads = np.array(list(dependents))
scores[np.repeat(cycle, len(non_heads)),
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1
new_scores = scores[cycle, new_heads] / old_scores
change = np.argmax(new_scores)
changed_cycle = cycle[change]
old_head = old_heads[change]
new_head = new_heads[change]
heads[changed_cycle] = new_head
edges[new_head].add(changed_cycle)
edges[old_head].remove(changed_cycle)
return heads
def _find_cycle(vertices, edges):
"""
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py
"""
_index = 0
_stack = []
_indices = {}
_lowlinks = {}
_onstack = defaultdict(lambda: False)
_SCCs = []
def _strongconnect(v):
nonlocal _index
_indices[v] = _index
_lowlinks[v] = _index
_index += 1
_stack.append(v)
_onstack[v] = True
for w in edges[v]:
if w not in _indices:
_strongconnect(w)
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w])
elif _onstack[w]:
_lowlinks[v] = min(_lowlinks[v], _indices[w])
if _lowlinks[v] == _indices[v]:
SCC = set()
while True:
w = _stack.pop()
_onstack[w] = False
SCC.add(w)
if not(w != v):
break
_SCCs.append(SCC)
for v in vertices:
if v not in _indices:
_strongconnect(v)
return [SCC for SCC in _SCCs if len(SCC) > 1]
[docs]class GraphParser(BaseModel):
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
"""
def __init__(self):
super(GraphParser, self).__init__()
[docs] def forward(self, x):
raise NotImplementedError
def _greedy_decoder(self, arc_matrix, mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (mask == 0).byte()
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:
heads *= mask.long()
return heads
def _mst_decoder(self, arc_matrix, mask=None):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device)
mask[batch_idx, lens-1] = 0
for i, graph in enumerate(matrix):
len_i = lens[i]
if len_i == seq_len:
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
else:
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device)
if mask is not None:
ans *= mask.long()
return ans
[docs]class ArcBiaffine(nn.Module):
"""helper module for Biaffine Dependency Parser predicting arc
"""
def __init__(self, hidden_size, bias=True):
super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
self.has_bias = bias
if self.has_bias:
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True)
else:
self.register_parameter("bias", None)
initial_parameter(self)
[docs] def forward(self, head, dep):
"""
:param head arc-head tensor = [batch, length, emb_dim]
:param dep arc-dependent tensor = [batch, length, emb_dim]
:return output tensor = [bacth, length, length]
"""
output = dep.matmul(self.U)
output = output.bmm(head.transpose(-1, -2))
if self.has_bias:
output += head.matmul(self.bias).unsqueeze(1)
return output
[docs]class LabelBilinear(nn.Module):
"""helper module for Biaffine Dependency Parser predicting label
"""
def __init__(self, in1_features, in2_features, num_label, bias=True):
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False)
[docs] def forward(self, x1, x2):
output = self.bilinear(x1, x2)
output += self.lin(torch.cat([x1, x2], dim=2))
return output
[docs]class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
<https://arxiv.org/abs/1611.01734>`_ .
"""
def __init__(self,
word_vocab_size,
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
word_hid_dim,
pos_hid_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
label_mlp_size,
num_label,
dropout,
use_var_lstm=False,
use_greedy_infer=False):
super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim)
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
else:
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.LayerNorm(arc_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.LayerNorm(label_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
self.reset_parameters()
self.explore_p = 0.2
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 0.1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.1)
[docs] def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks
:param gold_heads: [batch_size, seq_len] sequence of golden heads
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
label_pred: [batch_size, seq_len, seq_len]
mask: [batch_size, seq_len]
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
"""
# prepare embeddings
device = self.parameters().__next__().device
word_seq = word_seq.long().to(device)
pos_seq = pos_seq.long().to(device)
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1)
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))
# get sequence mask
mask = seq_mask(word_seq_origin_len, seq_len).long()
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
del word, pos
# lstm, extract features
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True)
x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx]
# for arc biaffine
# mlp, reduce dim
arc_dep = self.arc_dep_mlp(feat)
arc_head = self.arc_head_mlp(feat)
label_dep = self.label_dep_mlp(feat)
label_head = self.label_head_mlp(feat)
del feat
# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
# use gold or predicted arc to predict label
if gold_heads is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, mask)
else:
heads = self._mst_decoder(arc_pred, mask)
head_pred = heads
else:
assert self.training # must be training mode
if torch.rand(1).item() < self.explore_p:
heads = self._greedy_decoder(arc_pred, mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
if head_pred is not None:
res_dict['head_pred'] = head_pred
return res_dict
[docs] def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
"""
Compute loss.
:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, n_tags]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param mask: [batch_size, seq_len]
:return: loss value
"""
batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0)
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred)
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]
arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]
float_mask = mask[:, 1:].float()
arc_nll = -(arc_loss*float_mask).mean()
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll
[docs] def predict(self, word_seq, pos_seq, word_seq_origin_len):
"""
:param word_seq:
:param pos_seq:
:param word_seq_origin_len:
:return: head_pred: [B, L]
label_pred: [B, L]
seq_len: [B,]
"""
res = self(word_seq, pos_seq, word_seq_origin_len)
output = {}
output['head_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
return output