Files
autorag_evaluation/autorag-workspace/autorag/cli.py
2025-03-14 17:28:01 +09:00

210 lines
6.1 KiB
Python

import importlib.resources
import logging
import os
import pathlib
import subprocess
from pathlib import Path
from typing import Optional
import click
import nest_asyncio
from autorag import dashboard
from autorag.deploy import extract_best_config as original_extract_best_config
from autorag.deploy.api import ApiRunner
from autorag.evaluator import Evaluator
from autorag.validator import Validator
logger = logging.getLogger("AutoRAG")
autorag_dir = os.path.dirname(os.path.realpath(__file__))
version_file = os.path.join(autorag_dir, "VERSION")
with open(version_file, "r") as f:
__version__ = f.read().strip()
@click.group()
@click.version_option(__version__)
def cli():
pass
@click.command()
@click.option(
"--config",
"-c",
help="Path to config yaml file. Must be yaml or yml file.",
type=str,
)
@click.option(
"--qa_data_path", help="Path to QA dataset. Must be parquet file.", type=str
)
@click.option(
"--corpus_data_path", help="Path to corpus dataset. Must be parquet file.", type=str
)
@click.option(
"--project_dir", help="Path to project directory.", type=str, default=None
)
@click.option(
"--skip_validation",
help="Skip validation or not. Default is False.",
type=bool,
default=False,
)
def evaluate(config, qa_data_path, corpus_data_path, project_dir, skip_validation):
if not config.endswith(".yaml") and not config.endswith(".yml"):
raise ValueError(f"Config file {config} is not a yaml or yml file.")
if not os.path.exists(config):
raise ValueError(f"Config file {config} does not exist.")
evaluator = Evaluator(qa_data_path, corpus_data_path, project_dir=project_dir)
evaluator.start_trial(config, skip_validation=skip_validation)
@click.command()
@click.option(
"--config_path", type=str, help="Path to extracted config yaml file.", default=None
)
@click.option("--host", type=str, default="0.0.0.0", help="Host address")
@click.option("--port", type=int, default=8000, help="Port number")
@click.option(
"--trial_dir",
type=click.Path(file_okay=False, dir_okay=True, exists=True),
default=None,
help="Path to trial directory.",
)
@click.option(
"--project_dir", help="Path to project directory.", type=str, default=None
)
@click.option(
"--remote", help="Run the API server in remote mode.", type=bool, default=False
)
def run_api(config_path, host, port, trial_dir, project_dir, remote: bool):
if trial_dir is None:
runner = ApiRunner.from_yaml(config_path, project_dir=project_dir)
else:
runner = ApiRunner.from_trial_folder(trial_dir)
logger.info(f"Running API server at {host}:{port}...")
nest_asyncio.apply()
runner.run_api_server(host, port, remote=remote)
@click.command()
@click.option(
"--yaml_path", type=click.Path(path_type=Path), help="Path to the YAML file."
)
@click.option(
"--project_dir",
type=click.Path(path_type=Path),
help="Path to the project directory.",
)
@click.option(
"--trial_path", type=click.Path(path_type=Path), help="Path to the trial directory."
)
def run_web(
yaml_path: Optional[str], project_dir: Optional[str], trial_path: Optional[str]
):
try:
with importlib.resources.path("autorag", "web.py") as web_path:
web_py_path = str(web_path)
except ImportError:
raise ImportError(
"Could not locate the web.py file within the autorag package."
" Please ensure that autorag is correctly installed."
)
if not yaml_path and not trial_path:
raise ValueError("yaml_path or trial_path must be given.")
elif yaml_path and trial_path:
raise ValueError("yaml_path and trial_path cannot be given at the same time.")
elif yaml_path and not project_dir:
subprocess.run(
["streamlit", "run", web_py_path, "--", "--yaml_path", yaml_path]
)
elif yaml_path and project_dir:
subprocess.run(
[
"streamlit",
"run",
web_py_path,
"--",
"--yaml_path",
yaml_path,
"--project_dir",
project_dir,
]
)
elif trial_path:
subprocess.run(
["streamlit", "run", web_py_path, "--", "--trial_path", trial_path]
)
@click.command()
@click.option(
"--trial_dir",
type=click.Path(dir_okay=True, file_okay=False, exists=True),
required=True,
)
@click.option(
"--port", type=int, default=7690, help="Port number. The default is 7690."
)
def run_dashboard(trial_dir: str, port: int):
dashboard.run(trial_dir, port=port)
@click.command()
@click.option("--trial_path", type=click.Path(), help="Path to the trial directory.")
@click.option(
"--output_path",
type=click.Path(),
help="Path to the output directory." " Must be .yaml or .yml file.",
)
def extract_best_config(trial_path: str, output_path: str):
original_extract_best_config(trial_path, output_path)
@click.command()
@click.option("--trial_path", help="Path to trial directory.", type=str)
def restart_evaluate(trial_path):
if not os.path.exists(trial_path):
raise ValueError(f"trial_path {trial_path} does not exist.")
project_dir = str(pathlib.PurePath(trial_path).parent)
qa_data_path = os.path.join(project_dir, "data", "qa.parquet")
corpus_data_path = os.path.join(project_dir, "data", "corpus.parquet")
evaluator = Evaluator(qa_data_path, corpus_data_path, project_dir)
evaluator.restart_trial(trial_path)
@click.command()
@click.option(
"--config",
"-c",
help="Path to config yaml file. Must be yaml or yml file.",
type=str,
)
@click.option(
"--qa_data_path", help="Path to QA dataset. Must be parquet file.", type=str
)
@click.option(
"--corpus_data_path", help="Path to corpus dataset. Must be parquet file.", type=str
)
def validate(config, qa_data_path, corpus_data_path):
if not config.endswith(".yaml") and not config.endswith(".yml"):
raise ValueError(f"Config file {config} is not a parquet file.")
if not os.path.exists(config):
raise ValueError(f"Config file {config} does not exist.")
validator = Validator(qa_data_path=qa_data_path, corpus_data_path=corpus_data_path)
validator.validate(config)
cli.add_command(evaluate, "evaluate")
cli.add_command(run_api, "run_api")
cli.add_command(run_web, "run_web")
cli.add_command(run_dashboard, "dashboard")
cli.add_command(extract_best_config, "extract_best_config")
cli.add_command(restart_evaluate, "restart_evaluate")
cli.add_command(validate, "validate")
if __name__ == "__main__":
cli()