use fp16 model, add logcallback
This commit is contained in:
@@ -12,6 +12,7 @@ from utils import (
|
||||
DataCollatorForLLaMA,
|
||||
Seq2SeqTrainerForLLaMA,
|
||||
ComputeMetrics,
|
||||
LogCallback,
|
||||
get_logits_processor,
|
||||
plot_loss
|
||||
)
|
||||
@@ -49,6 +50,7 @@ def main():
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**trainer_kwargs
|
||||
)
|
||||
@@ -57,7 +59,7 @@ def main():
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
|
||||
"max_new_tokens": data_args.max_target_length + 1,
|
||||
"temperature": 0.95,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user