Fix Dockerfile build issue
This commit is contained in:
9
autorag/deploy/__init__.py
Normal file
9
autorag/deploy/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .base import (
|
||||
extract_node_line_names,
|
||||
extract_node_strategy,
|
||||
summary_df_to_yaml,
|
||||
extract_best_config,
|
||||
Runner,
|
||||
)
|
||||
from .api import ApiRunner
|
||||
from .gradio import GradioRunner
|
||||
293
autorag/deploy/api.py
Normal file
293
autorag/deploy/api.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Dict, Optional, List, Union, Literal
|
||||
|
||||
import pandas as pd
|
||||
from quart import Quart, request, jsonify
|
||||
from quart.helpers import stream_with_context
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from autorag.deploy.base import BaseRunner
|
||||
from autorag.nodes.generator.base import BaseGenerator
|
||||
from autorag.nodes.promptmaker.base import BasePromptMaker
|
||||
from autorag.utils.util import fetch_contents, to_list
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
deploy_dir = pathlib.Path(__file__).parent
|
||||
root_dir = pathlib.Path(__file__).parent.parent
|
||||
|
||||
VERSION_PATH = os.path.join(root_dir, "VERSION")
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
result_column: Optional[str] = "generated_texts"
|
||||
|
||||
|
||||
class RetrievedPassage(BaseModel):
|
||||
content: str
|
||||
doc_id: str
|
||||
score: float
|
||||
filepath: Optional[str] = None
|
||||
file_page: Optional[int] = None
|
||||
start_idx: Optional[int] = None
|
||||
end_idx: Optional[int] = None
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
result: Union[str, List[str]]
|
||||
retrieved_passage: List[RetrievedPassage]
|
||||
|
||||
|
||||
class RetrievalResponse(BaseModel):
|
||||
passages: List[RetrievedPassage]
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
"""
|
||||
When the type is generated_text, only generated_text is returned. The other fields are None.
|
||||
When the type is retrieved_passage, only retrieved_passage and passage_index are returned. The other fields are None.
|
||||
"""
|
||||
|
||||
type: Literal["generated_text", "retrieved_passage"]
|
||||
generated_text: Optional[str]
|
||||
retrieved_passage: Optional[RetrievedPassage]
|
||||
passage_index: Optional[int]
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class ApiRunner(BaseRunner):
|
||||
def __init__(self, config: Dict, project_dir: Optional[str] = None):
|
||||
super().__init__(config, project_dir)
|
||||
self.app = Quart(__name__)
|
||||
|
||||
data_dir = os.path.join(project_dir, "data")
|
||||
self.corpus_df = pd.read_parquet(
|
||||
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
|
||||
)
|
||||
self.__add_api_route()
|
||||
|
||||
def __add_api_route(self):
|
||||
@self.app.route("/v1/run", methods=["POST"])
|
||||
async def run_query():
|
||||
try:
|
||||
data = await request.get_json()
|
||||
data = QueryRequest(**data)
|
||||
except ValidationError as e:
|
||||
return jsonify(e.errors()), 400
|
||||
|
||||
previous_result = pd.DataFrame(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": [data.query],
|
||||
"retrieval_gt": [[]],
|
||||
"generation_gt": [""],
|
||||
}
|
||||
) # pseudo qa data for execution
|
||||
for module_instance, module_param in zip(
|
||||
self.module_instances, self.module_params
|
||||
):
|
||||
new_result = module_instance.pure(
|
||||
previous_result=previous_result, **module_param
|
||||
)
|
||||
duplicated_columns = previous_result.columns.intersection(
|
||||
new_result.columns
|
||||
)
|
||||
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
||||
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
||||
|
||||
# Simulate processing the query
|
||||
generated_text = previous_result[data.result_column].tolist()[0]
|
||||
retrieved_passage = self.extract_retrieve_passage(previous_result)
|
||||
|
||||
response = RunResponse(
|
||||
result=generated_text, retrieved_passage=retrieved_passage
|
||||
)
|
||||
|
||||
return jsonify(response.model_dump()), 200
|
||||
|
||||
@self.app.route("/v1/retrieve", methods=["POST"])
|
||||
async def run_retrieve_only():
|
||||
data = await request.get_json()
|
||||
query = data.get("query", None)
|
||||
if query is None:
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Invalid request. You need to include 'query' in the request body."
|
||||
}
|
||||
), 400
|
||||
|
||||
previous_result = pd.DataFrame(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": [query],
|
||||
"retrieval_gt": [[]],
|
||||
"generation_gt": [""],
|
||||
}
|
||||
) # pseudo qa data for execution
|
||||
for module_instance, module_param in zip(
|
||||
self.module_instances, self.module_params
|
||||
):
|
||||
if isinstance(module_instance, BasePromptMaker) or isinstance(
|
||||
module_instance, BaseGenerator
|
||||
):
|
||||
continue
|
||||
new_result = module_instance.pure(
|
||||
previous_result=previous_result, **module_param
|
||||
)
|
||||
duplicated_columns = previous_result.columns.intersection(
|
||||
new_result.columns
|
||||
)
|
||||
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
||||
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
||||
|
||||
# Simulate processing the query
|
||||
retrieved_passages = self.extract_retrieve_passage(previous_result)
|
||||
|
||||
retrieval_response = RetrievalResponse(passages=retrieved_passages)
|
||||
return jsonify(retrieval_response.model_dump()), 200
|
||||
|
||||
@self.app.route("/v1/stream", methods=["POST"])
|
||||
async def stream_query():
|
||||
try:
|
||||
data = await request.get_json()
|
||||
data = QueryRequest(**data)
|
||||
except ValidationError as e:
|
||||
return jsonify(e.errors()), 400
|
||||
|
||||
@stream_with_context
|
||||
async def generate():
|
||||
previous_result = pd.DataFrame(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": [data.query],
|
||||
"retrieval_gt": [[]],
|
||||
"generation_gt": [""],
|
||||
}
|
||||
) # pseudo qa data for execution
|
||||
|
||||
for module_instance, module_param in zip(
|
||||
self.module_instances, self.module_params
|
||||
):
|
||||
if not isinstance(module_instance, BaseGenerator):
|
||||
new_result = module_instance.pure(
|
||||
previous_result=previous_result, **module_param
|
||||
)
|
||||
duplicated_columns = previous_result.columns.intersection(
|
||||
new_result.columns
|
||||
)
|
||||
drop_previous_result = previous_result.drop(
|
||||
columns=duplicated_columns
|
||||
)
|
||||
previous_result = pd.concat(
|
||||
[drop_previous_result, new_result], axis=1
|
||||
)
|
||||
else:
|
||||
retrieved_passages = self.extract_retrieve_passage(
|
||||
previous_result
|
||||
)
|
||||
for i, retrieved_passage in enumerate(retrieved_passages):
|
||||
yield (
|
||||
StreamResponse(
|
||||
type="retrieved_passage",
|
||||
generated_text=None,
|
||||
retrieved_passage=retrieved_passage,
|
||||
passage_index=i,
|
||||
)
|
||||
.model_dump_json()
|
||||
.encode("utf-8")
|
||||
)
|
||||
# Start streaming of the result
|
||||
assert len(previous_result) == 1
|
||||
prompt: str = previous_result["prompts"].tolist()[0]
|
||||
async for delta in module_instance.astream(
|
||||
prompt=prompt, **module_param
|
||||
):
|
||||
response = StreamResponse(
|
||||
type="generated_text",
|
||||
generated_text=delta,
|
||||
retrieved_passage=None,
|
||||
passage_index=None,
|
||||
)
|
||||
yield response.model_dump_json().encode("utf-8")
|
||||
|
||||
return generate(), 200, {"X-Something": "value"}
|
||||
|
||||
@self.app.route("/version", methods=["GET"])
|
||||
def get_version():
|
||||
with open(VERSION_PATH, "r") as f:
|
||||
version = f.read().strip()
|
||||
response = VersionResponse(version=version)
|
||||
return jsonify(response.model_dump()), 200
|
||||
|
||||
def run_api_server(
|
||||
self, host: str = "0.0.0.0", port: int = 8000, remote: bool = True, **kwargs
|
||||
):
|
||||
"""
|
||||
Run the pipeline as an api server.
|
||||
Here is api endpoint documentation => https://docs.auto-rag.com/deploy/api_endpoint.html
|
||||
|
||||
:param host: The host of the api server.
|
||||
:param port: The port of the api server.
|
||||
:param remote: Whether to expose the api server to the public internet using ngrok.
|
||||
:param kwargs: Other arguments for Flask app.run.
|
||||
"""
|
||||
logger.info(f"Run api server at {host}:{port}")
|
||||
if remote:
|
||||
from pyngrok import ngrok
|
||||
|
||||
http_tunnel = ngrok.connect(str(port), "http")
|
||||
public_url = http_tunnel.public_url
|
||||
logger.info(f"Public API URL: {public_url}")
|
||||
self.app.run(host=host, port=port, **kwargs)
|
||||
|
||||
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
|
||||
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
|
||||
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
|
||||
scores = df["retrieve_scores"].tolist()[0]
|
||||
if "path" in self.corpus_df.columns:
|
||||
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
|
||||
0
|
||||
]
|
||||
else:
|
||||
paths = [None] * len(retrieved_ids)
|
||||
metadatas = fetch_contents(
|
||||
self.corpus_df, [retrieved_ids], column_name="metadata"
|
||||
)[0]
|
||||
if "start_end_idx" in self.corpus_df.columns:
|
||||
start_end_indices = fetch_contents(
|
||||
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
|
||||
)[0]
|
||||
else:
|
||||
start_end_indices = [None] * len(retrieved_ids)
|
||||
start_end_indices = to_list(start_end_indices)
|
||||
return list(
|
||||
map(
|
||||
lambda content,
|
||||
doc_id,
|
||||
score,
|
||||
path,
|
||||
metadata,
|
||||
start_end_idx: RetrievedPassage(
|
||||
content=content,
|
||||
doc_id=doc_id,
|
||||
score=score,
|
||||
filepath=path,
|
||||
file_page=metadata.get("page", None),
|
||||
start_idx=start_end_idx[0] if start_end_idx else None,
|
||||
end_idx=start_end_idx[1] if start_end_idx else None,
|
||||
),
|
||||
contents,
|
||||
retrieved_ids,
|
||||
scores,
|
||||
paths,
|
||||
metadatas,
|
||||
start_end_indices,
|
||||
)
|
||||
)
|
||||
235
autorag/deploy/base.py
Normal file
235
autorag/deploy/base.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
import yaml
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
from autorag.utils.util import load_summary_file, load_yaml_config
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def extract_node_line_names(config_dict: Dict) -> List[str]:
|
||||
"""
|
||||
Extract node line names with the given config dictionary order.
|
||||
|
||||
:param config_dict: The YAML configuration dict for the pipeline.
|
||||
You can load this to access trail_folder/config.yaml.
|
||||
:return: The list of node line names.
|
||||
It is the order of the node line names in the pipeline.
|
||||
"""
|
||||
return [node_line["node_line_name"] for node_line in config_dict["node_lines"]]
|
||||
|
||||
|
||||
def extract_node_strategy(config_dict: Dict) -> Dict:
|
||||
"""
|
||||
Extract node strategies with the given config dictionary.
|
||||
The return value is a dictionary of the node type and its strategy.
|
||||
|
||||
:param config_dict: The YAML configuration dict for the pipeline.
|
||||
You can load this to access trail_folder/config.yaml.
|
||||
:return: Key is node_type and value is strategy dict.
|
||||
"""
|
||||
return {
|
||||
node["node_type"]: node.get("strategy", {})
|
||||
for node_line in config_dict["node_lines"]
|
||||
for node in node_line["nodes"]
|
||||
}
|
||||
|
||||
|
||||
def summary_df_to_yaml(summary_df: pd.DataFrame, config_dict: Dict) -> Dict:
|
||||
"""
|
||||
Convert trial summary dataframe to config yaml file.
|
||||
|
||||
:param summary_df: The trial summary dataframe of the evaluated trial.
|
||||
:param config_dict: The yaml configuration dict for the pipeline.
|
||||
You can load this to access trail_folder/config.yaml.
|
||||
:return: Dictionary of config yaml file.
|
||||
You can save this dictionary to yaml file.
|
||||
"""
|
||||
|
||||
# summary_df columns : 'node_line_name', 'node_type', 'best_module_filename',
|
||||
# 'best_module_name', 'best_module_params', 'best_execution_time'
|
||||
node_line_names = extract_node_line_names(config_dict)
|
||||
node_strategies = extract_node_strategy(config_dict)
|
||||
strategy_df = pd.DataFrame(
|
||||
{
|
||||
"node_type": list(node_strategies.keys()),
|
||||
"strategy": list(node_strategies.values()),
|
||||
}
|
||||
)
|
||||
summary_df = summary_df.merge(strategy_df, on="node_type", how="left")
|
||||
summary_df["categorical_node_line_name"] = pd.Categorical(
|
||||
summary_df["node_line_name"], categories=node_line_names, ordered=True
|
||||
)
|
||||
summary_df = summary_df.sort_values(by="categorical_node_line_name")
|
||||
grouped = summary_df.groupby("categorical_node_line_name", observed=False)
|
||||
|
||||
node_lines = [
|
||||
{
|
||||
"node_line_name": node_line_name,
|
||||
"nodes": [
|
||||
{
|
||||
"node_type": row["node_type"],
|
||||
"strategy": row["strategy"],
|
||||
"modules": [
|
||||
{
|
||||
"module_type": row["best_module_name"],
|
||||
**row["best_module_params"],
|
||||
}
|
||||
],
|
||||
}
|
||||
for _, row in node_line.iterrows()
|
||||
],
|
||||
}
|
||||
for node_line_name, node_line in grouped
|
||||
]
|
||||
return {"node_lines": node_lines}
|
||||
|
||||
|
||||
def extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Extract the optimal pipeline from the evaluated trial.
|
||||
|
||||
:param trial_path: The path to the trial directory that you want to extract the pipeline from.
|
||||
Must already be evaluated.
|
||||
:param output_path: Output path that pipeline yaml file will be saved.
|
||||
Must be .yaml or .yml file.
|
||||
If None, it does not save YAML file and just returns dict values.
|
||||
Default is None.
|
||||
:return: The dictionary of the extracted pipeline.
|
||||
"""
|
||||
summary_path = os.path.join(trial_path, "summary.csv")
|
||||
if not os.path.exists(summary_path):
|
||||
raise ValueError(f"summary.csv does not exist in {trial_path}.")
|
||||
trial_summary_df = load_summary_file(
|
||||
summary_path, dict_columns=["best_module_params"]
|
||||
)
|
||||
config_yaml_path = os.path.join(trial_path, "config.yaml")
|
||||
with open(config_yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)
|
||||
yaml_dict["vectordb"] = extract_vectordb_config(trial_path)
|
||||
if output_path is not None:
|
||||
with open(output_path, "w") as f:
|
||||
yaml.safe_dump(yaml_dict, f)
|
||||
return yaml_dict
|
||||
|
||||
|
||||
def extract_vectordb_config(trial_path: str) -> List[Dict]:
|
||||
# get vectordb.yaml file
|
||||
project_dir = pathlib.PurePath(os.path.realpath(trial_path)).parent
|
||||
vectordb_config_path = os.path.join(project_dir, "resources", "vectordb.yaml")
|
||||
if not os.path.exists(vectordb_config_path):
|
||||
raise ValueError(f"vectordb.yaml does not exist in {vectordb_config_path}.")
|
||||
with open(vectordb_config_path, "r") as f:
|
||||
vectordb_dict = yaml.safe_load(f)
|
||||
result = vectordb_dict.get("vectordb", [])
|
||||
if len(result) != 0:
|
||||
return result
|
||||
# return default setting of chroma
|
||||
return [
|
||||
{
|
||||
"name": "default",
|
||||
"db_type": "chroma",
|
||||
"client_type": "persistent",
|
||||
"embedding_model": "openai",
|
||||
"collection_name": "openai",
|
||||
"path": os.path.join(project_dir, "resources", "chroma"),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class BaseRunner:
|
||||
def __init__(self, config: Dict, project_dir: Optional[str] = None):
|
||||
self.config = config
|
||||
project_dir = os.getcwd() if project_dir is None else project_dir
|
||||
os.environ["PROJECT_DIR"] = project_dir
|
||||
|
||||
# init modules
|
||||
node_lines = deepcopy(self.config["node_lines"])
|
||||
self.module_instances = []
|
||||
self.module_params = []
|
||||
for node_line in node_lines:
|
||||
for node in node_line["nodes"]:
|
||||
if len(node["modules"]) != 1:
|
||||
raise ValueError(
|
||||
"The number of modules in a node must be 1 for using runner."
|
||||
"Please use extract_best_config method for extracting yaml file from evaluated trial."
|
||||
)
|
||||
module = node["modules"][0]
|
||||
module_type = module.pop("module_type")
|
||||
module_params = module
|
||||
module_instance = get_support_modules(module_type)(
|
||||
project_dir=project_dir,
|
||||
**module_params,
|
||||
)
|
||||
self.module_instances.append(module_instance)
|
||||
self.module_params.append(module_params)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str, project_dir: Optional[str] = None):
|
||||
"""
|
||||
Load Runner from the YAML file.
|
||||
Must be extracted YAML file from the evaluated trial using the extract_best_config method.
|
||||
|
||||
:param yaml_path: The path of the YAML file.
|
||||
:param project_dir: The path of the project directory.
|
||||
Default is the current directory.
|
||||
:return: Initialized Runner.
|
||||
"""
|
||||
config = load_yaml_config(yaml_path)
|
||||
return cls(config, project_dir=project_dir)
|
||||
|
||||
@classmethod
|
||||
def from_trial_folder(cls, trial_path: str):
|
||||
"""
|
||||
Load Runner from the evaluated trial folder.
|
||||
Must already be evaluated using Evaluator class.
|
||||
It sets the project_dir as the parent directory of the trial folder.
|
||||
|
||||
:param trial_path: The path of the trial folder.
|
||||
:return: Initialized Runner.
|
||||
"""
|
||||
config = extract_best_config(trial_path)
|
||||
return cls(config, project_dir=os.path.dirname(trial_path))
|
||||
|
||||
|
||||
class Runner(BaseRunner):
|
||||
def run(self, query: str, result_column: str = "generated_texts"):
|
||||
"""
|
||||
Run the pipeline with query.
|
||||
The loaded pipeline must start with a single query,
|
||||
so the first module of the pipeline must be `query_expansion` or `retrieval` module.
|
||||
|
||||
:param query: The query of the user.
|
||||
:param result_column: The result column name for the answer.
|
||||
Default is `generated_texts`, which is the output of the `generation` module.
|
||||
:return: The result of the pipeline.
|
||||
"""
|
||||
previous_result = pd.DataFrame(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": [query],
|
||||
"retrieval_gt": [[]],
|
||||
"generation_gt": [""],
|
||||
}
|
||||
) # pseudo qa data for execution
|
||||
for module_instance, module_param in zip(
|
||||
self.module_instances, self.module_params
|
||||
):
|
||||
new_result = module_instance.pure(
|
||||
previous_result=previous_result, **module_param
|
||||
)
|
||||
duplicated_columns = previous_result.columns.intersection(
|
||||
new_result.columns
|
||||
)
|
||||
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
||||
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
||||
|
||||
return previous_result[result_column].tolist()[0]
|
||||
74
autorag/deploy/gradio.py
Normal file
74
autorag/deploy/gradio.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.deploy.base import BaseRunner
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class GradioRunner(BaseRunner):
|
||||
def run_web(
|
||||
self,
|
||||
server_name: str = "0.0.0.0",
|
||||
server_port: int = 7680,
|
||||
share: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Run web interface to interact pipeline.
|
||||
You can access the web interface at `http://server_name:server_port` in your browser
|
||||
|
||||
:param server_name: The host of the web. Default is 0.0.0.0.
|
||||
:param server_port: The port of the web. Default is 7680.
|
||||
:param share: Whether to create a publicly shareable link. Default is False.
|
||||
:param kwargs: Other arguments for gr.ChatInterface.launch.
|
||||
"""
|
||||
|
||||
logger.info(f"Run web interface at http://{server_name}:{server_port}")
|
||||
|
||||
def get_response(message, _):
|
||||
return self.run(message)
|
||||
|
||||
gr.ChatInterface(
|
||||
get_response, title="📚 AutoRAG", retry_btn=None, undo_btn=None
|
||||
).launch(
|
||||
server_name=server_name, server_port=server_port, share=share, **kwargs
|
||||
)
|
||||
|
||||
def run(self, query: str, result_column: str = "generated_texts"):
|
||||
"""
|
||||
Run the pipeline with query.
|
||||
The loaded pipeline must start with a single query,
|
||||
so the first module of the pipeline must be `query_expansion` or `retrieval` module.
|
||||
|
||||
:param query: The query of the user.
|
||||
:param result_column: The result column name for the answer.
|
||||
Default is `generated_texts`, which is the output of the `generation` module.
|
||||
:return: The result of the pipeline.
|
||||
"""
|
||||
previous_result = pd.DataFrame(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": [query],
|
||||
"retrieval_gt": [[]],
|
||||
"generation_gt": [""],
|
||||
}
|
||||
) # pseudo qa data for execution
|
||||
for module_instance, module_param in zip(
|
||||
self.module_instances, self.module_params
|
||||
):
|
||||
new_result = module_instance.pure(
|
||||
previous_result=previous_result, **module_param
|
||||
)
|
||||
duplicated_columns = previous_result.columns.intersection(
|
||||
new_result.columns
|
||||
)
|
||||
drop_previous_result = previous_result.drop(columns=duplicated_columns)
|
||||
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
|
||||
|
||||
return previous_result[result_column].tolist()[0]
|
||||
202
autorag/deploy/swagger.yml
Normal file
202
autorag/deploy/swagger.yml
Normal file
@@ -0,0 +1,202 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: Example API
|
||||
version: 1.0.1
|
||||
paths:
|
||||
/v1/run:
|
||||
post:
|
||||
summary: Run a query and get generated text with retrieved passages
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The query string
|
||||
result_column:
|
||||
type: string
|
||||
description: The result column name
|
||||
default: generated_texts
|
||||
required:
|
||||
- query
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum:
|
||||
- generated_text
|
||||
- retrieved_passage
|
||||
description: |
|
||||
When the type is "generated_text", only "generated_text" is returned.
|
||||
The other fields are None. When the type is "retrieved_passage", only
|
||||
"retrieved_passage" and "passage_index" are returned. The other fields are None.
|
||||
generated_text:
|
||||
type: string
|
||||
nullable: true
|
||||
description: |
|
||||
The generated text, only present when "type" is "generated_text".
|
||||
retrieved_passage:
|
||||
type: object
|
||||
nullable: true
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
doc_id:
|
||||
type: string
|
||||
filepath:
|
||||
type: string
|
||||
nullable: true
|
||||
file_page:
|
||||
type: integer
|
||||
nullable: true
|
||||
start_idx:
|
||||
type: integer
|
||||
nullable: true
|
||||
end_idx:
|
||||
type: integer
|
||||
nullable: true
|
||||
passage_index:
|
||||
type: integer
|
||||
nullable: true
|
||||
description: |
|
||||
The index of the retrieved passage, only present when "type" is "retrieved_passage".
|
||||
required:
|
||||
- type
|
||||
oneOf:
|
||||
- required:
|
||||
- generated_text
|
||||
- required:
|
||||
- retrieved_passage
|
||||
- passage_index
|
||||
/v1/retrieve:
|
||||
post:
|
||||
summary: Retrieve documents based on a query
|
||||
operationId: runRetrieveOnly
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The query string to retrieve documents.
|
||||
required:
|
||||
- query
|
||||
responses:
|
||||
'200':
|
||||
description: Successful retrieval of documents
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
passages:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
doc_id:
|
||||
type: string
|
||||
description: The unique identifier for the document.
|
||||
content:
|
||||
type: string
|
||||
description: The content of the retrieved document.
|
||||
score:
|
||||
type: number
|
||||
format: float
|
||||
description: The score of the retrieved document.
|
||||
'400':
|
||||
description: Invalid request due to missing query parameter
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: Error message explaining the issue.
|
||||
/v1/stream:
|
||||
post:
|
||||
summary: Stream generated text with retrieved passages
|
||||
description: >
|
||||
This endpoint streams the generated text line by line. The `retrieved_passage`
|
||||
is sent first, followed by the `result` streamed incrementally.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The query string
|
||||
result_column:
|
||||
type: string
|
||||
description: The result column name
|
||||
default: generated_texts
|
||||
required:
|
||||
- query
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response with streaming
|
||||
content:
|
||||
text/event-stream:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
result:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: The result text or list of texts (streamed line by line)
|
||||
retrieved_passage:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
doc_id:
|
||||
type: string
|
||||
filepath:
|
||||
type: string
|
||||
nullable: true
|
||||
file_page:
|
||||
type: integer
|
||||
nullable: true
|
||||
start_idx:
|
||||
type: integer
|
||||
nullable: true
|
||||
end_idx:
|
||||
type: integer
|
||||
nullable: true
|
||||
|
||||
/version:
|
||||
get:
|
||||
summary: Get the API version
|
||||
description: Returns the current version of the API as a string.
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
version:
|
||||
type: string
|
||||
description: The version of the API
|
||||
Reference in New Issue
Block a user