294 lines
8.4 KiB
Python
294 lines
8.4 KiB
Python
import logging
|
|
import os
|
|
import pathlib
|
|
import uuid
|
|
from typing import Dict, Optional, List, Union, Literal
|
|
|
|
import pandas as pd
|
|
from quart import Quart, request, jsonify
|
|
from quart.helpers import stream_with_context
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from autorag.deploy.base import BaseRunner
|
|
from autorag.nodes.generator.base import BaseGenerator
|
|
from autorag.nodes.promptmaker.base import BasePromptMaker
|
|
from autorag.utils.util import fetch_contents, to_list
|
|
|
|
logger = logging.getLogger("AutoRAG")
|
|
|
|
deploy_dir = pathlib.Path(__file__).parent
|
|
root_dir = pathlib.Path(__file__).parent.parent
|
|
|
|
VERSION_PATH = os.path.join(root_dir, "VERSION")
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
result_column: Optional[str] = "generated_texts"
|
|
|
|
|
|
class RetrievedPassage(BaseModel):
|
|
content: str
|
|
doc_id: str
|
|
score: float
|
|
filepath: Optional[str] = None
|
|
file_page: Optional[int] = None
|
|
start_idx: Optional[int] = None
|
|
end_idx: Optional[int] = None
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
result: Union[str, List[str]]
|
|
retrieved_passage: List[RetrievedPassage]
|
|
|
|
|
|
class RetrievalResponse(BaseModel):
|
|
passages: List[RetrievedPassage]
|
|
|
|
|
|
class StreamResponse(BaseModel):
|
|
"""
|
|
When the type is generated_text, only generated_text is returned. The other fields are None.
|
|
When the type is retrieved_passage, only retrieved_passage and passage_index are returned. The other fields are None.
|
|
"""
|
|
|
|
type: Literal["generated_text", "retrieved_passage"]
|
|
generated_text: Optional[str]
|
|
retrieved_passage: Optional[RetrievedPassage]
|
|
passage_index: Optional[int]
|
|
|
|
|
|
class VersionResponse(BaseModel):
|
|
version: str
|
|
|
|
|
|
class ApiRunner(BaseRunner):
|
|
def __init__(self, config: Dict, project_dir: Optional[str] = None):
|
|
super().__init__(config, project_dir)
|
|
self.app = Quart(__name__)
|
|
|
|
data_dir = os.path.join(project_dir, "data")
|
|
self.corpus_df = pd.read_parquet(
|
|
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
|
|
)
|
|
self.__add_api_route()
|
|
|
|
def __add_api_route(self):
|
|
@self.app.route("/v1/run", methods=["POST"])
|
|
async def run_query():
|
|
try:
|
|
data = await request.get_json()
|
|
data = QueryRequest(**data)
|
|
except ValidationError as e:
|
|
return jsonify(e.errors()), 400
|
|
|
|
previous_result = pd.DataFrame(
|
|
{
|
|
"qid": str(uuid.uuid4()),
|
|
"query": [data.query],
|
|
"retrieval_gt": [[]],
|
|
"generation_gt": [""],
|
|
}
|
|
) # pseudo qa data for execution
|
|
for module_instance, module_param in zip(
|
|
self.module_instances, self.module_params
|
|
):
|
|
new_result = module_instance.pure(
|
|
previous_result=previous_result, **module_param
|
|
)
|
|
duplicated_columns = previous_result.columns.intersection(
|
|
new_result.columns
|
|
)
|
|
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
|
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
|
|
|
# Simulate processing the query
|
|
generated_text = previous_result[data.result_column].tolist()[0]
|
|
retrieved_passage = self.extract_retrieve_passage(previous_result)
|
|
|
|
response = RunResponse(
|
|
result=generated_text, retrieved_passage=retrieved_passage
|
|
)
|
|
|
|
return jsonify(response.model_dump()), 200
|
|
|
|
@self.app.route("/v1/retrieve", methods=["POST"])
|
|
async def run_retrieve_only():
|
|
data = await request.get_json()
|
|
query = data.get("query", None)
|
|
if query is None:
|
|
return jsonify(
|
|
{
|
|
"error": "Invalid request. You need to include 'query' in the request body."
|
|
}
|
|
), 400
|
|
|
|
previous_result = pd.DataFrame(
|
|
{
|
|
"qid": str(uuid.uuid4()),
|
|
"query": [query],
|
|
"retrieval_gt": [[]],
|
|
"generation_gt": [""],
|
|
}
|
|
) # pseudo qa data for execution
|
|
for module_instance, module_param in zip(
|
|
self.module_instances, self.module_params
|
|
):
|
|
if isinstance(module_instance, BasePromptMaker) or isinstance(
|
|
module_instance, BaseGenerator
|
|
):
|
|
continue
|
|
new_result = module_instance.pure(
|
|
previous_result=previous_result, **module_param
|
|
)
|
|
duplicated_columns = previous_result.columns.intersection(
|
|
new_result.columns
|
|
)
|
|
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
|
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
|
|
|
# Simulate processing the query
|
|
retrieved_passages = self.extract_retrieve_passage(previous_result)
|
|
|
|
retrieval_response = RetrievalResponse(passages=retrieved_passages)
|
|
return jsonify(retrieval_response.model_dump()), 200
|
|
|
|
@self.app.route("/v1/stream", methods=["POST"])
|
|
async def stream_query():
|
|
try:
|
|
data = await request.get_json()
|
|
data = QueryRequest(**data)
|
|
except ValidationError as e:
|
|
return jsonify(e.errors()), 400
|
|
|
|
@stream_with_context
|
|
async def generate():
|
|
previous_result = pd.DataFrame(
|
|
{
|
|
"qid": str(uuid.uuid4()),
|
|
"query": [data.query],
|
|
"retrieval_gt": [[]],
|
|
"generation_gt": [""],
|
|
}
|
|
) # pseudo qa data for execution
|
|
|
|
for module_instance, module_param in zip(
|
|
self.module_instances, self.module_params
|
|
):
|
|
if not isinstance(module_instance, BaseGenerator):
|
|
new_result = module_instance.pure(
|
|
previous_result=previous_result, **module_param
|
|
)
|
|
duplicated_columns = previous_result.columns.intersection(
|
|
new_result.columns
|
|
)
|
|
drop_previous_result = previous_result.drop(
|
|
columns=duplicated_columns
|
|
)
|
|
previous_result = pd.concat(
|
|
[drop_previous_result, new_result], axis=1
|
|
)
|
|
else:
|
|
retrieved_passages = self.extract_retrieve_passage(
|
|
previous_result
|
|
)
|
|
for i, retrieved_passage in enumerate(retrieved_passages):
|
|
yield (
|
|
StreamResponse(
|
|
type="retrieved_passage",
|
|
generated_text=None,
|
|
retrieved_passage=retrieved_passage,
|
|
passage_index=i,
|
|
)
|
|
.model_dump_json()
|
|
.encode("utf-8")
|
|
)
|
|
# Start streaming of the result
|
|
assert len(previous_result) == 1
|
|
prompt: str = previous_result["prompts"].tolist()[0]
|
|
async for delta in module_instance.astream(
|
|
prompt=prompt, **module_param
|
|
):
|
|
response = StreamResponse(
|
|
type="generated_text",
|
|
generated_text=delta,
|
|
retrieved_passage=None,
|
|
passage_index=None,
|
|
)
|
|
yield response.model_dump_json().encode("utf-8")
|
|
|
|
return generate(), 200, {"X-Something": "value"}
|
|
|
|
@self.app.route("/version", methods=["GET"])
|
|
def get_version():
|
|
with open(VERSION_PATH, "r") as f:
|
|
version = f.read().strip()
|
|
response = VersionResponse(version=version)
|
|
return jsonify(response.model_dump()), 200
|
|
|
|
def run_api_server(
|
|
self, host: str = "0.0.0.0", port: int = 8000, remote: bool = True, **kwargs
|
|
):
|
|
"""
|
|
Run the pipeline as an api server.
|
|
Here is api endpoint documentation => https://docs.auto-rag.com/deploy/api_endpoint.html
|
|
|
|
:param host: The host of the api server.
|
|
:param port: The port of the api server.
|
|
:param remote: Whether to expose the api server to the public internet using ngrok.
|
|
:param kwargs: Other arguments for Flask app.run.
|
|
"""
|
|
logger.info(f"Run api server at {host}:{port}")
|
|
if remote:
|
|
from pyngrok import ngrok
|
|
|
|
http_tunnel = ngrok.connect(str(port), "http")
|
|
public_url = http_tunnel.public_url
|
|
logger.info(f"Public API URL: {public_url}")
|
|
self.app.run(host=host, port=port, **kwargs)
|
|
|
|
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
|
|
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
|
|
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
|
|
scores = df["retrieve_scores"].tolist()[0]
|
|
if "path" in self.corpus_df.columns:
|
|
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
|
|
0
|
|
]
|
|
else:
|
|
paths = [None] * len(retrieved_ids)
|
|
metadatas = fetch_contents(
|
|
self.corpus_df, [retrieved_ids], column_name="metadata"
|
|
)[0]
|
|
if "start_end_idx" in self.corpus_df.columns:
|
|
start_end_indices = fetch_contents(
|
|
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
|
|
)[0]
|
|
else:
|
|
start_end_indices = [None] * len(retrieved_ids)
|
|
start_end_indices = to_list(start_end_indices)
|
|
return list(
|
|
map(
|
|
lambda content,
|
|
doc_id,
|
|
score,
|
|
path,
|
|
metadata,
|
|
start_end_idx: RetrievedPassage(
|
|
content=content,
|
|
doc_id=doc_id,
|
|
score=score,
|
|
filepath=path,
|
|
file_page=metadata.get("page", None),
|
|
start_idx=start_end_idx[0] if start_end_idx else None,
|
|
end_idx=start_end_idx[1] if start_end_idx else None,
|
|
),
|
|
contents,
|
|
retrieved_ids,
|
|
scores,
|
|
paths,
|
|
metadatas,
|
|
start_end_indices,
|
|
)
|
|
)
|