Files
autorag_evaluation/autorag-workspace/autorag/deploy/api.py
2025-03-14 17:28:01 +09:00

294 lines
8.4 KiB
Python

import logging
import os
import pathlib
import uuid
from typing import Dict, Optional, List, Union, Literal
import pandas as pd
from quart import Quart, request, jsonify
from quart.helpers import stream_with_context
from pydantic import BaseModel, ValidationError
from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.nodes.promptmaker.base import BasePromptMaker
from autorag.utils.util import fetch_contents, to_list
logger = logging.getLogger("AutoRAG")
deploy_dir = pathlib.Path(__file__).parent
root_dir = pathlib.Path(__file__).parent.parent
VERSION_PATH = os.path.join(root_dir, "VERSION")
class QueryRequest(BaseModel):
query: str
result_column: Optional[str] = "generated_texts"
class RetrievedPassage(BaseModel):
content: str
doc_id: str
score: float
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
end_idx: Optional[int] = None
class RunResponse(BaseModel):
result: Union[str, List[str]]
retrieved_passage: List[RetrievedPassage]
class RetrievalResponse(BaseModel):
passages: List[RetrievedPassage]
class StreamResponse(BaseModel):
"""
When the type is generated_text, only generated_text is returned. The other fields are None.
When the type is retrieved_passage, only retrieved_passage and passage_index are returned. The other fields are None.
"""
type: Literal["generated_text", "retrieved_passage"]
generated_text: Optional[str]
retrieved_passage: Optional[RetrievedPassage]
passage_index: Optional[int]
class VersionResponse(BaseModel):
version: str
class ApiRunner(BaseRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
self.app = Quart(__name__)
data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
self.__add_api_route()
def __add_api_route(self):
@self.app.route("/v1/run", methods=["POST"])
async def run_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
# Simulate processing the query
generated_text = previous_result[data.result_column].tolist()[0]
retrieved_passage = self.extract_retrieve_passage(previous_result)
response = RunResponse(
result=generated_text, retrieved_passage=retrieved_passage
)
return jsonify(response.model_dump()), 200
@self.app.route("/v1/retrieve", methods=["POST"])
async def run_retrieve_only():
data = await request.get_json()
query = data.get("query", None)
if query is None:
return jsonify(
{
"error": "Invalid request. You need to include 'query' in the request body."
}
), 400
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if isinstance(module_instance, BasePromptMaker) or isinstance(
module_instance, BaseGenerator
):
continue
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
# Simulate processing the query
retrieved_passages = self.extract_retrieve_passage(previous_result)
retrieval_response = RetrievalResponse(passages=retrieved_passages)
return jsonify(retrieval_response.model_dump()), 200
@self.app.route("/v1/stream", methods=["POST"])
async def stream_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
@stream_with_context
async def generate():
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
retrieved_passages = self.extract_retrieve_passage(
previous_result
)
for i, retrieved_passage in enumerate(retrieved_passages):
yield (
StreamResponse(
type="retrieved_passage",
generated_text=None,
retrieved_passage=retrieved_passage,
passage_index=i,
)
.model_dump_json()
.encode("utf-8")
)
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
async for delta in module_instance.astream(
prompt=prompt, **module_param
):
response = StreamResponse(
type="generated_text",
generated_text=delta,
retrieved_passage=None,
passage_index=None,
)
yield response.model_dump_json().encode("utf-8")
return generate(), 200, {"X-Something": "value"}
@self.app.route("/version", methods=["GET"])
def get_version():
with open(VERSION_PATH, "r") as f:
version = f.read().strip()
response = VersionResponse(version=version)
return jsonify(response.model_dump()), 200
def run_api_server(
self, host: str = "0.0.0.0", port: int = 8000, remote: bool = True, **kwargs
):
"""
Run the pipeline as an api server.
Here is api endpoint documentation => https://docs.auto-rag.com/deploy/api_endpoint.html
:param host: The host of the api server.
:param port: The port of the api server.
:param remote: Whether to expose the api server to the public internet using ngrok.
:param kwargs: Other arguments for Flask app.run.
"""
logger.info(f"Run api server at {host}:{port}")
if remote:
from pyngrok import ngrok
http_tunnel = ngrok.connect(str(port), "http")
public_url = http_tunnel.public_url
logger.info(f"Public API URL: {public_url}")
self.app.run(host=host, port=port, **kwargs)
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
scores = df["retrieve_scores"].tolist()[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
]
else:
paths = [None] * len(retrieved_ids)
metadatas = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="metadata"
)[0]
if "start_end_idx" in self.corpus_df.columns:
start_end_indices = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
)[0]
else:
start_end_indices = [None] * len(retrieved_ids)
start_end_indices = to_list(start_end_indices)
return list(
map(
lambda content,
doc_id,
score,
path,
metadata,
start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
score=score,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
scores,
paths,
metadatas,
start_end_indices,
)
)