update loading logic
This commit is contained in:
@@ -5,8 +5,8 @@ import torch
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
@@ -133,11 +133,14 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if not os.path.exists(weights_file):
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user