release v0.7.0

This commit is contained in:
hiyouga
2024-04-26 23:18:00 +08:00
parent 031775ade8
commit 168f56683a
13 changed files with 163 additions and 44 deletions

View File

@@ -11,10 +11,13 @@ from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -39,20 +42,30 @@ class VllmEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
download_dir=model_args.cache_dir,
dtype=infer_dtype,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
enable_lora=model_args.adapter_name_or_path is not None,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_feature_size"] = self.image_feature_size
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
@@ -67,6 +80,9 @@ class VllmEngine(BaseEngine):
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
@@ -110,12 +126,21 @@ class VllmEngine(BaseEngine):
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
)
return result_generator