use low_cpu_mem_usage to speed up loading

This commit is contained in:
hiyouga
2023-06-03 18:19:01 +08:00
parent dca27b4412
commit 771f454ff1
3 changed files with 24 additions and 12 deletions

View File

@@ -143,15 +143,24 @@ def load_pretrained(
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
padding_side="left",
**config_kwargs
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
@@ -162,23 +171,19 @@ def load_pretrained(
config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the model weights are float16 type
low_cpu_mem_usage=True,
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if not is_trainable:
model.requires_grad_(False) # fix all model params
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
@@ -194,6 +199,9 @@ def load_pretrained(
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
if not is_trainable:
model.requires_grad_(False) # fix all model params
print_trainable_params(model)
return model, tokenizer