Source code for fastNLP.core.predictor

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler


[docs]class Predictor(object): """An interface for predicting outputs based on trained models. It does not care about evaluations of the model, which is different from Tester. This is a high-level model wrapper to be called by FastNLP. This class does not share any operations with Trainer and Tester. Currently, Predictor does not support GPU. """ def __init__(self): self.batch_size = 1 self.batch_output = []
[docs] def predict(self, network, data): """Perform inference using the trained model. :param network: a PyTorch model (cpu) :param data: a DataSet object. :return: list of batch outputs """ # turn on the testing mode; clean up the history self.mode(network, test=True) batch_output = [] data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) for batch_x, _ in data_iterator: with torch.no_grad(): prediction = self.data_forward(network, batch_x) batch_output.append(prediction) return batch_output
def mode(self, network, test=True): if test: network.eval() else: network.train()
[docs] def data_forward(self, network, x): """Forward through network.""" y = network(**x) return y