fix import error

This commit is contained in:
hiyouga
2023-08-23 20:45:03 +08:00
parent 57146c101f
commit 2de1a7610a
3 changed files with 6 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
@@ -12,7 +13,7 @@ from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments