update webui and add CLIs

This commit is contained in:
hiyouga
2024-05-03 02:58:23 +08:00
parent 39e964a97a
commit 245fe47ece
65 changed files with 363 additions and 372 deletions

View File

@@ -1,7 +1,7 @@
import json
import math
import os
from typing import List
from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
@@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
@@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
@@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)