refactor data preprocessing, fix mllm rlhf

This commit is contained in:
hiyouga
2024-05-24 04:08:25 +08:00
parent a506f3628b
commit 3a023bca2a
15 changed files with 572 additions and 464 deletions

View File

@@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
r"""
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,