Source code for fastNLP.models.base_model

import torch

from fastNLP.modules.decoder.MLP import MLP


[docs]class BaseModel(torch.nn.Module): """Base PyTorch model for all models. """ def __init__(self): super(BaseModel, self).__init__() def fit(self, train_data, dev_data=None, **train_args): pass def predict(self, *args, **kwargs): raise NotImplementedError
[docs]class NaiveClassifier(BaseModel): def __init__(self, in_feature_dim, out_feature_dim): super(NaiveClassifier, self).__init__() self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])
[docs] def forward(self, x): return {"predict": torch.sigmoid(self.mlp(x))}
def predict(self, x): return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}