|
|
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
|
|
import sys
|
|
|
|
|
import torch
|
|
|
|
|
import hashlib
|
|
|
|
|
from itertools import chain
|
|
|
|
|
from typing import List, Literal, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import transformers
|
|
|
|
|
@@ -84,6 +85,8 @@ def init_adapter(
|
|
|
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
|
|
|
|
if len(model_args.checkpoint_dir) > 1:
|
|
|
|
|
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
|
|
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
|
|
|
|
|
|
|
|
|
if finetuning_args.finetuning_type == "lora":
|
|
|
|
|
@@ -117,6 +120,9 @@ def init_adapter(
|
|
|
|
|
)
|
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None:
|
|
|
|
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -131,19 +137,14 @@ def load_pretrained(
|
|
|
|
|
|
|
|
|
|
Support both training and inference.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if (not is_trainable) and (model_args.checkpoint_dir is None):
|
|
|
|
|
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
|
|
|
|
finetuning_args = FinetuningArguments(finetuning_type="none")
|
|
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
|
|
|
|
|
for checkpoint_dir in model_args.checkpoint_dir:
|
|
|
|
|
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
|
|
|
|
|
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
|
|
|
|
|
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
|
|
|
|
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
|
|
|
|
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
|
|
|
|
|
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
|
|
|
|
if finetuning_args is None: # load the fine-tuning arguments
|
|
|
|
|
if model_args.checkpoint_dir is None:
|
|
|
|
|
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
|
|
|
|
finetuning_args = FinetuningArguments(finetuning_type="none")
|
|
|
|
|
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
|
|
|
|
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
|
|
|
|
|
|
|
|
|
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
|
|
|
|
|
|
|
|
|
|
@@ -350,7 +351,7 @@ def preprocess_data(
|
|
|
|
|
if examples["prompt"][i] and examples["response"][i]:
|
|
|
|
|
query, answer = examples["prompt"][i], examples["response"][i]
|
|
|
|
|
if examples["query"][i]:
|
|
|
|
|
query += examples["query"][i]
|
|
|
|
|
query += "\n" + examples["query"][i]
|
|
|
|
|
prompt = "Below is an instruction that describes a task. "
|
|
|
|
|
prompt += "Write a response that appropriately completes the request.\n"
|
|
|
|
|
prompt += "Instruction:\n" + prefix
|
|
|
|
|
@@ -361,6 +362,20 @@ def preprocess_data(
|
|
|
|
|
prompt += "Human: {}\nAssistant: ".format(query)
|
|
|
|
|
yield prompt, answer
|
|
|
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples):
|
|
|
|
|
# build grouped texts with format `<s>??`
|
|
|
|
|
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
|
|
|
|
concatenated_ids = list(chain(*text_ids))
|
|
|
|
|
total_length = len(concatenated_ids)
|
|
|
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
|
|
|
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length
|
|
|
|
|
# split by chunks of max_source_length
|
|
|
|
|
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)]
|
|
|
|
|
return {
|
|
|
|
|
"input_ids": result,
|
|
|
|
|
"labels": result.copy()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def preprocess_supervised_dataset(examples):
|
|
|
|
|
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
|
|
|
|
|
model_inputs = {"input_ids": [], "labels": []}
|
|
|
|
|
@@ -425,7 +440,9 @@ def preprocess_data(
|
|
|
|
|
print("input_ids:\n{}".format(example["input_ids"]))
|
|
|
|
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
|
|
|
|
print("label_ids:\n{}".format(example["labels"]))
|
|
|
|
|
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])))
|
|
|
|
|
print("labels:\n{}".format(
|
|
|
|
|
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def print_pairwise_dataset_example(example):
|
|
|
|
|
print("accept_ids:\n{}".format(example["accept_ids"]))
|
|
|
|
|
@@ -437,11 +454,11 @@ def preprocess_data(
|
|
|
|
|
print("input_ids:\n{}".format(example["input_ids"]))
|
|
|
|
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
|
|
|
|
|
|
|
|
|
if stage == "sft":
|
|
|
|
|
if (not training_args.do_train) and training_args.predict_with_generate: # with generation
|
|
|
|
|
preprocess_function = preprocess_evaluation_dataset
|
|
|
|
|
else: # without generation
|
|
|
|
|
preprocess_function = preprocess_supervised_dataset
|
|
|
|
|
if stage == "pt":
|
|
|
|
|
preprocess_function = preprocess_pretrain_dataset
|
|
|
|
|
elif stage == "sft":
|
|
|
|
|
preprocess_function = preprocess_evaluation_dataset \
|
|
|
|
|
if training_args.predict_with_generate else preprocess_supervised_dataset
|
|
|
|
|
elif stage == "rm":
|
|
|
|
|
preprocess_function = preprocess_pairwise_dataset
|
|
|
|
|
elif stage == "ppo":
|
|
|
|
|
|