update loading logic

This commit is contained in:
hiyouga
2023-06-28 12:07:16 +08:00
parent 0a46313cca
commit 4d0fddba21
2 changed files with 40 additions and 21 deletions

View File

@@ -5,8 +5,8 @@ import torch
import logging
from typing import Dict, List, Optional
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
@@ -133,11 +133,14 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(weights_file):
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
return True