import inspect
from collections import defaultdict
import numpy as np
import torch
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.vocabulary import Vocabulary
[docs]class MetricBase(object):
"""Base class for all metrics.
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``.
``pred_dict`` is the output of ``forward()`` or prediction function of a model.
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``.
``MetricBase`` will do the following type checks:
1. whether self.evaluate has varargs, which is not supported.
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``.
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``.
4. whether params in ``pred_dict``, ``target_dict`` are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.)
However, in some cases where type check is not necessary, ``_fast_param_map`` will be used.
"""
def __init__(self):
self.param_map = {} # key is param in function, value is input param.
self._checked = False
def evaluate(self, *args, **kwargs):
raise NotImplementedError
def _init_param_map(self, key_map=None, **kwargs):
"""Check the validity of key_map and other param map. Add these into self.param_map
:param key_map: dict
:param kwargs:
:return: None
"""
value_counter = defaultdict(set)
if key_map is not None:
if not isinstance(key_map, dict):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if not isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
# check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.")
def get_metric(self, reset=True):
raise NotImplemented
def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
return fast_param
return fast_param
def __call__(self, pred_dict, target_dict):
"""
This method will call self.evaluate method.
Before calling self.evaluate, it will first check the validity of output_dict, target_dict
(1) whether params needed by self.evaluate is not included in output_dict,target_dict.
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.)
This function also support _fast_param_map.
:param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param:
self.evaluate(**fast_param)
return
if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)
# missing
if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
# only check missing.
# replace missing.
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)
self.evaluate(**refined_args)
self._checked = True
return
[docs]class AccuracyMetric(MetricBase):
"""Accuracy Metric
"""
def __init__(self, pred=None, target=None, seq_lens=None):
super().__init__()
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self.total = 0
self.acc_count = 0
def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
targets = list(target_dict.values())
if len(targets) == 1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict) == 1:
pred = list(pred_dict.values())[0]
fast_param['pred'] = pred
elif len(pred_dict) == 2:
pred1 = list(pred_dict.values())[0]
pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1:
seq_lens = pred1
pred = pred2
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1:
seq_lens = pred2
pred = pred1
else:
return fast_param
fast_param['pred'] = pred
fast_param['seq_lens'] = seq_lens
else:
return fast_param
fast_param['target'] = targets[0]
# TODO need to make sure they all have same batch_size
return fast_param
[docs] def evaluate(self, pred, target, seq_lens=None):
"""
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
"""
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)
else:
masks = None
if pred.size() == target.size():
pass
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")
pred = pred.float()
target = target.float()
if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item()
else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size()))
[docs] def get_metric(self, reset=True):
"""Returns computed metric.
:param bool reset: whether to recount next time.
:return evaluate_result: {"acc": float}
"""
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
if reset:
self.acc_count = 0
self.total = 0
return evaluate_result
[docs]def bmes_tag_to_spans(tags, ignore_labels=None):
"""
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
spans[-1][1][1] = idx
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]))
for span in spans
if span[0] not in ignore_labels
]
[docs]def bio_tag_to_spans(tags, ignore_labels=None):
"""
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bio_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b':
spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]:
spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
pass
else:
spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]))
for span in spans
if span[0] not in ignore_labels
]
[docs]class SpanFPreRecMetric(MetricBase):
"""
在序列标注问题中,以span的方式计算F, pre, rec.
最后得到的metric结果为
{
'f': xxx, # 这里使用f考虑以后可以计算f_beta值
'pre': xxx,
'rec':xxx
}
若only_gross=False, 即还会返回各个label的metric统计值
{
'f': xxx,
'pre': xxx,
'rec':xxx,
'f-label': xxx,
'pre-label': xxx,
'rec-label':xxx,
...
}
"""
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
"""
:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。
:param encoding_type: str, 目前支持bio, bmes
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
label的f1, pre, rec
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
"""
encoding_type = encoding_type.lower()
if encoding_type not in ('bio', 'bmes'):
raise ValueError("Only support 'bio' or 'bmes' type.")
if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
self.encoding_type = encoding_type
if self.encoding_type == 'bmes':
self.tag_to_span_func = bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = bio_tag_to_spans
self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta**2
self.only_gross = only_gross
super().__init__()
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self.tag_vocab = tag_vocab
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)
[docs] def evaluate(self, pred, target, seq_lens):
"""
A lot of design idea comes from allennlp's measure
:param pred:
:param target:
:param seq_lens:
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")
if not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
num_classes = pred.size(-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")
batch_size = pred.size(0)
for i in range(batch_size):
pred_tags = pred[i, :seq_lens[i]].tolist()
gold_tags = target[i, :seq_lens[i]].tolist()
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1
def get_metric(self, reset=True):
evaluate_result = {}
if not self.only_gross or self.f_type=='macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag]
fn = self._false_negatives[tag]
fp = self._false_positives[tag]
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum + rec
if not self.only_gross and tag!='': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec
if self.f_type == 'macro':
evaluate_result['f'] = f_sum/len(tags)
evaluate_result['pre'] = pre_sum/len(tags)
evaluate_result['rec'] = rec_sum/len(tags)
if self.f_type == 'micro':
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)
return evaluate_result
def _compute_f_pre_rec(self, tp, fn, fp):
"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)
return f, pre, rec
[docs]class BMESF1PreRecMetric(MetricBase):
"""
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B
| | next_B | next_M | next_E | next_S | end |
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:|
| start | 合法 | next_M=B | next_E=S | 合法 | - |
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S |
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E |
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 |
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 |
举例:
prediction为BSEMS,会被认为是SSSSS.
本Metric不检验target的合法性,请务必保证target的合法性。
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。
target形状为 (batch_size, max_len)
seq_lens形状为 (batch_size, )
"""
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None):
"""
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。
:param b_idx: int, Begin标签所对应的tag idx.
:param m_idx: int, Middle标签所对应的tag idx.
:param e_idx: int, End标签所对应的tag idx.
:param s_idx: int, Single标签所对应的tag idx
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。
"""
super().__init__()
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self.yt_wordnum = 0
self.yp_wordnum = 0
self.corr_num = 0
self.b_idx = b_idx
self.m_idx = m_idx
self.e_idx = e_idx
self.s_idx = s_idx
# 还原init处介绍的矩阵
self._valida_matrix = {
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)],
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)],
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
}
def _validate_tags(self, tags):
"""
给定一个tag的Tensor,返回合法tag
:param tags: Tensor, shape: (seq_len, )
:return: 返回修改为合法tag的list
"""
assert len(tags)!=0
assert isinstance(tags, torch.Tensor) and len(tags.size())==1
padded_tags = [-1, *tags.tolist(), -1]
for idx in range(len(padded_tags)-1):
cur_tag = padded_tags[idx]
if cur_tag not in self._valida_matrix:
cur_tag = self.s_idx
if padded_tags[idx+1] not in self._valida_matrix:
padded_tags[idx+1] = self.s_idx
next_tag = padded_tags[idx+1]
shift_tag = self._valida_matrix[cur_tag][next_tag]
if shift_tag[0]!=-1:
padded_tags[idx+shift_tag[0]] = shift_tag[1]
return padded_tags[1:-1]
def evaluate(self, pred, target, seq_lens):
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")
if not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")
for idx in range(len(pred)):
seq_len = seq_lens[idx]
target_tags = target[idx][:seq_len].tolist()
pred_tags = pred[idx][:seq_len]
pred_tags = self._validate_tags(pred_tags)
start_idx = 0
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)):
if t_tag in (self.s_idx, self.e_idx):
self.yt_wordnum += 1
corr_flag = True
for i in range(start_idx, t_idx+1):
if target_tags[i]!=pred_tags[i]:
corr_flag = False
if corr_flag:
self.corr_num += 1
start_idx = t_idx + 1
if p_tag in (self.s_idx, self.e_idx):
self.yp_wordnum += 1
def get_metric(self, reset=True):
P = self.corr_num / (self.yp_wordnum + 1e-12)
R = self.corr_num / (self.yt_wordnum + 1e-12)
F = 2 * P * R / (P + R + 1e-12)
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)}
if reset:
self.yp_wordnum = 0
self.yt_wordnum = 0
self.corr_num = 0
return evaluate_result
def _prepare_metrics(metrics):
"""
Prepare list of Metric based on input
:param metrics:
:return: List[fastNLP.MetricBase]
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, MetricBase):
metric_name = metric.__class__.__name__
if not callable(metric.evaluate):
raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.")
if not callable(metric.get_metric):
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
_metrics.append(metric)
else:
raise TypeError(
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, "
f"got {type(metrics)}.")
return _metrics
[docs]def accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable labels in y_prob.
:param y_true: ndarray, true label, [n_samples]
:param y_prob: ndarray, label probabilities, [n_samples, n_classes]
:param k: int, k in top-k
:returns acc: accuracy of top-k
"""
y_pred_topk = np.argsort(y_prob, axis=-1)[:, -1:-k - 1:-1]
y_true_tile = np.tile(np.expand_dims(y_true, axis=1), (1, k))
y_match = np.any(y_pred_topk == y_true_tile, axis=-1)
acc = np.sum(y_match) / y_match.shape[0]
return acc
[docs]def pred_topk(y_prob, k=1):
"""Return top-k predicted labels and corresponding probabilities.
:param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels
:param k: int, k of top-k
:returns (y_pred_topk, y_prob_topk):
y_pred_topk: ndarray, size [n_samples, k], predicted top-k labels
y_prob_topk: ndarray, size [n_samples, k], probabilities for top-k labels
"""
y_pred_topk = np.argsort(y_prob, axis=-1)[:, -1:-k - 1:-1]
x_axis_index = np.tile(
np.arange(len(y_prob))[:, np.newaxis],
(1, k))
y_prob_topk = y_prob[x_axis_index, y_pred_topk]
return y_pred_topk, y_prob_topk