@@ -25,9 +25,8 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,7 +40,7 @@ logger = get_logger(__name__)
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
|
||||
|
||||
|
||||
@@ -64,7 +63,7 @@ def load_model_and_tokenizer(
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -99,15 +98,9 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
|
||||
if is_trainable:
|
||||
logger.warning("Qwen model does not support RoPE scaling in training.")
|
||||
else:
|
||||
setattr(config, "use_dynamic_ntk", True)
|
||||
setattr(config, "use_logn_attn", True)
|
||||
logger.info("Using dynamic NTK scaling.")
|
||||
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
@@ -129,9 +122,6 @@ def load_model_and_tokenizer(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
@@ -155,7 +145,6 @@ def load_model_and_tokenizer(
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
@@ -165,7 +154,7 @@ def load_model_and_tokenizer(
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
@@ -175,7 +164,6 @@ def load_model_and_tokenizer(
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -207,7 +195,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
@@ -226,7 +214,7 @@ def load_model_and_tokenizer(
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
load_valuehead_params(model, model_args)
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
|
||||
Reference in New Issue
Block a user