support badam for all stages

This commit is contained in:
hiyouga
2024-04-16 17:44:48 +08:00
parent 4d660c5ade
commit e3d8fc75eb
9 changed files with 61 additions and 28 deletions

View File

@@ -1,3 +1,4 @@
from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
@@ -23,6 +24,10 @@ class CustomTrainer(Trainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: