This commit is contained in:
hiyouga
2023-11-07 19:42:01 +08:00
parent c52336d144
commit 11c1e1e157
5 changed files with 21 additions and 17 deletions

View File

@@ -190,8 +190,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1].item() + 2 # save the EOS token
else:
response_length = response_index[-1].item() + 1
@@ -221,7 +219,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")