use fp16 model, add logcallback

This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 769c6ab56b
commit 0c9fda01e3
7 changed files with 112 additions and 10 deletions

View File

@@ -12,6 +12,7 @@ from utils import (
DataCollatorForLLaMA,
Seq2SeqTrainerForLLaMA,
ComputeMetrics,
LogCallback,
get_logits_processor,
plot_loss
)
@@ -49,6 +50,7 @@ def main():
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)
@@ -57,7 +59,7 @@ def main():
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}