support new special token #3420

This commit is contained in:
hiyouga
2024-04-24 23:39:31 +08:00
parent e5d23c053a
commit 297fb8ead3
8 changed files with 47 additions and 7 deletions

View File

@@ -39,6 +39,8 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
r"""
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
@@ -57,6 +59,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
return tokenizer