improve rlhf
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,28 +14,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import json
|
||||
import os
|
||||
@@ -53,6 +31,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
@@ -108,46 +87,23 @@ class PairwiseTrainer(Trainer):
|
||||
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
end_index = chosen_length
|
||||
div_index = end_index - 1
|
||||
else:
|
||||
end_index = max(chosen_length, rejected_length)
|
||||
div_index = check_divergence[0]
|
||||
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
|
||||
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
|
||||
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||
if return_outputs:
|
||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||
return loss, [loss, chosen_scores, rejected_scores]
|
||||
|
||||
return loss
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
else:
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user