This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 54b7d34908
commit 4a6ca621c0
4 changed files with 23 additions and 15 deletions

View File

@@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer):
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(