import _pickle
import inspect
import os
import warnings
from collections import Counter
from collections import namedtuple
import numpy as np
import torch
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)
[docs]def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.
:param obj: an object
:param pickle_path: str, the directory where the pickle file is to be saved
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
"""
if not os.path.exists(pickle_path):
os.mkdir(pickle_path)
print("make dir {} before saving pickle file".format(pickle_path))
with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f)
print("{} saved in {}".format(file_name, pickle_path))
[docs]def load_pickle(pickle_path, file_name):
"""Load an object from a given pickle file.
:param pickle_path: str, the directory where the pickle file is.
:param file_name: str, the name of the pickle file.
:return obj: an object stored in the pickle
"""
with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f)
print("{} loaded from {}".format(file_name, pickle_path))
return obj
[docs]def pickle_exist(pickle_path, pickle_name):
"""Check if a given pickle file exists in the directory.
:param pickle_path: the directory of target pickle file
:param pickle_name: the filename of target pickle file
:return: True if file exists else False
"""
if not os.path.exists(pickle_path):
os.makedirs(pickle_path)
file_name = os.path.join(pickle_path, pickle_name)
if os.path.exists(file_name):
return True
else:
return False
def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func)
if spect.varkw is not None:
return kwargs
needed_args = set(spect.args)
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], defaults)}
output.update({name: val for name, val in kwargs.items() if name in needed_args})
return output
def _map_args(maps: dict, **kwargs):
# maps: key=old name, value= new name
output = {}
for name, val in kwargs.items():
if name in maps:
assert isinstance(maps[name], str)
output.update({maps[name]: val})
else:
output.update({name: val})
for keys in maps.keys():
if keys not in output.keys():
# TODO: add UNUSED warning.
pass
return output
def _get_arg_list(func):
assert callable(func)
spect = inspect.getfullargspec(func)
if spect.defaults is not None:
args = spect.args[: -len(spect.defaults)]
defaults = spect.args[-len(spect.defaults):]
defaults_val = spect.defaults
else:
args = spect.args
defaults = None
defaults_val = None
varargs = spect.varargs
kwargs = spect.varkw
return args, defaults, defaults_val, varargs, kwargs
# check args
def _check_arg_dict_list(func, args):
if isinstance(args, dict):
arg_dict_list = [args]
else:
arg_dict_list = args
assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func)
all_args = set([arg for arg in spect.args if arg != 'self'])
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
default_args = set(spect.args[start_idx:])
require_args = all_args - default_args
input_arg_count = Counter()
for arg_dict in arg_dict_list:
input_arg_count.update(arg_dict.keys())
duplicated = [name for name, val in input_arg_count.items() if val > 1]
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
varargs = [] if not spect.varargs else [spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=varargs)
[docs]def get_func_signature(func):
"""
Given a function or method, return its signature.
For example:
(1) function
def func(a, b='a', *args):
xxxx
get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
class Demo:
def __init__(self):
xxx
def forward(self, a, b='a', **args)
demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method
:return: str or None
"""
if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__
signature = inspect.signature(func)
signature_str = str(signature)
if len(signature_str) > 2:
_self = '(self, '
else:
_self = '(self'
signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:]
return signature_str
elif inspect.isfunction(func):
signature = inspect.signature(func)
signature_str = str(signature)
signature_str = func.__name__ + signature_str
return signature_str
def _is_function_or_method(func):
"""
:param func:
:return:
"""
if not inspect.ismethod(func) and not inspect.isfunction(func):
return False
return True
def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")
def _move_dict_value_to_device(*args, device: torch.device):
"""
move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param args:
:return:
"""
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")
[docs]class CheckError(Exception):
"""
CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`']
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}")
Exception.__init__(self, '\n'.join(errs))
self.check_res = check_res
self.func_signature = func_signature
IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = []
unuseds = []
_unused_field = []
_unused_param = []
suggestions = []
# if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused:
for _unused in check_res.unused:
if _unused in target_dict:
_unused_field.append(_unused)
else:
_unused_param.append(_unused)
if _unused_field:
unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = func_signature.split('.')[0]
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
import re
mapped_missing = []
unmapped_missing = []
input_func_map = {}
for _miss in check_res.missing:
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
if len(matches) == 2:
fun_arg, module_name = matches
input_func_map[_miss] = fun_arg
if fun_arg == _miss:
unmapped_missing.append(_miss)
else:
mapped_missing.append(_miss)
else:
unmapped_missing.append(_miss)
for _miss in mapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
for _miss in unmapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.'
suggestions.append(_tmp)
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if len(errs)>0:
errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)
if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
if idx>0:
sugg_str += '\t\t\t'
sugg_str += f'({idx+1}). {sugg}\n'
sugg_str = sugg_str[:-1]
else:
sugg_str += suggestions[0]
errs.append(f'\ttarget field: {list(target_dict.keys())}')
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}')
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if check_res.unused:
if check_level == WARNING_CHECK_LEVEL:
if not module_name:
module_name = func_signature.split('.')[0]
_unused_warn = f'{check_res.unused} is not used by {module_name}.'
warnings.warn(message=_unused_warn)
def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)
errs = []
suggestions = []
_unused = []
# if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
# if check_res.unused:
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
suggestions.append(_tmp)
if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"]
if len(errs)>0:
errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)
if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)
[docs]def seq_lens_to_masks(seq_lens, float=False):
"""
Convert seq_lens to masks.
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,)
:param float: if True, the return masks is in float type, otherwise it is byte.
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
"""
if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented
elif isinstance(seq_lens, torch.LongTensor):
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
masks = indexes.lt(seq_lens.unsqueeze(1))
if float:
masks = masks.float()
return masks
elif isinstance(seq_lens, list):
raise NotImplemented
else:
raise NotImplemented
[docs]def seq_mask(seq_len, max_len):
"""Create sequence mask.
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]
"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]