support rank0 logger

This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent bd08b8c441
commit c38aa29336
42 changed files with 316 additions and 252 deletions

View File

@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine):
@@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))