This commit is contained in:
hiyouga
2024-08-05 23:48:19 +08:00
parent c2921b9960
commit b7ca6c8dc1
13 changed files with 111 additions and 69 deletions

View File

@@ -162,11 +162,12 @@ class PissaConvertCallback(TrainerCallback):
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
)
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default")
if "pissa_init" in model.peft_config.keys():
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)