refactor model_dtype, fix PPO trainer
This commit is contained in:
@@ -24,7 +24,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
@@ -86,11 +86,17 @@ def load_model_and_tokenizer(
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is not None:
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
else: # priority: bf16 > fp16 > fp32
|
||||
optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", optim_dtype)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
@@ -131,9 +137,7 @@ def load_model_and_tokenizer(
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
|
||||
LlamaPatches._prepare_decoder_attention_mask
|
||||
)
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif getattr(config, "model_type", None) == "qwen":
|
||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||
@@ -180,7 +184,6 @@ def load_model_and_tokenizer(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
@@ -203,7 +206,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Initialize adapters
|
||||
if is_trainable:
|
||||
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
|
||||
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user