fix layer norm dtype

This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent b0b0138e1d
commit 84b7486885
6 changed files with 28 additions and 22 deletions

View File

@@ -128,10 +128,6 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
@@ -205,7 +201,8 @@ def load_model_and_tokenizer(
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
if is_trainable:
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()