refactor pissa, improve llamaboard

This commit is contained in:
hiyouga
2024-06-28 01:04:24 +08:00
parent ef38daa0a4
commit 8baf3b22b0
16 changed files with 219 additions and 216 deletions

View File

@@ -17,11 +17,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from peft import PeftModel
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.optimization import get_scheduler
@@ -40,7 +38,6 @@ if is_galore_available():
if TYPE_CHECKING:
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
@@ -175,51 +172,6 @@ def create_reward_model(
return reward_model
def convert_pissa_adapter(
output_dir: str,
state_dict: Dict[str, "torch.Tensor"],
accelerator: "Accelerator",
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
) -> None:
r"""
Converts the PiSSA adapter to a LoRA adapter.
"""
pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(output_dir, "pissa_backup")
if output_dir == pissa_init_dir:
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
unwrapped_model = accelerator.unwrap_model(model)
if isinstance(unwrapped_model, PeftModel):
init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights")
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True)
unwrapped_model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
)
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights)
elif output_dir == training_args.output_dir: # at the end of training
logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir))
unwrapped_model = accelerator.unwrap_model(model)
if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use
unwrapped_model.save_pretrained(
pissa_backup_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
)
unwrapped_model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
convert_pissa_to_lora=pissa_init_dir,
)
# TODO: the model is applied pissa again unexpectedly
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
unwrapped_model.set_adapter("default")
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)