This commit is contained in:
hiyouga
2024-06-07 04:18:05 +08:00
parent ccc8b64cc2
commit f9e818d79c
7 changed files with 47 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
@@ -7,6 +7,7 @@ from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments
@@ -399,3 +400,24 @@ def create_custom_scheduler(
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Computes the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)