Initial Commit

This commit is contained in:
kyy
2025-03-18 04:34:57 +00:00
commit a4981faeef
60 changed files with 2160 additions and 0 deletions

1
autorag/VERSION Normal file
View File

@@ -0,0 +1 @@
0.3.14

113
autorag/__init__.py Normal file
View File

@@ -0,0 +1,113 @@
import logging
import logging.config
import os
import sys
from random import random
from typing import List, Any
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.llms.mock import MockLLM
from llama_index.llms.bedrock import Bedrock
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingModelType
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from langchain_openai.embeddings import OpenAIEmbeddings
from rich.logging import RichHandler
from llama_index.llms.ollama import Ollama
version_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION")
with open(version_path, "r") as f:
__version__ = f.read().strip()
class LazyInit:
def __init__(self, factory, *args, **kwargs):
self._factory = factory
self._args = args
self._kwargs = kwargs
self._instance = None
def __call__(self):
if self._instance is None:
self._instance = self._factory(*self._args, **self._kwargs)
return self._instance
def __getattr__(self, name):
if self._instance is None:
self._instance = self._factory(*self._args, **self._kwargs)
return getattr(self._instance, name)
rich_format = "[%(filename)s:%(lineno)s] >> %(message)s"
logging.basicConfig(
level="INFO", format=rich_format, handlers=[RichHandler(rich_tracebacks=True)]
)
logger = logging.getLogger("AutoRAG")
def handle_exception(exc_type, exc_value, exc_traceback):
logger = logging.getLogger("AutoRAG")
logger.error("Unexpected exception", exc_info=(exc_type, exc_value, exc_traceback))
sys.excepthook = handle_exception
class AutoRAGBedrock(Bedrock):
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return self.complete(prompt, formatted=formatted, **kwargs)
generator_models = {
"openai": OpenAI,
"openailike": OpenAILike,
"mock": MockLLM,
"bedrock": AutoRAGBedrock,
"ollama": Ollama,
}
# embedding_models = {
# }
try:
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.llms.ollama import Ollama
generator_models["huggingfacellm"] = HuggingFaceLLM
generator_models["ollama"] = Ollama
except ImportError:
logger.info(
"You are using API version of AutoRAG. "
"To use local version, run pip install 'AutoRAG[gpu]'"
)
# try:
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# embedding_models["hf_all_mpnet_base_v2"] = HuggingFaceEmbedding # 250312 변경 - 김용연
# embedding_models["hf_KURE-v1"] = HuggingFaceEmbedding # 250312 변경 - 김용연
# embedding_models["hf_snowflake-arctic-embed-l-v2.0-ko"] = HuggingFaceEmbedding # 250313 변경 - 김용연
# except ImportError:
# logger.info(
# "You are using API version of AutoRAG."
# "To use local version, run pip install 'AutoRAG[gpu]'"
# )
try:
import transformers
transformers.logging.set_verbosity_error()
except ImportError:
logger.info(
"You are using API version of AutoRAG."
"To use local version, run pip install 'AutoRAG[gpu]'"
)

40
autorag/cli.py Normal file
View File

@@ -0,0 +1,40 @@
import logging
import os
import click
from autorag import dashboard
logger = logging.getLogger("AutoRAG")
autorag_dir = os.path.dirname(os.path.realpath(__file__))
version_file = os.path.join(autorag_dir, "VERSION")
with open(version_file, "r") as f:
__version__ = f.read().strip()
@click.group()
@click.version_option(__version__)
def cli():
pass
@click.command()
@click.option(
"--trial_dir",
type=click.Path(dir_okay=True, file_okay=False, exists=True),
required=True,
)
@click.option(
"--port", type=int, default=7690, help="Port number. The default is 7690."
)
def run_dashboard(trial_dir: str, port: int):
"""Runs the AutoRAG Dashboard."""
logger.info(f"Starting AutoRAG Dashboard on port {port}...")
dashboard.run(trial_dir, port=port)
cli.add_command(run_dashboard, "dashboard")
if __name__ == "__main__":
cli()

215
autorag/dashboard.py Normal file
View File

@@ -0,0 +1,215 @@
import ast
import logging
import os
from typing import Dict, List
import matplotlib.pyplot as plt
import pandas as pd
import panel as pn
import seaborn as sns
import yaml
from bokeh.models import NumberFormatter, BooleanFormatter
from autorag.utils.util import dict_to_markdown, dict_to_markdown_table
pn.extension(
"terminal",
"tabulator",
"mathjax",
"ipywidgets",
console_output="disable",
sizing_mode="stretch_width",
css_files=[
"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css"
],
)
logger = logging.getLogger("AutoRAG")
def find_node_dir(trial_dir: str) -> List[str]:
trial_summary_df = pd.read_csv(os.path.join(trial_dir, "summary.csv"))
result_paths = []
for idx, row in trial_summary_df.iterrows():
node_line_name = row["node_line_name"]
node_type = row["node_type"]
result_paths.append(os.path.join(trial_dir, node_line_name, node_type))
return result_paths
def get_metric_values(node_summary_df: pd.DataFrame) -> Dict:
non_metric_column_names = [
"filename",
"module_name",
"module_params",
"execution_time",
"average_output_token",
"is_best",
]
best_row = node_summary_df.loc[node_summary_df["is_best"]].drop(
columns=non_metric_column_names, errors="ignore"
)
assert len(best_row) == 1, "The best module must be only one."
return best_row.iloc[0].to_dict()
def make_trial_summary_md(trial_dir):
markdown_text = f"""# Trial Result Summary
- Trial Directory : {trial_dir}
"""
node_dirs = find_node_dir(trial_dir)
for node_dir in node_dirs:
node_summary_filepath = os.path.join(node_dir, "summary.csv")
node_type = os.path.basename(node_dir)
node_summary_df = pd.read_csv(node_summary_filepath)
best_row = node_summary_df.loc[node_summary_df["is_best"]].iloc[0]
metric_dict = get_metric_values(node_summary_df)
markdown_text += f"""---
## {node_type} best module
### Module Name
{best_row['module_name']}
### Module Params
{dict_to_markdown(ast.literal_eval(best_row['module_params']), level=3)}
### Metric Values
{dict_to_markdown_table(metric_dict, key_column_name='metric_name', value_column_name='metric_value')}
"""
return markdown_text
def node_view(node_dir: str):
non_metric_column_names = [
"filename",
"module_name",
"module_params",
"execution_time",
"average_output_token",
"is_best",
]
summary_df = pd.read_csv(os.path.join(node_dir, "summary.csv"))
bokeh_formatters = {
"float": NumberFormatter(format="0.000"),
"bool": BooleanFormatter(),
}
first_df = pd.read_parquet(os.path.join(node_dir, "0.parquet"), engine="pyarrow")
each_module_df_widget = pn.widgets.Tabulator(
pd.DataFrame(columns=first_df.columns),
name="Module DataFrame",
formatters=bokeh_formatters,
pagination="local",
page_size=20,
widths=150,
)
def change_module_widget(event):
if event.column == "detail":
filename = summary_df["filename"].iloc[event.row]
filepath = os.path.join(node_dir, filename)
each_module_df = pd.read_parquet(filepath, engine="pyarrow")
each_module_df_widget.value = each_module_df
df_widget = pn.widgets.Tabulator(
summary_df,
name="Summary DataFrame",
formatters=bokeh_formatters,
buttons={"detail": '<i class="fa fa-eye"></i>'},
widths=150,
)
df_widget.on_click(change_module_widget)
try:
fig, ax = plt.subplots(figsize=(10, 5))
metric_df = summary_df.drop(columns=non_metric_column_names, errors="ignore")
sns.stripplot(data=metric_df, ax=ax)
strip_plot_pane = pn.pane.Matplotlib(fig, tight=True)
fig2, ax2 = plt.subplots(figsize=(10, 5))
sns.boxplot(data=metric_df, ax=ax2)
box_plot_pane = pn.pane.Matplotlib(fig2, tight=True)
plot_pane = pn.Row(strip_plot_pane, box_plot_pane)
layout = pn.Column(
"## Summary distribution plot",
plot_pane,
"## Summary DataFrame",
df_widget,
"## Module Result DataFrame",
each_module_df_widget,
)
except Exception as e:
logger.error(f"Skipping make boxplot and stripplot with error {e}")
layout = pn.Column("## Summary DataFrame", df_widget)
layout.servable()
return layout
CSS = """
div.card-margin:nth-child(1) {
max-height: 300px;
}
div.card-margin:nth-child(2) {
max-height: 400px;
}
"""
def yaml_to_markdown(yaml_filepath):
markdown_content = ""
with open(yaml_filepath, "r", encoding="utf-8") as file:
try:
content = yaml.safe_load(file)
markdown_content += f"## {os.path.basename(yaml_filepath)}\n```yaml\n{yaml.safe_dump(content, allow_unicode=True)}\n```\n\n"
except yaml.YAMLError as exc:
print(f"Error in {yaml_filepath}: {exc}")
return markdown_content
def run(trial_dir: str, port: int = 7690):
trial_summary_md = make_trial_summary_md(trial_dir=trial_dir)
trial_summary_tab = pn.pane.Markdown(trial_summary_md, sizing_mode="stretch_width")
node_views = [
(str(os.path.basename(node_dir)), pn.bind(node_view, node_dir))
for node_dir in find_node_dir(trial_dir)
]
"""
수정 전
node_views = [
(str(os.path.basename(node_dir)), node_view(node_dir))
for node_dir in find_node_dir(trial_dir)
]
"""
yaml_file_markdown = yaml_to_markdown(os.path.join(trial_dir, "config.yaml"))
yaml_file = pn.pane.Markdown(yaml_file_markdown, sizing_mode="stretch_width")
tabs = pn.Tabs(
("Summary", trial_summary_tab),
*node_views,
("Used YAML file", yaml_file),
dynamic=True,
)
'''
수정 전
template = pn.template.FastListTemplate(
site="AutoRAG", title="Dashboard", main=[tabs], raw_css=[CSS]
).servable()
template.show(port=port)
'''
if CSS not in pn.config.raw_css:
pn.config.raw_css.append(CSS)
template = pn.template.FastListTemplate(
site="AutoRAG", title="Dashboard", main=[tabs]
).servable()
pn.serve(template, port=port, show=False)

View File

@@ -0,0 +1,3 @@
from .module import Module
from .node import Node
from .base import BaseModule

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

35
autorag/schema/base.py Normal file
View File

@@ -0,0 +1,35 @@
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Union
import pandas as pd
class BaseModule(metaclass=ABCMeta):
@abstractmethod
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
pass
@abstractmethod
def _pure(self, *args, **kwargs):
pass
@classmethod
def run_evaluator(
cls,
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args,
**kwargs,
):
instance = cls(project_dir, *args, **kwargs)
result = instance.pure(previous_result, *args, **kwargs)
del instance
return result
@abstractmethod
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
"""
This function is for cast function (a.k.a decorator) only for pure function in the whole node.
"""
pass

View File

@@ -0,0 +1,99 @@
from dataclasses import dataclass
from typing import Optional, List, Dict, Callable, Any, Union
import numpy as np
import pandas as pd
@dataclass
class MetricInput:
query: Optional[str] = None
queries: Optional[List[str]] = None
retrieval_gt_contents: Optional[List[List[str]]] = None
retrieved_contents: Optional[List[str]] = None
retrieval_gt: Optional[List[List[str]]] = None
retrieved_ids: Optional[List[str]] = None
prompt: Optional[str] = None
generated_texts: Optional[str] = None
generation_gt: Optional[List[str]] = None
generated_log_probs: Optional[List[float]] = None
def is_fields_notnone(self, fields_to_check: List[str]) -> bool:
for field in fields_to_check:
actual_value = getattr(self, field)
if actual_value is None:
return False
try:
if not type_checks.get(type(actual_value), lambda _: False)(
actual_value
):
return False
except Exception:
return False
return True
@classmethod
def from_dataframe(cls, qa_data: pd.DataFrame) -> List["MetricInput"]:
"""
Convert a pandas DataFrame into a list of MetricInput instances.
qa_data: pd.DataFrame: qa_data DataFrame containing metric data.
:returns: List[MetricInput]: List of MetricInput objects created from DataFrame rows.
"""
instances = []
for _, row in qa_data.iterrows():
instance = cls()
for attr_name in cls.__annotations__:
if attr_name in row:
value = row[attr_name]
if isinstance(value, str):
setattr(
instance,
attr_name,
value.strip() if value.strip() != "" else None,
)
elif isinstance(value, list):
setattr(instance, attr_name, value if len(value) > 0 else None)
else:
setattr(instance, attr_name, value)
instances.append(instance)
return instances
@staticmethod
def _check_list(lst_or_arr: Union[List[Any], np.ndarray]) -> bool:
if isinstance(lst_or_arr, np.ndarray):
lst_or_arr = lst_or_arr.flatten().tolist()
if len(lst_or_arr) == 0:
return False
for item in lst_or_arr:
if item is None:
return False
item_type = type(item)
if item_type in type_checks:
if not type_checks[item_type](item):
return False
else:
return False
return True
type_checks: Dict[type, Callable[[Any], bool]] = {
str: lambda x: len(x.strip()) > 0,
list: MetricInput._check_list,
np.ndarray: MetricInput._check_list,
int: lambda _: True,
float: lambda _: True,
}

24
autorag/schema/module.py Normal file
View File

@@ -0,0 +1,24 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Callable, Dict
from autorag.support import get_support_modules
@dataclass
class Module:
module_type: str
module_param: Dict
module: Callable = field(init=False)
def __post_init__(self):
self.module = get_support_modules(self.module_type)
if self.module is None:
raise ValueError(f"Module type {self.module_type} is not supported.")
@classmethod
def from_dict(cls, module_dict: Dict) -> "Module":
_module_dict = deepcopy(module_dict)
module_type = _module_dict.pop("module_type")
module_params = _module_dict
return cls(module_type, module_params)

143
autorag/schema/node.py Normal file
View File

@@ -0,0 +1,143 @@
import itertools
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, List, Callable, Tuple, Any
import pandas as pd
from autorag.schema.module import Module
from autorag.support import get_support_nodes
from autorag.utils.util import make_combinations, explode, find_key_values
logger = logging.getLogger("AutoRAG")
@dataclass
class Node:
node_type: str
strategy: Dict
node_params: Dict
modules: List[Module]
run_node: Callable = field(init=False)
def __post_init__(self):
self.run_node = get_support_nodes(self.node_type)
if self.run_node is None:
raise ValueError(f"Node type {self.node_type} is not supported.")
def get_param_combinations(self) -> Tuple[List[Callable], List[Dict]]:
"""
This method returns a combination of module and node parameters, also corresponding modules.
:return: Each module and its module parameters.
:rtype: Tuple[List[Callable], List[Dict]]
"""
def make_single_combination(module: Module) -> List[Dict]:
input_dict = {**self.node_params, **module.module_param}
return make_combinations(input_dict)
combinations = list(map(make_single_combination, self.modules))
module_list, combination_list = explode(self.modules, combinations)
return list(map(lambda x: x.module, module_list)), combination_list
@classmethod
def from_dict(cls, node_dict: Dict) -> "Node":
_node_dict = deepcopy(node_dict)
node_type = _node_dict.pop("node_type")
strategy = _node_dict.pop("strategy")
modules = list(map(lambda x: Module.from_dict(x), _node_dict.pop("modules")))
node_params = _node_dict
return cls(node_type, strategy, node_params, modules)
def run(self, previous_result: pd.DataFrame, node_line_dir: str) -> pd.DataFrame:
logger.info(f"Running node {self.node_type}...")
input_modules, input_params = self.get_param_combinations()
return self.run_node(
modules=input_modules,
module_params=input_params,
previous_result=previous_result,
node_line_dir=node_line_dir,
strategies=self.strategy,
)
def extract_values(node: Node, key: str) -> List[str]:
"""
This function extract values from node's modules' module_param.
:param node: The node you want to extract values from.
:param key: The key of module_param that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
def extract_module_values(module: Module):
if key not in module.module_param:
return []
value = module.module_param[key]
if isinstance(value, str) or isinstance(value, int):
return [value]
elif isinstance(value, list):
return value
else:
raise ValueError(f"{key} must be str,list or int, but got {type(value)}")
values = list(map(extract_module_values, node.modules))
return list(set(list(itertools.chain.from_iterable(values))))
def extract_values_from_nodes(nodes: List[Node], key: str) -> List[str]:
"""
This function extract values from nodes' modules' module_param.
:param nodes: The nodes you want to extract values from.
:param key: The key of module_param that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
values = list(map(lambda node: extract_values(node, key), nodes))
return list(set(list(itertools.chain.from_iterable(values))))
def extract_values_from_nodes_strategy(nodes: List[Node], key: str) -> List[Any]:
"""
This function extract values from nodes' strategy.
:param nodes: The nodes you want to extract values from.
:param key: The key string that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
values = []
for node in nodes:
value_list = find_key_values(node.strategy, key)
if value_list:
values.extend(value_list)
return values
def module_type_exists(nodes: List[Node], module_type: str) -> bool:
"""
This function check if the module type exists in the nodes.
:param nodes: The nodes you want to check.
:param module_type: The module type you want to check.
:return: True if the module type exists in the nodes.
"""
return any(
list(
map(
lambda node: any(
list(
map(
lambda module: module.module_type == module_type,
node.modules,
)
)
),
nodes,
)
)
)

View File

@@ -0,0 +1,8 @@
from .preprocess import (
validate_qa_dataset,
validate_corpus_dataset,
cast_qa_dataset,
cast_corpus_dataset,
validate_qa_from_corpus_dataset,
)
from .util import fetch_contents, result_to_dataframe, sort_by_scores

Binary file not shown.

Binary file not shown.

Binary file not shown.

149
autorag/utils/preprocess.py Normal file
View File

@@ -0,0 +1,149 @@
from datetime import datetime
import numpy as np
import pandas as pd
from autorag.utils.util import preprocess_text
def validate_qa_dataset(df: pd.DataFrame):
columns = ["qid", "query", "retrieval_gt", "generation_gt"]
assert set(columns).issubset(
df.columns
), f"df must have columns {columns}, but got {df.columns}"
def validate_corpus_dataset(df: pd.DataFrame):
columns = ["doc_id", "contents", "metadata"]
assert set(columns).issubset(
df.columns
), f"df must have columns {columns}, but got {df.columns}"
def cast_qa_dataset(df: pd.DataFrame):
def cast_retrieval_gt(gt):
if isinstance(gt, str):
return [[gt]]
elif isinstance(gt, list):
if isinstance(gt[0], str):
return [gt]
elif isinstance(gt[0], list):
return gt
elif isinstance(gt[0], np.ndarray):
return cast_retrieval_gt(list(map(lambda x: x.tolist(), gt)))
else:
raise ValueError(
f"retrieval_gt must be str or list, but got {type(gt[0])}"
)
elif isinstance(gt, np.ndarray):
return cast_retrieval_gt(gt.tolist())
else:
raise ValueError(f"retrieval_gt must be str or list, but got {type(gt)}")
def cast_generation_gt(gt):
if isinstance(gt, str):
return [gt]
elif isinstance(gt, list):
return gt
elif isinstance(gt, np.ndarray):
return cast_generation_gt(gt.tolist())
else:
raise ValueError(f"generation_gt must be str or list, but got {type(gt)}")
df = df.reset_index(drop=True)
validate_qa_dataset(df)
assert df["qid"].apply(lambda x: isinstance(x, str)).sum() == len(
df
), "qid must be string type."
assert df["query"].apply(lambda x: isinstance(x, str)).sum() == len(
df
), "query must be string type."
df["retrieval_gt"] = df["retrieval_gt"].apply(cast_retrieval_gt)
df["generation_gt"] = df["generation_gt"].apply(cast_generation_gt)
df["query"] = df["query"].apply(preprocess_text)
df["generation_gt"] = df["generation_gt"].apply(
lambda x: list(map(preprocess_text, x))
)
return df
def cast_corpus_dataset(df: pd.DataFrame):
df = df.reset_index(drop=True)
validate_corpus_dataset(df)
# drop rows that have empty contents
df = df[~df["contents"].apply(lambda x: x is None or x.isspace())]
def make_datetime_metadata(x):
if x is None or x == {}:
return {"last_modified_datetime": datetime.now()}
elif x.get("last_modified_datetime") is None:
return {**x, "last_modified_datetime": datetime.now()}
else:
return x
df["metadata"] = df["metadata"].apply(make_datetime_metadata)
# check every metadata have a datetime key
assert sum(
df["metadata"].apply(lambda x: x.get("last_modified_datetime") is not None)
) == len(df), "Every metadata must have a datetime key."
def make_prev_next_id_metadata(x, id_type: str):
if x is None or x == {}:
return {id_type: None}
elif x.get(id_type) is None:
return {**x, id_type: None}
else:
return x
df["metadata"] = df["metadata"].apply(
lambda x: make_prev_next_id_metadata(x, "prev_id")
)
df["metadata"] = df["metadata"].apply(
lambda x: make_prev_next_id_metadata(x, "next_id")
)
df["contents"] = df["contents"].apply(preprocess_text)
def normalize_unicode_metadata(metadata: dict):
result = {}
for key, value in metadata.items():
if isinstance(value, str):
result[key] = preprocess_text(value)
else:
result[key] = value
return result
df["metadata"] = df["metadata"].apply(normalize_unicode_metadata)
# check every metadata have a prev_id, next_id key
assert all(
"prev_id" in metadata for metadata in df["metadata"]
), "Every metadata must have a prev_id key."
assert all(
"next_id" in metadata for metadata in df["metadata"]
), "Every metadata must have a next_id key."
return df
def validate_qa_from_corpus_dataset(qa_df: pd.DataFrame, corpus_df: pd.DataFrame):
qa_ids = []
for retrieval_gt in qa_df["retrieval_gt"].tolist():
if isinstance(retrieval_gt, list) and (
retrieval_gt[0] != [] or any(bool(g) is True for g in retrieval_gt)
):
for gt in retrieval_gt:
qa_ids.extend(gt)
elif isinstance(retrieval_gt, np.ndarray) and retrieval_gt[0].size > 0:
for gt in retrieval_gt:
qa_ids.extend(gt)
no_exist_ids = list(
filter(lambda qa_id: corpus_df[corpus_df["doc_id"] == qa_id].empty, qa_ids)
)
assert (
len(no_exist_ids) == 0
), f"{len(no_exist_ids)} doc_ids in retrieval_gt do not exist in corpus_df."

751
autorag/utils/util.py Normal file
View File

@@ -0,0 +1,751 @@
import ast
import asyncio
import datetime
import functools
import glob
import inspect
import itertools
import json
import logging
import os
import re
import string
from copy import deepcopy
from json import JSONDecoder
from typing import List, Callable, Dict, Optional, Any, Collection, Iterable
from asyncio import AbstractEventLoop
import emoji
import numpy as np
import pandas as pd
import tiktoken
import unicodedata
import yaml
from llama_index.embeddings.openai import OpenAIEmbedding
from pydantic import BaseModel as BM
from pydantic.v1 import BaseModel
logger = logging.getLogger("AutoRAG")
def fetch_contents(
corpus_data: pd.DataFrame, ids: List[List[str]], column_name: str = "contents"
) -> List[List[Any]]:
def fetch_contents_pure(
ids: List[str], corpus_data: pd.DataFrame, column_name: str
):
return list(map(lambda x: fetch_one_content(corpus_data, x, column_name), ids))
result = flatten_apply(
fetch_contents_pure, ids, corpus_data=corpus_data, column_name=column_name
)
return result
def fetch_one_content(
corpus_data: pd.DataFrame,
id_: str,
column_name: str = "contents",
id_column_name: str = "doc_id",
) -> Any:
if isinstance(id_, str):
if id_ in ["", ""]:
return None
fetch_result = corpus_data[corpus_data[id_column_name] == id_]
if fetch_result.empty:
raise ValueError(f"doc_id: {id_} not found in corpus_data.")
else:
return fetch_result[column_name].iloc[0]
else:
return None
def result_to_dataframe(column_names: List[str]):
"""
Decorator for converting results to pd.DataFrame.
"""
def decorator_result_to_dataframe(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs) -> pd.DataFrame:
results = func(*args, **kwargs)
if len(column_names) == 1:
df_input = {column_names[0]: results}
else:
df_input = {
column_name: result
for result, column_name in zip(results, column_names)
}
result_df = pd.DataFrame(df_input)
return result_df
return wrapper
return decorator_result_to_dataframe
def load_summary_file(
summary_path: str, dict_columns: Optional[List[str]] = None
) -> pd.DataFrame:
"""
Load a summary file from summary_path.
:param summary_path: The path of the summary file.
:param dict_columns: The columns that are dictionary type.
You must fill this parameter if you want to load summary file properly.
Default is ['module_params'].
:return: The summary dataframe.
"""
if not os.path.exists(summary_path):
raise ValueError(f"summary.csv does not exist in {summary_path}.")
summary_df = pd.read_csv(summary_path)
if dict_columns is None:
dict_columns = ["module_params"]
if any([col not in summary_df.columns for col in dict_columns]):
raise ValueError(f"{dict_columns} must be in summary_df.columns.")
def convert_dict(elem):
try:
return ast.literal_eval(elem)
except:
# convert datetime or date to its object (recency filter)
date_object = convert_datetime_string(elem)
if date_object is None:
raise ValueError(
f"Malformed dict received : {elem}\nCan't convert to dict properly"
)
return {"threshold": date_object}
summary_df[dict_columns] = summary_df[dict_columns].map(convert_dict)
return summary_df
def convert_datetime_string(s):
# Regex to extract datetime arguments from the string
m = re.search(r"(datetime|date)(\((\d+)(,\s*\d+)*\))", s)
if m:
args = ast.literal_eval(m.group(2))
if m.group(1) == "datetime":
return datetime.datetime(*args)
elif m.group(1) == "date":
return datetime.date(*args)
return None
def make_combinations(target_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Make combinations from target_dict.
The target_dict key value must be a string,
and the value can be a list of values or single value.
If generates all combinations of values from target_dict,
which means generating dictionaries that contain only one value for each key,
and all dictionaries will be different from each other.
:param target_dict: The target dictionary.
:return: The list of generated dictionaries.
"""
dict_with_lists = dict(
map(
lambda x: (x[0], x[1] if isinstance(x[1], list) else [x[1]]),
target_dict.items(),
)
)
def delete_duplicate(x):
def is_hashable(obj):
try:
hash(obj)
return True
except TypeError:
return False
if any([not is_hashable(elem) for elem in x]):
# TODO: add duplication check for unhashable objects
return x
else:
return list(set(x))
dict_with_lists = dict(
map(lambda x: (x[0], delete_duplicate(x[1])), dict_with_lists.items())
)
combination = list(itertools.product(*dict_with_lists.values()))
combination_dicts = [
dict(zip(dict_with_lists.keys(), combo)) for combo in combination
]
return combination_dicts
def explode(index_values: Collection[Any], explode_values: Collection[Collection[Any]]):
"""
Explode index_values and explode_values.
The index_values and explode_values must have the same length.
It will flatten explode_values and keep index_values as a pair.
:param index_values: The index values.
:param explode_values: The exploded values.
:return: Tuple of exploded index_values and exploded explode_values.
"""
assert len(index_values) == len(
explode_values
), "Index values and explode values must have same length"
df = pd.DataFrame({"index_values": index_values, "explode_values": explode_values})
df = df.explode("explode_values")
return df["index_values"].tolist(), df["explode_values"].tolist()
def replace_value_in_dict(target_dict: Dict, key: str, replace_value: Any) -> Dict:
"""
Replace the value of a certain key in target_dict.
If there is no targeted key in target_dict, it will return target_dict.
:param target_dict: The target dictionary.
:param key: The key is to replace.
:param replace_value: The value to replace.
:return: The replaced dictionary.
"""
replaced_dict = deepcopy(target_dict)
if key not in replaced_dict:
return replaced_dict
replaced_dict[key] = replace_value
return replaced_dict
def normalize_string(s: str) -> str:
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
Lower text and remove punctuation, articles, and extra whitespace.
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def convert_string_to_tuple_in_dict(d):
"""Recursively converts strings that start with '(' and end with ')' to tuples in a dictionary."""
for key, value in d.items():
# If the value is a dictionary, recurse
if isinstance(value, dict):
convert_string_to_tuple_in_dict(value)
# If the value is a list, iterate through its elements
elif isinstance(value, list):
for i, item in enumerate(value):
# If an item in the list is a dictionary, recurse
if isinstance(item, dict):
convert_string_to_tuple_in_dict(item)
# If an item in the list is a string matching the criteria, convert it to a tuple
elif (
isinstance(item, str)
and item.startswith("(")
and item.endswith(")")
):
value[i] = ast.literal_eval(item)
# If the value is a string matching the criteria, convert it to a tuple
elif isinstance(value, str) and value.startswith("(") and value.endswith(")"):
d[key] = ast.literal_eval(value)
return d
def convert_env_in_dict(d: Dict):
"""
Recursively converts environment variable string in a dictionary to actual environment variable.
:param d: The dictionary to convert.
:return: The converted dictionary.
"""
env_pattern = re.compile(r".*?\${(.*?)}.*?")
def convert_env(val: str):
matches = env_pattern.findall(val)
for match in matches:
val = val.replace(f"${{{match}}}", os.environ.get(match, ""))
return val
for key, value in d.items():
if isinstance(value, dict):
convert_env_in_dict(value)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
convert_env_in_dict(item)
elif isinstance(item, str):
value[i] = convert_env(item)
elif isinstance(value, str):
d[key] = convert_env(value)
return d
async def process_batch(tasks, batch_size: int = 64) -> List[Any]:
"""
Processes tasks in batches asynchronously.
:param tasks: A list of no-argument functions or coroutines to be executed.
:param batch_size: The number of tasks to process in a single batch.
Default is 64.
:return: A list of results from the processed tasks.
"""
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i : i + batch_size]
batch_results = await asyncio.gather(*batch)
results.extend(batch_results)
return results
def make_batch(elems: List[Any], batch_size: int) -> List[List[Any]]:
"""
Make a batch of elems with batch_size.
"""
return [elems[i : i + batch_size] for i in range(0, len(elems), batch_size)]
def save_parquet_safe(df: pd.DataFrame, filepath: str, upsert: bool = False):
output_file_dir = os.path.dirname(filepath)
if not os.path.isdir(output_file_dir):
raise NotADirectoryError(f"directory {output_file_dir} not found.")
if not filepath.endswith("parquet"):
raise NameError(
f'file path: {filepath} filename extension need to be ".parquet"'
)
if os.path.exists(filepath) and not upsert:
raise FileExistsError(
f"file {filepath} already exists."
"Set upsert True if you want to overwrite the file."
)
df.to_parquet(filepath, index=False)
def openai_truncate_by_token(
texts: List[str], token_limit: int, model_name: str
) -> List[str]:
try:
tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
# This is not a real OpenAI model
return texts
def truncate_text(text: str, limit: int, tokenizer):
tokens = tokenizer.encode(text)
if len(tokens) <= limit:
return text
truncated_text = tokenizer.decode(tokens[:limit])
return truncated_text
return list(map(lambda x: truncate_text(x, token_limit, tokenizer), texts))
def reconstruct_list(flat_list: List[Any], lengths: List[int]) -> List[List[Any]]:
result = []
start = 0
for length in lengths:
result.append(flat_list[start : start + length])
start += length
return result
def flatten_apply(
func: Callable, nested_list: List[List[Any]], **kwargs
) -> List[List[Any]]:
"""
This function flattens the input list and applies the function to the elements.
After that, it reconstructs the list to the original shape.
Its speciality is that the first dimension length of the list can be different from each other.
:param func: The function that applies to the flattened list.
:param nested_list: The nested list to be flattened.
:return: The list that is reconstructed after applying the function.
"""
df = pd.DataFrame({"col1": nested_list})
df = df.explode("col1")
df["result"] = func(df["col1"].tolist(), **kwargs)
return df.groupby(level=0, sort=False)["result"].apply(list).tolist()
async def aflatten_apply(
func: Callable, nested_list: List[List[Any]], **kwargs
) -> List[List[Any]]:
"""
This function flattens the input list and applies the function to the elements.
After that, it reconstructs the list to the original shape.
Its speciality is that the first dimension length of the list can be different from each other.
:param func: The function that applies to the flattened list.
:param nested_list: The nested list to be flattened.
:return: The list that is reconstructed after applying the function.
"""
df = pd.DataFrame({"col1": nested_list})
df = df.explode("col1")
df["result"] = await func(df["col1"].tolist(), **kwargs)
return df.groupby(level=0, sort=False)["result"].apply(list).tolist()
def sort_by_scores(row, reverse=True):
"""
Sorts each row by 'scores' column.
The input column names must be 'contents', 'ids', and 'scores'.
And its elements must be list type.
"""
results = sorted(
zip(row["contents"], row["ids"], row["scores"]),
key=lambda x: x[2],
reverse=reverse,
)
reranked_contents, reranked_ids, reranked_scores = zip(*results)
return list(reranked_contents), list(reranked_ids), list(reranked_scores)
def select_top_k(df, column_names: List[str], top_k: int):
for column_name in column_names:
df[column_name] = df[column_name].apply(lambda x: x[:top_k])
return df
def filter_dict_keys(dict_, keys: List[str]):
result = {}
for key in keys:
if key in dict_:
result[key] = dict_[key]
else:
raise KeyError(f"Key '{key}' not found in dictionary.")
return result
def split_dataframe(df, chunk_size):
num_chunks = (
len(df) // chunk_size + 1
if len(df) % chunk_size != 0
else len(df) // chunk_size
)
result = list(
map(lambda x: df[x * chunk_size : (x + 1) * chunk_size], range(num_chunks))
)
result = list(map(lambda x: x.reset_index(drop=True), result))
return result
def find_trial_dir(project_dir: str) -> List[str]:
# Pattern to match directories named with numbers
pattern = os.path.join(project_dir, "[0-9]*")
all_entries = glob.glob(pattern)
# Filter out only directories
trial_dirs = [
entry
for entry in all_entries
if os.path.isdir(entry) and entry.split(os.sep)[-1].isdigit()
]
return trial_dirs
def find_node_summary_files(trial_dir: str) -> List[str]:
# Find all summary.csv files recursively
all_summary_files = glob.glob(
os.path.join(trial_dir, "**", "summary.csv"), recursive=True
)
# Filter out files that are at a lower directory level
filtered_files = [
f for f in all_summary_files if f.count(os.sep) > trial_dir.count(os.sep) + 2
]
return filtered_files
def preprocess_text(text: str) -> str:
return normalize_unicode(demojize(text))
def demojize(text: str) -> str:
return emoji.demojize(text)
def normalize_unicode(text: str) -> str:
return unicodedata.normalize("NFC", text)
def dict_to_markdown(d, level=1):
"""
Convert a dictionary to a Markdown formatted string.
:param d: Dictionary to convert
:param level: Current level of heading (used for nested dictionaries)
:return: Markdown formatted string
"""
markdown = ""
for key, value in d.items():
if isinstance(value, dict):
markdown += f"{'#' * level} {key}\n"
markdown += dict_to_markdown(value, level + 1)
elif isinstance(value, list):
markdown += f"{'#' * level} {key}\n"
for item in value:
if isinstance(item, dict):
markdown += dict_to_markdown(item, level + 1)
else:
markdown += f"- {item}\n"
else:
markdown += f"{'#' * level} {key}\n{value}\n"
return markdown
def dict_to_markdown_table(data, key_column_name: str, value_column_name: str):
# Check if the input is a dictionary
if not isinstance(data, dict):
raise ValueError("Input must be a dictionary")
# Create the header of the table
header = f"| {key_column_name} | {value_column_name} |\n| :---: | :-----: |\n"
# Create the rows of the table
rows = ""
for key, value in data.items():
rows += f"| {key} | {value} |\n"
# Combine header and rows
markdown_table = header + rows
return markdown_table
def embedding_query_content(
queries: List[str],
contents_list: List[List[str]],
embedding_model: Optional[str] = None,
batch: int = 128,
):
flatten_contents = list(itertools.chain.from_iterable(contents_list))
openai_embedding_limit = 8000 # all openai embedding model has 8000 max token input
if isinstance(embedding_model, OpenAIEmbedding):
queries = openai_truncate_by_token(
queries, openai_embedding_limit, embedding_model.model_name
)
flatten_contents = openai_truncate_by_token(
flatten_contents, openai_embedding_limit, embedding_model.model_name
)
# Embedding using batch
embedding_model.embed_batch_size = batch
query_embeddings = embedding_model.get_text_embedding_batch(queries)
content_lengths = list(map(len, contents_list))
content_embeddings_flatten = embedding_model.get_text_embedding_batch(
flatten_contents
)
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths)
return query_embeddings, content_embeddings
def to_list(item):
"""Recursively convert collections to Python lists."""
if isinstance(item, np.ndarray):
# Convert numpy array to list and recursively process each element
return [to_list(sub_item) for sub_item in item.tolist()]
elif isinstance(item, pd.Series):
# Convert pandas Series to list and recursively process each element
return [to_list(sub_item) for sub_item in item.tolist()]
elif isinstance(item, Iterable) and not isinstance(
item, (str, bytes, BaseModel, BM)
):
# Recursively process each element in other iterables
return [to_list(sub_item) for sub_item in item]
else:
return item
def convert_inputs_to_list(func):
"""Decorator to convert all function inputs to Python lists."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
new_args = [to_list(arg) for arg in args]
new_kwargs = {k: to_list(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
return wrapper
def get_best_row(
summary_df: pd.DataFrame, best_column_name: str = "is_best"
) -> pd.Series:
"""
From the summary dataframe, find the best result row by 'is_best' column and return it.
:param summary_df: Summary dataframe created by AutoRAG.
:param best_column_name: The column name that indicates the best result.
Default is 'is_best'.
You don't have to change this unless the column name is different.
:return: Best row pandas Series instance.
"""
bests = summary_df.loc[summary_df[best_column_name]]
assert len(bests) == 1, "There must be only one best result."
return bests.iloc[0]
def get_event_loop() -> AbstractEventLoop:
"""
Get asyncio event loop safely.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def find_key_values(data, target_key: str) -> List[Any]:
"""
Recursively find all values for a specific key in a nested dictionary or list.
:param data: The dictionary or list to search.
:param target_key: The key to search for.
:return: A list of values associated with the target key.
"""
values = []
if isinstance(data, dict):
for key, value in data.items():
if key == target_key:
values.append(value)
if isinstance(value, (dict, list)):
values.extend(find_key_values(value, target_key))
elif isinstance(data, list):
for item in data:
if isinstance(item, (dict, list)):
values.extend(find_key_values(item, target_key))
return values
def pop_params(func: Callable, kwargs: Dict) -> Dict:
"""
Pop parameters from the given func and return them.
It automatically deletes the parameters like "self" or "cls".
:param func: The function to pop parameters.
:param kwargs: kwargs to pop parameters.
:return: The popped parameters.
"""
ignore_params = ["self", "cls"]
target_params = list(inspect.signature(func).parameters.keys())
target_params = list(filter(lambda x: x not in ignore_params, target_params))
init_params = {}
kwargs_keys = list(kwargs.keys())
for key in kwargs_keys:
if key in target_params:
init_params[key] = kwargs.pop(key)
return init_params
def apply_recursive(func, data):
"""
Recursively apply a function to all elements in a list, tuple, set, np.ndarray, or pd.Series and return as List.
:param func: Function to apply to each element.
:param data: List or nested list.
:return: List with the function applied to each element.
"""
if (
isinstance(data, list)
or isinstance(data, tuple)
or isinstance(data, set)
or isinstance(data, np.ndarray)
or isinstance(data, pd.Series)
):
return [apply_recursive(func, item) for item in data]
else:
return func(data)
def empty_cuda_cache():
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
def load_yaml_config(yaml_path: str) -> Dict:
"""
Load a YAML configuration file for AutoRAG.
It contains safe loading, converting string to tuple, and insert environment variables.
:param yaml_path: The path of the YAML configuration file.
:return: The loaded configuration dictionary.
"""
if not os.path.exists(yaml_path):
raise ValueError(f"YAML file {yaml_path} does not exist.")
with open(yaml_path, "r", encoding="utf-8") as stream:
try:
yaml_dict = yaml.safe_load(stream)
except yaml.YAMLError as exc:
raise ValueError(f"YAML file {yaml_path} could not be loaded.") from exc
yaml_dict = convert_string_to_tuple_in_dict(yaml_dict)
yaml_dict = convert_env_in_dict(yaml_dict)
return yaml_dict
def decode_multiple_json_from_bytes(byte_data: bytes) -> list:
"""
Decode multiple JSON objects from bytes received from SSE server.
Args:
byte_data: Bytes containing one or more JSON objects
Returns:
List of decoded JSON objects
"""
# Decode bytes to string
try:
text_data = byte_data.decode("utf-8").strip()
except UnicodeDecodeError:
raise ValueError("Invalid byte data: Unable to decode as UTF-8")
# Initialize decoder and result list
decoder = JSONDecoder()
result = []
# Keep track of position in string
pos = 0
text_data = text_data.strip()
while pos < len(text_data):
try:
# Try to decode next JSON object
json_obj, json_end = decoder.raw_decode(text_data[pos:])
result.append(json_obj)
# Move position to end of current JSON object
pos += json_end
# Skip any whitespace
while pos < len(text_data) and text_data[pos].isspace():
pos += 1
except json.JSONDecodeError:
# If we can't decode at current position, move forward one character
pos += 1
return result