fix full/freeze tuning for mllm

This commit is contained in:
hiyouga
2024-05-27 20:37:57 +08:00
parent 838f2fb3e4
commit 08564838bd
7 changed files with 76 additions and 61 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, List
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.models
@@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules