This commit is contained in:
hiyouga
2023-09-21 15:25:29 +08:00
parent e510006ed6
commit ace3f85a72
4 changed files with 30 additions and 14 deletions

View File

@@ -13,14 +13,14 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.utils import check_min_version, is_torch_npu_available
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters
@@ -85,7 +85,7 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix config (for Qwen)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
if hasattr(config, "fp16") and hasattr(config, "bf16"):
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
@@ -215,11 +215,7 @@ def load_model_and_tokenizer(
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
if is_torch_npu_available():
infer_dtype = torch.float16
else:
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(