fix #1422
This commit is contained in:
@@ -190,8 +190,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
response_length = response_index[-1].item() + 2 # save the EOS token
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
@@ -221,7 +219,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
Reference in New Issue
Block a user