add docstrings, refactor logger

This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 8eac1b929f
commit 54c6905937
30 changed files with 334 additions and 57 deletions

View File

@@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
@@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
@override
def zero_grad(self, set_to_none: bool = True) -> None:
pass
@override
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass