fix llava rlhf

This commit is contained in:
hiyouga
2024-04-28 03:01:49 +08:00
parent 4dbbce21d5
commit b3e33c703e
5 changed files with 79 additions and 43 deletions

View File

@@ -7,9 +7,10 @@ from ..extras.logging import get_logger
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils.misc import load_valuehead_params, register_autoclass
from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING:
@@ -105,7 +106,7 @@ def load_model(
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
model = None
lazy_load = False
@@ -130,7 +131,7 @@ def load_model(
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable)
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)