add some
This commit is contained in:
@@ -40,7 +40,9 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
def load_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcesser"]]:
|
||||
r"""
|
||||
Loads pretrained tokenizer.
|
||||
|
||||
@@ -78,33 +80,25 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
)
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_processor(model_args: "ModelArguments") -> "AutoProcessor":
|
||||
r"""
|
||||
Loads processor. Must before load_model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except Exception: # try the fast one
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=True,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
return processor
|
||||
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
|
||||
if model_args.use_mllm:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except Exception: # try the fast one
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=True,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
tokenizer_modules["processor"] = processor
|
||||
return tokenizer_modules
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
|
||||
Reference in New Issue
Block a user