refactor pissa, improve llamaboard
This commit is contained in:
@@ -17,11 +17,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.optimization import get_scheduler
|
||||
@@ -40,7 +38,6 @@ if is_galore_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -175,51 +172,6 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
def convert_pissa_adapter(
|
||||
output_dir: str,
|
||||
state_dict: Dict[str, "torch.Tensor"],
|
||||
accelerator: "Accelerator",
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> None:
|
||||
r"""
|
||||
Converts the PiSSA adapter to a LoRA adapter.
|
||||
"""
|
||||
pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(output_dir, "pissa_backup")
|
||||
if output_dir == pissa_init_dir:
|
||||
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
if isinstance(unwrapped_model, PeftModel):
|
||||
init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights")
|
||||
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True)
|
||||
unwrapped_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=training_args.save_safetensors,
|
||||
)
|
||||
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
elif output_dir == training_args.output_dir: # at the end of training
|
||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir))
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use
|
||||
unwrapped_model.save_pretrained(
|
||||
pissa_backup_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=training_args.save_safetensors,
|
||||
)
|
||||
unwrapped_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=training_args.save_safetensors,
|
||||
convert_pissa_to_lora=pissa_init_dir,
|
||||
)
|
||||
# TODO: the model is applied pissa again unexpectedly
|
||||
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||
unwrapped_model.set_adapter("default")
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
|
||||
Reference in New Issue
Block a user