tiny fix
This commit is contained in:
@@ -13,14 +13,14 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.utils import check_min_version, is_torch_npu_available
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
@@ -85,7 +85,7 @@ def load_model_and_tokenizer(
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
if hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||
|
||||
@@ -215,11 +215,7 @@ def load_model_and_tokenizer(
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
if is_torch_npu_available():
|
||||
infer_dtype = torch.float16
|
||||
else:
|
||||
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
||||
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
|
||||
Reference in New Issue
Block a user