|
|
|
|
@@ -1,5 +1,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import math
|
|
|
|
|
import torch
|
|
|
|
|
from types import MethodType
|
|
|
|
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
|
@@ -66,15 +68,56 @@ def load_model_and_tokenizer(
|
|
|
|
|
**config_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
|
|
|
|
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
|
|
|
|
|
model_to_load = model_args.checkpoint_dir[0]
|
|
|
|
|
else:
|
|
|
|
|
model_to_load = model_args.model_name_or_path
|
|
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
|
|
|
|
is_mergeable = True
|
|
|
|
|
|
|
|
|
|
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
|
|
|
|
if model_args.compute_dtype == torch.bfloat16:
|
|
|
|
|
setattr(config, "bf16", True)
|
|
|
|
|
else:
|
|
|
|
|
setattr(config, "fp16", True)
|
|
|
|
|
|
|
|
|
|
# Set RoPE scaling
|
|
|
|
|
if model_args.rope_scaling is not None:
|
|
|
|
|
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
|
|
|
|
|
if is_trainable:
|
|
|
|
|
logger.warning("Qwen model does not support rope scaling in training.")
|
|
|
|
|
else:
|
|
|
|
|
setattr(config, "use_dynamic_ntk", True)
|
|
|
|
|
setattr(config, "use_logn_attn", True)
|
|
|
|
|
logger.info("Using dynamic NTK scaling.")
|
|
|
|
|
|
|
|
|
|
elif hasattr(config, "rope_scaling"): # for LLaMA models
|
|
|
|
|
if is_trainable:
|
|
|
|
|
if model_args.rope_scaling == "dynamic":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Dynamic NTK may not work well with fine-tuning. "
|
|
|
|
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
|
|
|
|
if current_max_length and model_args.model_max_length <= current_max_length:
|
|
|
|
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
|
|
|
|
scaling_factor = 1.0
|
|
|
|
|
else:
|
|
|
|
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
|
|
|
|
else:
|
|
|
|
|
scaling_factor = 2.0
|
|
|
|
|
|
|
|
|
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
|
|
|
|
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
|
|
|
|
model_args.rope_scaling, scaling_factor
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("Current model does not support RoPE scaling.")
|
|
|
|
|
|
|
|
|
|
# Quantization configurations (using bitsandbytes library).
|
|
|
|
|
is_mergeable = True
|
|
|
|
|
if model_args.quantization_bit is not None:
|
|
|
|
|
if model_args.quantization_bit == 8:
|
|
|
|
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
|
|
|
@@ -95,7 +138,7 @@ def load_model_and_tokenizer(
|
|
|
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
|
|
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
|
|
|
|
|
|
|
|
# Load and prepare pretrained models (without valuehead).
|
|
|
|
|
# Load and prepare pre-trained models (without valuehead).
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_to_load,
|
|
|
|
|
config=config,
|
|
|
|
|
@@ -104,6 +147,10 @@ def load_model_and_tokenizer(
|
|
|
|
|
**config_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Disable custom generate method (for Qwen)
|
|
|
|
|
if "GenerationMixin" not in str(model.generate.__func__):
|
|
|
|
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
|
|
|
|
|
|
|
|
|
# Register auto class to save the custom code files.
|
|
|
|
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
|
|
|
|
config.__class__.register_for_auto_class()
|
|
|
|
|
@@ -116,10 +163,10 @@ def load_model_and_tokenizer(
|
|
|
|
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
|
|
|
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
|
|
|
|
|
|
|
|
|
if stage == "rm" or stage == "ppo": # add value head
|
|
|
|
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
|
|
|
|
# Prepare model with valuehead for RLHF
|
|
|
|
|
if stage == "rm" or stage == "ppo":
|
|
|
|
|
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
|
|
|
|
reset_logging()
|
|
|
|
|
|
|
|
|
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
|
|
|
|
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
|
|
|
|
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
|
|
|
|
@@ -133,9 +180,11 @@ def load_model_and_tokenizer(
|
|
|
|
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
|
|
|
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
|
|
|
|
|
|
|
|
|
# Prepare model for inference
|
|
|
|
|
if not is_trainable:
|
|
|
|
|
model.requires_grad_(False) # fix all model params
|
|
|
|
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
|
|
|
|
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
|
|
|
|
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
|
|
|
|
|
|
|
|
|
trainable_params, all_param = count_parameters(model)
|
|
|
|
|
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
|
|
|
|
|