122 lines
3.6 KiB
Python
122 lines
3.6 KiB
Python
import gc
|
|
from copy import deepcopy
|
|
from typing import List, Tuple
|
|
|
|
import pandas as pd
|
|
|
|
from autorag.nodes.generator.base import BaseGenerator
|
|
from autorag.utils import result_to_dataframe
|
|
from autorag.utils.util import pop_params, to_list
|
|
|
|
|
|
class Vllm(BaseGenerator):
|
|
def __init__(self, project_dir: str, llm: str, **kwargs):
|
|
super().__init__(project_dir, llm, **kwargs)
|
|
try:
|
|
from vllm import SamplingParams, LLM
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install vllm library. You can install it by running `pip install vllm`."
|
|
)
|
|
|
|
model_from_kwargs = kwargs.pop("model", None)
|
|
model = llm if model_from_kwargs is None else model_from_kwargs
|
|
|
|
input_kwargs = deepcopy(kwargs)
|
|
sampling_params_init_params = pop_params(
|
|
SamplingParams.from_optional, input_kwargs
|
|
)
|
|
self.vllm_model = LLM(model, **input_kwargs)
|
|
|
|
# delete not sampling param keys in the kwargs
|
|
kwargs_keys = list(kwargs.keys())
|
|
for key in kwargs_keys:
|
|
if key not in sampling_params_init_params:
|
|
kwargs.pop(key)
|
|
|
|
def __del__(self):
|
|
try:
|
|
import torch
|
|
import contextlib
|
|
|
|
if torch.cuda.is_available():
|
|
from vllm.distributed.parallel_state import (
|
|
destroy_model_parallel,
|
|
destroy_distributed_environment,
|
|
)
|
|
|
|
destroy_model_parallel()
|
|
destroy_distributed_environment()
|
|
del self.vllm_model.llm_engine.model_executor
|
|
del self.vllm_model
|
|
with contextlib.suppress(AssertionError):
|
|
torch.distributed.destroy_process_group()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
except ImportError:
|
|
del self.vllm_model
|
|
|
|
super().__del__()
|
|
|
|
@result_to_dataframe(["generated_texts", "generated_tokens", "generated_log_probs"])
|
|
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
|
prompts = self.cast_to_run(previous_result)
|
|
return self._pure(prompts, **kwargs)
|
|
|
|
def _pure(
|
|
self, prompts: List[str], **kwargs
|
|
) -> Tuple[List[str], List[List[int]], List[List[float]]]:
|
|
"""
|
|
Vllm module.
|
|
It gets the VLLM instance and returns generated texts by the input prompt.
|
|
You can set logprobs to get the log probs of the generated text.
|
|
Default logprobs is 1.
|
|
|
|
:param prompts: A list of prompts.
|
|
:param kwargs: The extra parameters for generating the text.
|
|
:return: A tuple of three elements.
|
|
The first element is a list of generated text.
|
|
The second element is a list of generated text's token ids.
|
|
The third element is a list of generated text's log probs.
|
|
"""
|
|
try:
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sequence import SampleLogprobs
|
|
from vllm import SamplingParams
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install vllm library. You can install it by running `pip install vllm`."
|
|
)
|
|
|
|
if "logprobs" not in kwargs:
|
|
kwargs["logprobs"] = 1
|
|
|
|
sampling_params = pop_params(SamplingParams.from_optional, kwargs)
|
|
generate_params = SamplingParams(**sampling_params)
|
|
results: List[RequestOutput] = self.vllm_model.generate(
|
|
prompts, generate_params
|
|
)
|
|
generated_texts = list(map(lambda x: x.outputs[0].text, results))
|
|
generated_token_ids = list(map(lambda x: x.outputs[0].token_ids, results))
|
|
log_probs: List[SampleLogprobs] = list(
|
|
map(lambda x: x.outputs[0].logprobs, results)
|
|
)
|
|
generated_log_probs = list(
|
|
map(
|
|
lambda x: list(map(lambda y: y[0][y[1]].logprob, zip(x[0], x[1]))),
|
|
zip(log_probs, generated_token_ids),
|
|
)
|
|
)
|
|
return (
|
|
to_list(generated_texts),
|
|
to_list(generated_token_ids),
|
|
to_list(generated_log_probs),
|
|
)
|
|
|
|
async def astream(self, prompt: str, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def stream(self, prompt: str, **kwargs):
|
|
raise NotImplementedError
|