Source code for fastNLP.io.base_loader

import _pickle as pickle
import os


[docs]class BaseLoader(object): """Base loader for all loaders. """ def __init__(self): super(BaseLoader, self).__init__() @staticmethod def load_lines(data_path): with open(data_path, "r", encoding="utf=8") as f: text = f.readlines() return [line.strip() for line in text] @classmethod def load(cls, data_path): with open(data_path, "r", encoding="utf-8") as f: text = f.readlines() return [[word for word in sent.strip()] for sent in text] @classmethod def load_with_cache(cls, data_path, cache_path): if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): with open(cache_path, 'rb') as f: return pickle.load(f) else: obj = cls.load(data_path) with open(cache_path, 'wb') as f: pickle.dump(obj, f) return obj
[docs]class DataLoaderRegister: """Register for all data sets. """ _readers = {} @classmethod def set_reader(cls, reader_cls, read_fn_name): # def wrapper(reader_cls): if read_fn_name in cls._readers: raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) if hasattr(reader_cls, 'load'): cls._readers[read_fn_name] = reader_cls().load return reader_cls @classmethod def get_reader(cls, read_fn_name): if read_fn_name in cls._readers: return cls._readers[read_fn_name] raise AttributeError('no read function: {}'.format(read_fn_name))