remove PeftTrainer

This commit is contained in:
hiyouga
2023-09-10 22:23:23 +08:00
parent baac22f4f4
commit b218c271ed
17 changed files with 75 additions and 259 deletions

View File

@@ -1,49 +1,21 @@
import os
import torch
from typing import Dict
from transformers.trainer import WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True