reimplement neftune

This commit is contained in:
hiyouga
2023-10-22 16:15:08 +08:00
parent b42a145253
commit 7b4acf7265
9 changed files with 36 additions and 104 deletions

View File

@@ -206,8 +206,7 @@ def load_model_and_tokenizer(
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
if is_trainable:
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()