fix #4120
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -7,6 +7,7 @@ from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
@@ -399,3 +400,24 @@ def create_custom_scheduler(
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the log probabilities of the given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
|
||||
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
|
||||
|
||||
labels = labels[:, 1:].clone()
|
||||
logits = logits[:, :-1, :]
|
||||
loss_mask = labels != label_pad_token_id
|
||||
labels[labels == label_pad_token_id] = 0 # dummy token
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
Reference in New Issue
Block a user