allow non-packing pretraining

This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent 412c52e325
commit bdb496644c
22 changed files with 64 additions and 67 deletions

View File

@@ -230,7 +230,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary:
logger.info(
@@ -253,7 +253,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = False
model_args.device_map = "auto"
if data_args.template is None:
@@ -267,7 +266,6 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = True
model_args.device_map = "auto"
if data_args.template is None: