upgrade peft, fix #1088 #1411

This commit is contained in:
hiyouga
2023-11-07 16:13:36 +08:00
parent 66a91e1fe3
commit b2a60905f3
15 changed files with 133 additions and 99 deletions

View File

@@ -25,9 +25,8 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
from llmtuner.tuner.core.utils import prepare_model_for_training
if TYPE_CHECKING:
@@ -41,7 +40,7 @@ logger = get_logger(__name__)
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
@@ -64,7 +63,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.hf_hub_token
}
tokenizer = AutoTokenizer.from_pretrained(
@@ -99,15 +98,9 @@ def load_model_and_tokenizer(
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
@@ -129,9 +122,6 @@ def load_model_and_tokenizer(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
@@ -155,7 +145,6 @@ def load_model_and_tokenizer(
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@@ -165,7 +154,7 @@ def load_model_and_tokenizer(
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -175,7 +164,6 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@@ -207,7 +195,7 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
@@ -226,7 +214,7 @@ def load_model_and_tokenizer(
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
load_valuehead_params(model, model_args)
# Prepare model for inference
if not is_trainable: