Source code for fastNLP.core.tester

import torch
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature


[docs]class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. :param DataSet data: a validation/development set :param torch.nn.modules.module model: a PyTorch model :param MetricBase metrics: a metric object or a list of metrics (List[MetricBase]) :param int batch_size: batch size for validation :param bool use_cuda: whether to use CUDA in validation. :param int verbose: the number of steps after which an information is printed. """ def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1): super(Tester, self).__init__() if not isinstance(data, DataSet): raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") self.metrics = _prepare_metrics(metrics) self.data = data self.use_cuda = use_cuda self.batch_size = batch_size self.verbose = verbose if torch.cuda.is_available() and self.use_cuda: self._model = model.cuda() else: self._model = model self._model_device = model.parameters().__next__().device # check predict if hasattr(self._model, 'predict'): self._predict_func = self._model.predict if not callable(self._predict_func): _model_name = model.__class__.__name__ raise TypeError(f"`{_model_name}.predict` must be callable to be used " f"for evaluation, not `{type(self._predict_func)}`.") else: self._predict_func = self._model.forward
[docs] def test(self): """Start test or validation. :return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics. """ # turn on the testing mode; clean up the history network = self._model self._mode(network, is_test=True) data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) eval_results = {} try: with torch.no_grad(): for batch_x, batch_y in data_iterator: _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) pred_dict = self._data_forward(self._predict_func, batch_x) if not isinstance(pred_dict, dict): raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " f"must be `dict`, got {type(pred_dict)}.") for metric in self.metrics: metric(pred_dict, batch_y) for metric in self.metrics: eval_result = metric.get_metric() if not isinstance(eval_result, dict): raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " f"`dict`, got {type(eval_result)}") metric_name = metric.__class__.__name__ eval_results[metric_name] = eval_result except CheckError as e: prev_func_signature = get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=self.data, check_level=0) if self.verbose >= 1: print("[tester] \n{}".format(self._format_eval_results(eval_results))) self._mode(network, is_test=False) return eval_results
def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model :param is_test: bool, whether in test mode or not. """ if is_test: model.eval() else: model.train() def _data_forward(self, func, x): """A forward pass of the model. """ x = _build_args(func, **x) y = func(**x) return y def _format_eval_results(self, results): """Override this method to support more print formats. :param results: dict, (str: float) is (metrics name: value) """ _str = '' for metric_name, metric_result in results.items(): _str += metric_name + ': ' _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) _str += '\n' return _str[:-1]