177 lines
5.5 KiB
Python
177 lines
5.5 KiB
Python
import logging
|
|
from typing import List, Tuple
|
|
import time
|
|
|
|
import pandas as pd
|
|
import requests
|
|
from asyncio import to_thread
|
|
|
|
from autorag.nodes.generator.base import BaseGenerator
|
|
from autorag.utils.util import get_event_loop, process_batch, result_to_dataframe
|
|
|
|
logger = logging.getLogger("AutoRAG")
|
|
|
|
DEFAULT_MAX_TOKENS = 4096 # Default token limit
|
|
|
|
|
|
class VllmAPI(BaseGenerator):
|
|
def __init__(
|
|
self,
|
|
project_dir,
|
|
llm: str,
|
|
uri: str,
|
|
max_tokens: int = None,
|
|
batch: int = 16,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
VLLM API Wrapper for OpenAI-compatible chat/completions format.
|
|
|
|
:param project_dir: Project directory.
|
|
:param llm: Model name (e.g., LLaMA model).
|
|
:param uri: VLLM API server URI.
|
|
:param max_tokens: Maximum token limit.
|
|
Default is 4096.
|
|
:param batch: Request batch size.
|
|
Default is 16.
|
|
"""
|
|
super().__init__(project_dir, llm, *args, **kwargs)
|
|
assert batch > 0, "Batch size must be greater than 0."
|
|
self.uri = uri.rstrip("/") # Set API URI
|
|
self.batch = batch
|
|
# Use the provided max_tokens if available, otherwise use the default
|
|
self.max_token_size = max_tokens if max_tokens else DEFAULT_MAX_TOKENS
|
|
self.max_model_len = self.get_max_model_length()
|
|
logger.info(f"{llm} max model length: {self.max_model_len}")
|
|
|
|
@result_to_dataframe(["generated_texts", "generated_tokens", "generated_log_probs"])
|
|
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
|
prompts = self.cast_to_run(previous_result)
|
|
return self._pure(prompts, **kwargs)
|
|
|
|
def _pure(
|
|
self, prompts: List[str], truncate: bool = True, **kwargs
|
|
) -> Tuple[List[str], List[List[int]], List[List[float]]]:
|
|
"""
|
|
Method to call the VLLM API to generate text.
|
|
|
|
:param prompts: List of input prompts.
|
|
:param truncate: Whether to truncate input prompts to fit within the token limit.
|
|
:param kwargs: Additional options (e.g., temperature, top_p).
|
|
:return: Generated text, token lists, and log probability lists.
|
|
"""
|
|
if kwargs.get("logprobs") is not None:
|
|
kwargs.pop("logprobs")
|
|
logger.warning(
|
|
"parameter logprob does not effective. It always set to True."
|
|
)
|
|
if kwargs.get("n") is not None:
|
|
kwargs.pop("n")
|
|
logger.warning("parameter n does not effective. It always set to 1.")
|
|
|
|
if truncate:
|
|
prompts = list(map(lambda p: self.truncate_by_token(p), prompts))
|
|
loop = get_event_loop()
|
|
tasks = [to_thread(self.get_result, prompt, **kwargs) for prompt in prompts]
|
|
results = loop.run_until_complete(process_batch(tasks, self.batch))
|
|
|
|
answer_result = list(map(lambda x: x[0], results))
|
|
token_result = list(map(lambda x: x[1], results))
|
|
logprob_result = list(map(lambda x: x[2], results))
|
|
return answer_result, token_result, logprob_result
|
|
|
|
def truncate_by_token(self, prompt: str) -> str:
|
|
"""
|
|
Function to truncate prompts to fit within the maximum token limit.
|
|
"""
|
|
tokens = self.encoding_for_model(prompt)["tokens"] # Simple tokenization
|
|
return self.decoding_for_model(tokens[: self.max_model_len])["prompt"]
|
|
|
|
def call_vllm_api(self, prompt: str, **kwargs) -> dict:
|
|
"""
|
|
Calls the VLLM API to get chat/completions responses.
|
|
|
|
:param prompt: Input prompt.
|
|
:param kwargs: Additional API options (e.g., temperature, max_tokens).
|
|
:return: API response.
|
|
"""
|
|
payload = {
|
|
"model": self.llm,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": kwargs.get("temperature", 0.4),
|
|
"max_tokens": min(
|
|
kwargs.get("max_tokens", self.max_token_size), self.max_token_size
|
|
),
|
|
"logprobs": True,
|
|
"n": 1,
|
|
}
|
|
start_time = time.time() # Record request start time
|
|
response = requests.post(f"{self.uri}/v1/chat/completions", json=payload)
|
|
end_time = time.time() # Record request end time
|
|
|
|
response.raise_for_status()
|
|
elapsed_time = end_time - start_time # Calculate elapsed time
|
|
logger.info(
|
|
f"Request chat completions to vllm server completed in {elapsed_time:.2f} seconds"
|
|
)
|
|
return response.json()
|
|
|
|
# Additional method: abstract method implementation
|
|
async def astream(self, prompt: str, **kwargs):
|
|
"""
|
|
Asynchronous streaming method not implemented.
|
|
"""
|
|
raise NotImplementedError("astream method is not implemented for VLLM API yet.")
|
|
|
|
def stream(self, prompt: str, **kwargs):
|
|
"""
|
|
Synchronous streaming method not implemented.
|
|
"""
|
|
raise NotImplementedError("stream method is not implemented for VLLM API yet.")
|
|
|
|
def get_result(self, prompt: str, **kwargs):
|
|
response = self.call_vllm_api(prompt, **kwargs)
|
|
choice = response["choices"][0]
|
|
answer = choice["message"]["content"]
|
|
|
|
# Handle cases where logprobs is None
|
|
if choice.get("logprobs") and "content" in choice["logprobs"]:
|
|
logprobs = list(map(lambda x: x["logprob"], choice["logprobs"]["content"]))
|
|
tokens = list(
|
|
map(
|
|
lambda x: self.encoding_for_model(x["token"])["tokens"],
|
|
choice["logprobs"]["content"],
|
|
)
|
|
)
|
|
else:
|
|
logprobs = []
|
|
tokens = []
|
|
|
|
return answer, tokens, logprobs
|
|
|
|
def encoding_for_model(self, answer_piece: str):
|
|
payload = {
|
|
"model": self.llm,
|
|
"prompt": answer_piece,
|
|
"add_special_tokens": True,
|
|
}
|
|
response = requests.post(f"{self.uri}/tokenize", json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def decoding_for_model(self, tokens: list[int]):
|
|
payload = {
|
|
"model": self.llm,
|
|
"tokens": tokens,
|
|
}
|
|
response = requests.post(f"{self.uri}/detokenize", json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def get_max_model_length(self):
|
|
response = requests.get(f"{self.uri}/v1/models")
|
|
response.raise_for_status()
|
|
json_data = response.json()
|
|
return json_data["data"][0]["max_model_len"]
|