fix llava rlhf
This commit is contained in:
@@ -7,9 +7,10 @@ from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .utils.misc import load_valuehead_params, register_autoclass
|
||||
from .utils.misc import register_autoclass
|
||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .utils.unsloth import load_unsloth_pretrained_model
|
||||
from .utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -105,7 +106,7 @@ def load_model(
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
@@ -130,7 +131,7 @@ def load_model(
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
if not lazy_load:
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
Reference in New Issue
Block a user