Initial Commit
This commit is contained in:
1
autorag/VERSION
Normal file
1
autorag/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0.3.14
|
||||
113
autorag/__init__.py
Normal file
113
autorag/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import sys
|
||||
from random import random
|
||||
from typing import List, Any
|
||||
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
from llama_index.core.llms.mock import MockLLM
|
||||
from llama_index.llms.bedrock import Bedrock
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbeddingModelType
|
||||
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from langchain_openai.embeddings import OpenAIEmbeddings
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
version_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION")
|
||||
|
||||
with open(version_path, "r") as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
|
||||
class LazyInit:
|
||||
def __init__(self, factory, *args, **kwargs):
|
||||
self._factory = factory
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self._instance = None
|
||||
|
||||
def __call__(self):
|
||||
if self._instance is None:
|
||||
self._instance = self._factory(*self._args, **self._kwargs)
|
||||
return self._instance
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._instance is None:
|
||||
self._instance = self._factory(*self._args, **self._kwargs)
|
||||
return getattr(self._instance, name)
|
||||
|
||||
|
||||
rich_format = "[%(filename)s:%(lineno)s] >> %(message)s"
|
||||
logging.basicConfig(
|
||||
level="INFO", format=rich_format, handlers=[RichHandler(rich_tracebacks=True)]
|
||||
)
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def handle_exception(exc_type, exc_value, exc_traceback):
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
logger.error("Unexpected exception", exc_info=(exc_type, exc_value, exc_traceback))
|
||||
|
||||
|
||||
sys.excepthook = handle_exception
|
||||
|
||||
|
||||
class AutoRAGBedrock(Bedrock):
|
||||
async def acomplete(
|
||||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||||
) -> CompletionResponse:
|
||||
return self.complete(prompt, formatted=formatted, **kwargs)
|
||||
|
||||
|
||||
generator_models = {
|
||||
"openai": OpenAI,
|
||||
"openailike": OpenAILike,
|
||||
"mock": MockLLM,
|
||||
"bedrock": AutoRAGBedrock,
|
||||
"ollama": Ollama,
|
||||
}
|
||||
|
||||
# embedding_models = {
|
||||
|
||||
# }
|
||||
|
||||
try:
|
||||
from llama_index.llms.huggingface import HuggingFaceLLM
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
generator_models["huggingfacellm"] = HuggingFaceLLM
|
||||
generator_models["ollama"] = Ollama
|
||||
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"You are using API version of AutoRAG. "
|
||||
"To use local version, run pip install 'AutoRAG[gpu]'"
|
||||
)
|
||||
|
||||
# try:
|
||||
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
# embedding_models["hf_all_mpnet_base_v2"] = HuggingFaceEmbedding # 250312 변경 - 김용연
|
||||
# embedding_models["hf_KURE-v1"] = HuggingFaceEmbedding # 250312 변경 - 김용연
|
||||
# embedding_models["hf_snowflake-arctic-embed-l-v2.0-ko"] = HuggingFaceEmbedding # 250313 변경 - 김용연
|
||||
|
||||
# except ImportError:
|
||||
# logger.info(
|
||||
# "You are using API version of AutoRAG."
|
||||
# "To use local version, run pip install 'AutoRAG[gpu]'"
|
||||
# )
|
||||
|
||||
try:
|
||||
import transformers
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"You are using API version of AutoRAG."
|
||||
"To use local version, run pip install 'AutoRAG[gpu]'"
|
||||
)
|
||||
40
autorag/cli.py
Normal file
40
autorag/cli.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from autorag import dashboard
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
autorag_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
version_file = os.path.join(autorag_dir, "VERSION")
|
||||
with open(version_file, "r") as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(__version__)
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--trial_dir",
|
||||
type=click.Path(dir_okay=True, file_okay=False, exists=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--port", type=int, default=7690, help="Port number. The default is 7690."
|
||||
)
|
||||
def run_dashboard(trial_dir: str, port: int):
|
||||
"""Runs the AutoRAG Dashboard."""
|
||||
logger.info(f"Starting AutoRAG Dashboard on port {port}...")
|
||||
dashboard.run(trial_dir, port=port)
|
||||
|
||||
|
||||
cli.add_command(run_dashboard, "dashboard")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
215
autorag/dashboard.py
Normal file
215
autorag/dashboard.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import panel as pn
|
||||
import seaborn as sns
|
||||
import yaml
|
||||
from bokeh.models import NumberFormatter, BooleanFormatter
|
||||
|
||||
from autorag.utils.util import dict_to_markdown, dict_to_markdown_table
|
||||
|
||||
pn.extension(
|
||||
"terminal",
|
||||
"tabulator",
|
||||
"mathjax",
|
||||
"ipywidgets",
|
||||
console_output="disable",
|
||||
sizing_mode="stretch_width",
|
||||
css_files=[
|
||||
"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css"
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def find_node_dir(trial_dir: str) -> List[str]:
|
||||
trial_summary_df = pd.read_csv(os.path.join(trial_dir, "summary.csv"))
|
||||
result_paths = []
|
||||
for idx, row in trial_summary_df.iterrows():
|
||||
node_line_name = row["node_line_name"]
|
||||
node_type = row["node_type"]
|
||||
result_paths.append(os.path.join(trial_dir, node_line_name, node_type))
|
||||
return result_paths
|
||||
|
||||
|
||||
def get_metric_values(node_summary_df: pd.DataFrame) -> Dict:
|
||||
non_metric_column_names = [
|
||||
"filename",
|
||||
"module_name",
|
||||
"module_params",
|
||||
"execution_time",
|
||||
"average_output_token",
|
||||
"is_best",
|
||||
]
|
||||
best_row = node_summary_df.loc[node_summary_df["is_best"]].drop(
|
||||
columns=non_metric_column_names, errors="ignore"
|
||||
)
|
||||
assert len(best_row) == 1, "The best module must be only one."
|
||||
return best_row.iloc[0].to_dict()
|
||||
|
||||
|
||||
def make_trial_summary_md(trial_dir):
|
||||
markdown_text = f"""# Trial Result Summary
|
||||
- Trial Directory : {trial_dir}
|
||||
|
||||
"""
|
||||
node_dirs = find_node_dir(trial_dir)
|
||||
for node_dir in node_dirs:
|
||||
node_summary_filepath = os.path.join(node_dir, "summary.csv")
|
||||
node_type = os.path.basename(node_dir)
|
||||
node_summary_df = pd.read_csv(node_summary_filepath)
|
||||
best_row = node_summary_df.loc[node_summary_df["is_best"]].iloc[0]
|
||||
metric_dict = get_metric_values(node_summary_df)
|
||||
markdown_text += f"""---
|
||||
|
||||
## {node_type} best module
|
||||
|
||||
### Module Name
|
||||
|
||||
{best_row['module_name']}
|
||||
|
||||
### Module Params
|
||||
|
||||
{dict_to_markdown(ast.literal_eval(best_row['module_params']), level=3)}
|
||||
|
||||
### Metric Values
|
||||
|
||||
{dict_to_markdown_table(metric_dict, key_column_name='metric_name', value_column_name='metric_value')}
|
||||
|
||||
"""
|
||||
|
||||
return markdown_text
|
||||
|
||||
|
||||
def node_view(node_dir: str):
|
||||
non_metric_column_names = [
|
||||
"filename",
|
||||
"module_name",
|
||||
"module_params",
|
||||
"execution_time",
|
||||
"average_output_token",
|
||||
"is_best",
|
||||
]
|
||||
summary_df = pd.read_csv(os.path.join(node_dir, "summary.csv"))
|
||||
bokeh_formatters = {
|
||||
"float": NumberFormatter(format="0.000"),
|
||||
"bool": BooleanFormatter(),
|
||||
}
|
||||
first_df = pd.read_parquet(os.path.join(node_dir, "0.parquet"), engine="pyarrow")
|
||||
|
||||
each_module_df_widget = pn.widgets.Tabulator(
|
||||
pd.DataFrame(columns=first_df.columns),
|
||||
name="Module DataFrame",
|
||||
formatters=bokeh_formatters,
|
||||
pagination="local",
|
||||
page_size=20,
|
||||
widths=150,
|
||||
)
|
||||
|
||||
def change_module_widget(event):
|
||||
if event.column == "detail":
|
||||
filename = summary_df["filename"].iloc[event.row]
|
||||
filepath = os.path.join(node_dir, filename)
|
||||
each_module_df = pd.read_parquet(filepath, engine="pyarrow")
|
||||
each_module_df_widget.value = each_module_df
|
||||
|
||||
df_widget = pn.widgets.Tabulator(
|
||||
summary_df,
|
||||
name="Summary DataFrame",
|
||||
formatters=bokeh_formatters,
|
||||
buttons={"detail": '<i class="fa fa-eye"></i>'},
|
||||
widths=150,
|
||||
)
|
||||
df_widget.on_click(change_module_widget)
|
||||
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
metric_df = summary_df.drop(columns=non_metric_column_names, errors="ignore")
|
||||
sns.stripplot(data=metric_df, ax=ax)
|
||||
strip_plot_pane = pn.pane.Matplotlib(fig, tight=True)
|
||||
|
||||
fig2, ax2 = plt.subplots(figsize=(10, 5))
|
||||
sns.boxplot(data=metric_df, ax=ax2)
|
||||
box_plot_pane = pn.pane.Matplotlib(fig2, tight=True)
|
||||
plot_pane = pn.Row(strip_plot_pane, box_plot_pane)
|
||||
|
||||
layout = pn.Column(
|
||||
"## Summary distribution plot",
|
||||
plot_pane,
|
||||
"## Summary DataFrame",
|
||||
df_widget,
|
||||
"## Module Result DataFrame",
|
||||
each_module_df_widget,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Skipping make boxplot and stripplot with error {e}")
|
||||
layout = pn.Column("## Summary DataFrame", df_widget)
|
||||
layout.servable()
|
||||
return layout
|
||||
|
||||
|
||||
CSS = """
|
||||
div.card-margin:nth-child(1) {
|
||||
max-height: 300px;
|
||||
}
|
||||
div.card-margin:nth-child(2) {
|
||||
max-height: 400px;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def yaml_to_markdown(yaml_filepath):
|
||||
markdown_content = ""
|
||||
with open(yaml_filepath, "r", encoding="utf-8") as file:
|
||||
try:
|
||||
content = yaml.safe_load(file)
|
||||
markdown_content += f"## {os.path.basename(yaml_filepath)}\n```yaml\n{yaml.safe_dump(content, allow_unicode=True)}\n```\n\n"
|
||||
except yaml.YAMLError as exc:
|
||||
print(f"Error in {yaml_filepath}: {exc}")
|
||||
return markdown_content
|
||||
|
||||
|
||||
def run(trial_dir: str, port: int = 7690):
|
||||
trial_summary_md = make_trial_summary_md(trial_dir=trial_dir)
|
||||
trial_summary_tab = pn.pane.Markdown(trial_summary_md, sizing_mode="stretch_width")
|
||||
|
||||
node_views = [
|
||||
(str(os.path.basename(node_dir)), pn.bind(node_view, node_dir))
|
||||
for node_dir in find_node_dir(trial_dir)
|
||||
]
|
||||
"""
|
||||
수정 전
|
||||
node_views = [
|
||||
(str(os.path.basename(node_dir)), node_view(node_dir))
|
||||
for node_dir in find_node_dir(trial_dir)
|
||||
]
|
||||
"""
|
||||
|
||||
yaml_file_markdown = yaml_to_markdown(os.path.join(trial_dir, "config.yaml"))
|
||||
|
||||
yaml_file = pn.pane.Markdown(yaml_file_markdown, sizing_mode="stretch_width")
|
||||
|
||||
tabs = pn.Tabs(
|
||||
("Summary", trial_summary_tab),
|
||||
*node_views,
|
||||
("Used YAML file", yaml_file),
|
||||
dynamic=True,
|
||||
)
|
||||
'''
|
||||
수정 전
|
||||
template = pn.template.FastListTemplate(
|
||||
site="AutoRAG", title="Dashboard", main=[tabs], raw_css=[CSS]
|
||||
).servable()
|
||||
template.show(port=port)
|
||||
'''
|
||||
if CSS not in pn.config.raw_css:
|
||||
pn.config.raw_css.append(CSS)
|
||||
|
||||
template = pn.template.FastListTemplate(
|
||||
site="AutoRAG", title="Dashboard", main=[tabs]
|
||||
).servable()
|
||||
pn.serve(template, port=port, show=False)
|
||||
3
autorag/schema/__init__.py
Normal file
3
autorag/schema/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .module import Module
|
||||
from .node import Node
|
||||
from .base import BaseModule
|
||||
BIN
autorag/schema/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
autorag/schema/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/schema/__pycache__/base.cpython-310.pyc
Normal file
BIN
autorag/schema/__pycache__/base.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/schema/__pycache__/metricinput.cpython-310.pyc
Normal file
BIN
autorag/schema/__pycache__/metricinput.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/schema/__pycache__/module.cpython-310.pyc
Normal file
BIN
autorag/schema/__pycache__/module.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/schema/__pycache__/node.cpython-310.pyc
Normal file
BIN
autorag/schema/__pycache__/node.cpython-310.pyc
Normal file
Binary file not shown.
35
autorag/schema/base.py
Normal file
35
autorag/schema/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseModule(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _pure(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run_evaluator(
|
||||
cls,
|
||||
project_dir: Union[str, Path],
|
||||
previous_result: pd.DataFrame,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
instance = cls(project_dir, *args, **kwargs)
|
||||
result = instance.pure(previous_result, *args, **kwargs)
|
||||
del instance
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
This function is for cast function (a.k.a decorator) only for pure function in the whole node.
|
||||
"""
|
||||
pass
|
||||
99
autorag/schema/metricinput.py
Normal file
99
autorag/schema/metricinput.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Callable, Any, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricInput:
|
||||
query: Optional[str] = None
|
||||
queries: Optional[List[str]] = None
|
||||
retrieval_gt_contents: Optional[List[List[str]]] = None
|
||||
retrieved_contents: Optional[List[str]] = None
|
||||
retrieval_gt: Optional[List[List[str]]] = None
|
||||
retrieved_ids: Optional[List[str]] = None
|
||||
prompt: Optional[str] = None
|
||||
generated_texts: Optional[str] = None
|
||||
generation_gt: Optional[List[str]] = None
|
||||
generated_log_probs: Optional[List[float]] = None
|
||||
|
||||
def is_fields_notnone(self, fields_to_check: List[str]) -> bool:
|
||||
for field in fields_to_check:
|
||||
actual_value = getattr(self, field)
|
||||
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not type_checks.get(type(actual_value), lambda _: False)(
|
||||
actual_value
|
||||
):
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_dataframe(cls, qa_data: pd.DataFrame) -> List["MetricInput"]:
|
||||
"""
|
||||
Convert a pandas DataFrame into a list of MetricInput instances.
|
||||
qa_data: pd.DataFrame: qa_data DataFrame containing metric data.
|
||||
|
||||
:returns: List[MetricInput]: List of MetricInput objects created from DataFrame rows.
|
||||
"""
|
||||
instances = []
|
||||
|
||||
for _, row in qa_data.iterrows():
|
||||
instance = cls()
|
||||
|
||||
for attr_name in cls.__annotations__:
|
||||
if attr_name in row:
|
||||
value = row[attr_name]
|
||||
|
||||
if isinstance(value, str):
|
||||
setattr(
|
||||
instance,
|
||||
attr_name,
|
||||
value.strip() if value.strip() != "" else None,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
setattr(instance, attr_name, value if len(value) > 0 else None)
|
||||
else:
|
||||
setattr(instance, attr_name, value)
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
return instances
|
||||
|
||||
@staticmethod
|
||||
def _check_list(lst_or_arr: Union[List[Any], np.ndarray]) -> bool:
|
||||
if isinstance(lst_or_arr, np.ndarray):
|
||||
lst_or_arr = lst_or_arr.flatten().tolist()
|
||||
|
||||
if len(lst_or_arr) == 0:
|
||||
return False
|
||||
|
||||
for item in lst_or_arr:
|
||||
if item is None:
|
||||
return False
|
||||
|
||||
item_type = type(item)
|
||||
|
||||
if item_type in type_checks:
|
||||
if not type_checks[item_type](item):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
type_checks: Dict[type, Callable[[Any], bool]] = {
|
||||
str: lambda x: len(x.strip()) > 0,
|
||||
list: MetricInput._check_list,
|
||||
np.ndarray: MetricInput._check_list,
|
||||
int: lambda _: True,
|
||||
float: lambda _: True,
|
||||
}
|
||||
24
autorag/schema/module.py
Normal file
24
autorag/schema/module.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module:
|
||||
module_type: str
|
||||
module_param: Dict
|
||||
module: Callable = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.module = get_support_modules(self.module_type)
|
||||
if self.module is None:
|
||||
raise ValueError(f"Module type {self.module_type} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, module_dict: Dict) -> "Module":
|
||||
_module_dict = deepcopy(module_dict)
|
||||
module_type = _module_dict.pop("module_type")
|
||||
module_params = _module_dict
|
||||
return cls(module_type, module_params)
|
||||
143
autorag/schema/node.py
Normal file
143
autorag/schema/node.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import itertools
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Callable, Tuple, Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.schema.module import Module
|
||||
from autorag.support import get_support_nodes
|
||||
from autorag.utils.util import make_combinations, explode, find_key_values
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
node_type: str
|
||||
strategy: Dict
|
||||
node_params: Dict
|
||||
modules: List[Module]
|
||||
run_node: Callable = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.run_node = get_support_nodes(self.node_type)
|
||||
if self.run_node is None:
|
||||
raise ValueError(f"Node type {self.node_type} is not supported.")
|
||||
|
||||
def get_param_combinations(self) -> Tuple[List[Callable], List[Dict]]:
|
||||
"""
|
||||
This method returns a combination of module and node parameters, also corresponding modules.
|
||||
|
||||
:return: Each module and its module parameters.
|
||||
:rtype: Tuple[List[Callable], List[Dict]]
|
||||
"""
|
||||
|
||||
def make_single_combination(module: Module) -> List[Dict]:
|
||||
input_dict = {**self.node_params, **module.module_param}
|
||||
return make_combinations(input_dict)
|
||||
|
||||
combinations = list(map(make_single_combination, self.modules))
|
||||
module_list, combination_list = explode(self.modules, combinations)
|
||||
return list(map(lambda x: x.module, module_list)), combination_list
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, node_dict: Dict) -> "Node":
|
||||
_node_dict = deepcopy(node_dict)
|
||||
node_type = _node_dict.pop("node_type")
|
||||
strategy = _node_dict.pop("strategy")
|
||||
modules = list(map(lambda x: Module.from_dict(x), _node_dict.pop("modules")))
|
||||
node_params = _node_dict
|
||||
return cls(node_type, strategy, node_params, modules)
|
||||
|
||||
def run(self, previous_result: pd.DataFrame, node_line_dir: str) -> pd.DataFrame:
|
||||
logger.info(f"Running node {self.node_type}...")
|
||||
input_modules, input_params = self.get_param_combinations()
|
||||
return self.run_node(
|
||||
modules=input_modules,
|
||||
module_params=input_params,
|
||||
previous_result=previous_result,
|
||||
node_line_dir=node_line_dir,
|
||||
strategies=self.strategy,
|
||||
)
|
||||
|
||||
|
||||
def extract_values(node: Node, key: str) -> List[str]:
|
||||
"""
|
||||
This function extract values from node's modules' module_param.
|
||||
|
||||
:param node: The node you want to extract values from.
|
||||
:param key: The key of module_param that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
|
||||
def extract_module_values(module: Module):
|
||||
if key not in module.module_param:
|
||||
return []
|
||||
value = module.module_param[key]
|
||||
if isinstance(value, str) or isinstance(value, int):
|
||||
return [value]
|
||||
elif isinstance(value, list):
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"{key} must be str,list or int, but got {type(value)}")
|
||||
|
||||
values = list(map(extract_module_values, node.modules))
|
||||
return list(set(list(itertools.chain.from_iterable(values))))
|
||||
|
||||
|
||||
def extract_values_from_nodes(nodes: List[Node], key: str) -> List[str]:
|
||||
"""
|
||||
This function extract values from nodes' modules' module_param.
|
||||
|
||||
:param nodes: The nodes you want to extract values from.
|
||||
:param key: The key of module_param that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
values = list(map(lambda node: extract_values(node, key), nodes))
|
||||
return list(set(list(itertools.chain.from_iterable(values))))
|
||||
|
||||
|
||||
def extract_values_from_nodes_strategy(nodes: List[Node], key: str) -> List[Any]:
|
||||
"""
|
||||
This function extract values from nodes' strategy.
|
||||
|
||||
:param nodes: The nodes you want to extract values from.
|
||||
:param key: The key string that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
values = []
|
||||
for node in nodes:
|
||||
value_list = find_key_values(node.strategy, key)
|
||||
if value_list:
|
||||
values.extend(value_list)
|
||||
return values
|
||||
|
||||
|
||||
def module_type_exists(nodes: List[Node], module_type: str) -> bool:
|
||||
"""
|
||||
This function check if the module type exists in the nodes.
|
||||
|
||||
:param nodes: The nodes you want to check.
|
||||
:param module_type: The module type you want to check.
|
||||
:return: True if the module type exists in the nodes.
|
||||
"""
|
||||
return any(
|
||||
list(
|
||||
map(
|
||||
lambda node: any(
|
||||
list(
|
||||
map(
|
||||
lambda module: module.module_type == module_type,
|
||||
node.modules,
|
||||
)
|
||||
)
|
||||
),
|
||||
nodes,
|
||||
)
|
||||
)
|
||||
)
|
||||
8
autorag/utils/__init__.py
Normal file
8
autorag/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .preprocess import (
|
||||
validate_qa_dataset,
|
||||
validate_corpus_dataset,
|
||||
cast_qa_dataset,
|
||||
cast_corpus_dataset,
|
||||
validate_qa_from_corpus_dataset,
|
||||
)
|
||||
from .util import fetch_contents, result_to_dataframe, sort_by_scores
|
||||
BIN
autorag/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
autorag/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/utils/__pycache__/preprocess.cpython-310.pyc
Normal file
BIN
autorag/utils/__pycache__/preprocess.cpython-310.pyc
Normal file
Binary file not shown.
BIN
autorag/utils/__pycache__/util.cpython-310.pyc
Normal file
BIN
autorag/utils/__pycache__/util.cpython-310.pyc
Normal file
Binary file not shown.
149
autorag/utils/preprocess.py
Normal file
149
autorag/utils/preprocess.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from autorag.utils.util import preprocess_text
|
||||
|
||||
|
||||
def validate_qa_dataset(df: pd.DataFrame):
|
||||
columns = ["qid", "query", "retrieval_gt", "generation_gt"]
|
||||
assert set(columns).issubset(
|
||||
df.columns
|
||||
), f"df must have columns {columns}, but got {df.columns}"
|
||||
|
||||
|
||||
def validate_corpus_dataset(df: pd.DataFrame):
|
||||
columns = ["doc_id", "contents", "metadata"]
|
||||
assert set(columns).issubset(
|
||||
df.columns
|
||||
), f"df must have columns {columns}, but got {df.columns}"
|
||||
|
||||
|
||||
def cast_qa_dataset(df: pd.DataFrame):
|
||||
def cast_retrieval_gt(gt):
|
||||
if isinstance(gt, str):
|
||||
return [[gt]]
|
||||
elif isinstance(gt, list):
|
||||
if isinstance(gt[0], str):
|
||||
return [gt]
|
||||
elif isinstance(gt[0], list):
|
||||
return gt
|
||||
elif isinstance(gt[0], np.ndarray):
|
||||
return cast_retrieval_gt(list(map(lambda x: x.tolist(), gt)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"retrieval_gt must be str or list, but got {type(gt[0])}"
|
||||
)
|
||||
elif isinstance(gt, np.ndarray):
|
||||
return cast_retrieval_gt(gt.tolist())
|
||||
else:
|
||||
raise ValueError(f"retrieval_gt must be str or list, but got {type(gt)}")
|
||||
|
||||
def cast_generation_gt(gt):
|
||||
if isinstance(gt, str):
|
||||
return [gt]
|
||||
elif isinstance(gt, list):
|
||||
return gt
|
||||
elif isinstance(gt, np.ndarray):
|
||||
return cast_generation_gt(gt.tolist())
|
||||
else:
|
||||
raise ValueError(f"generation_gt must be str or list, but got {type(gt)}")
|
||||
|
||||
df = df.reset_index(drop=True)
|
||||
validate_qa_dataset(df)
|
||||
assert df["qid"].apply(lambda x: isinstance(x, str)).sum() == len(
|
||||
df
|
||||
), "qid must be string type."
|
||||
assert df["query"].apply(lambda x: isinstance(x, str)).sum() == len(
|
||||
df
|
||||
), "query must be string type."
|
||||
df["retrieval_gt"] = df["retrieval_gt"].apply(cast_retrieval_gt)
|
||||
df["generation_gt"] = df["generation_gt"].apply(cast_generation_gt)
|
||||
df["query"] = df["query"].apply(preprocess_text)
|
||||
df["generation_gt"] = df["generation_gt"].apply(
|
||||
lambda x: list(map(preprocess_text, x))
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def cast_corpus_dataset(df: pd.DataFrame):
|
||||
df = df.reset_index(drop=True)
|
||||
validate_corpus_dataset(df)
|
||||
|
||||
# drop rows that have empty contents
|
||||
df = df[~df["contents"].apply(lambda x: x is None or x.isspace())]
|
||||
|
||||
def make_datetime_metadata(x):
|
||||
if x is None or x == {}:
|
||||
return {"last_modified_datetime": datetime.now()}
|
||||
elif x.get("last_modified_datetime") is None:
|
||||
return {**x, "last_modified_datetime": datetime.now()}
|
||||
else:
|
||||
return x
|
||||
|
||||
df["metadata"] = df["metadata"].apply(make_datetime_metadata)
|
||||
|
||||
# check every metadata have a datetime key
|
||||
assert sum(
|
||||
df["metadata"].apply(lambda x: x.get("last_modified_datetime") is not None)
|
||||
) == len(df), "Every metadata must have a datetime key."
|
||||
|
||||
def make_prev_next_id_metadata(x, id_type: str):
|
||||
if x is None or x == {}:
|
||||
return {id_type: None}
|
||||
elif x.get(id_type) is None:
|
||||
return {**x, id_type: None}
|
||||
else:
|
||||
return x
|
||||
|
||||
df["metadata"] = df["metadata"].apply(
|
||||
lambda x: make_prev_next_id_metadata(x, "prev_id")
|
||||
)
|
||||
df["metadata"] = df["metadata"].apply(
|
||||
lambda x: make_prev_next_id_metadata(x, "next_id")
|
||||
)
|
||||
|
||||
df["contents"] = df["contents"].apply(preprocess_text)
|
||||
|
||||
def normalize_unicode_metadata(metadata: dict):
|
||||
result = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
result[key] = preprocess_text(value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
df["metadata"] = df["metadata"].apply(normalize_unicode_metadata)
|
||||
|
||||
# check every metadata have a prev_id, next_id key
|
||||
assert all(
|
||||
"prev_id" in metadata for metadata in df["metadata"]
|
||||
), "Every metadata must have a prev_id key."
|
||||
assert all(
|
||||
"next_id" in metadata for metadata in df["metadata"]
|
||||
), "Every metadata must have a next_id key."
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def validate_qa_from_corpus_dataset(qa_df: pd.DataFrame, corpus_df: pd.DataFrame):
|
||||
qa_ids = []
|
||||
for retrieval_gt in qa_df["retrieval_gt"].tolist():
|
||||
if isinstance(retrieval_gt, list) and (
|
||||
retrieval_gt[0] != [] or any(bool(g) is True for g in retrieval_gt)
|
||||
):
|
||||
for gt in retrieval_gt:
|
||||
qa_ids.extend(gt)
|
||||
elif isinstance(retrieval_gt, np.ndarray) and retrieval_gt[0].size > 0:
|
||||
for gt in retrieval_gt:
|
||||
qa_ids.extend(gt)
|
||||
|
||||
no_exist_ids = list(
|
||||
filter(lambda qa_id: corpus_df[corpus_df["doc_id"] == qa_id].empty, qa_ids)
|
||||
)
|
||||
|
||||
assert (
|
||||
len(no_exist_ids) == 0
|
||||
), f"{len(no_exist_ids)} doc_ids in retrieval_gt do not exist in corpus_df."
|
||||
751
autorag/utils/util.py
Normal file
751
autorag/utils/util.py
Normal file
@@ -0,0 +1,751 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import datetime
|
||||
import functools
|
||||
import glob
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from copy import deepcopy
|
||||
from json import JSONDecoder
|
||||
from typing import List, Callable, Dict, Optional, Any, Collection, Iterable
|
||||
|
||||
from asyncio import AbstractEventLoop
|
||||
import emoji
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
import unicodedata
|
||||
|
||||
import yaml
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from pydantic import BaseModel as BM
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def fetch_contents(
|
||||
corpus_data: pd.DataFrame, ids: List[List[str]], column_name: str = "contents"
|
||||
) -> List[List[Any]]:
|
||||
def fetch_contents_pure(
|
||||
ids: List[str], corpus_data: pd.DataFrame, column_name: str
|
||||
):
|
||||
return list(map(lambda x: fetch_one_content(corpus_data, x, column_name), ids))
|
||||
|
||||
result = flatten_apply(
|
||||
fetch_contents_pure, ids, corpus_data=corpus_data, column_name=column_name
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def fetch_one_content(
|
||||
corpus_data: pd.DataFrame,
|
||||
id_: str,
|
||||
column_name: str = "contents",
|
||||
id_column_name: str = "doc_id",
|
||||
) -> Any:
|
||||
if isinstance(id_, str):
|
||||
if id_ in ["", ""]:
|
||||
return None
|
||||
fetch_result = corpus_data[corpus_data[id_column_name] == id_]
|
||||
if fetch_result.empty:
|
||||
raise ValueError(f"doc_id: {id_} not found in corpus_data.")
|
||||
else:
|
||||
return fetch_result[column_name].iloc[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def result_to_dataframe(column_names: List[str]):
|
||||
"""
|
||||
Decorator for converting results to pd.DataFrame.
|
||||
"""
|
||||
|
||||
def decorator_result_to_dataframe(func: Callable):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> pd.DataFrame:
|
||||
results = func(*args, **kwargs)
|
||||
if len(column_names) == 1:
|
||||
df_input = {column_names[0]: results}
|
||||
else:
|
||||
df_input = {
|
||||
column_name: result
|
||||
for result, column_name in zip(results, column_names)
|
||||
}
|
||||
result_df = pd.DataFrame(df_input)
|
||||
return result_df
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_result_to_dataframe
|
||||
|
||||
|
||||
def load_summary_file(
|
||||
summary_path: str, dict_columns: Optional[List[str]] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load a summary file from summary_path.
|
||||
|
||||
:param summary_path: The path of the summary file.
|
||||
:param dict_columns: The columns that are dictionary type.
|
||||
You must fill this parameter if you want to load summary file properly.
|
||||
Default is ['module_params'].
|
||||
:return: The summary dataframe.
|
||||
"""
|
||||
if not os.path.exists(summary_path):
|
||||
raise ValueError(f"summary.csv does not exist in {summary_path}.")
|
||||
summary_df = pd.read_csv(summary_path)
|
||||
if dict_columns is None:
|
||||
dict_columns = ["module_params"]
|
||||
|
||||
if any([col not in summary_df.columns for col in dict_columns]):
|
||||
raise ValueError(f"{dict_columns} must be in summary_df.columns.")
|
||||
|
||||
def convert_dict(elem):
|
||||
try:
|
||||
return ast.literal_eval(elem)
|
||||
except:
|
||||
# convert datetime or date to its object (recency filter)
|
||||
date_object = convert_datetime_string(elem)
|
||||
if date_object is None:
|
||||
raise ValueError(
|
||||
f"Malformed dict received : {elem}\nCan't convert to dict properly"
|
||||
)
|
||||
return {"threshold": date_object}
|
||||
|
||||
summary_df[dict_columns] = summary_df[dict_columns].map(convert_dict)
|
||||
return summary_df
|
||||
|
||||
|
||||
def convert_datetime_string(s):
|
||||
# Regex to extract datetime arguments from the string
|
||||
m = re.search(r"(datetime|date)(\((\d+)(,\s*\d+)*\))", s)
|
||||
if m:
|
||||
args = ast.literal_eval(m.group(2))
|
||||
if m.group(1) == "datetime":
|
||||
return datetime.datetime(*args)
|
||||
elif m.group(1) == "date":
|
||||
return datetime.date(*args)
|
||||
return None
|
||||
|
||||
|
||||
def make_combinations(target_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Make combinations from target_dict.
|
||||
The target_dict key value must be a string,
|
||||
and the value can be a list of values or single value.
|
||||
If generates all combinations of values from target_dict,
|
||||
which means generating dictionaries that contain only one value for each key,
|
||||
and all dictionaries will be different from each other.
|
||||
|
||||
:param target_dict: The target dictionary.
|
||||
:return: The list of generated dictionaries.
|
||||
"""
|
||||
dict_with_lists = dict(
|
||||
map(
|
||||
lambda x: (x[0], x[1] if isinstance(x[1], list) else [x[1]]),
|
||||
target_dict.items(),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_duplicate(x):
|
||||
def is_hashable(obj):
|
||||
try:
|
||||
hash(obj)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
if any([not is_hashable(elem) for elem in x]):
|
||||
# TODO: add duplication check for unhashable objects
|
||||
return x
|
||||
else:
|
||||
return list(set(x))
|
||||
|
||||
dict_with_lists = dict(
|
||||
map(lambda x: (x[0], delete_duplicate(x[1])), dict_with_lists.items())
|
||||
)
|
||||
combination = list(itertools.product(*dict_with_lists.values()))
|
||||
combination_dicts = [
|
||||
dict(zip(dict_with_lists.keys(), combo)) for combo in combination
|
||||
]
|
||||
return combination_dicts
|
||||
|
||||
|
||||
def explode(index_values: Collection[Any], explode_values: Collection[Collection[Any]]):
|
||||
"""
|
||||
Explode index_values and explode_values.
|
||||
The index_values and explode_values must have the same length.
|
||||
It will flatten explode_values and keep index_values as a pair.
|
||||
|
||||
:param index_values: The index values.
|
||||
:param explode_values: The exploded values.
|
||||
:return: Tuple of exploded index_values and exploded explode_values.
|
||||
"""
|
||||
assert len(index_values) == len(
|
||||
explode_values
|
||||
), "Index values and explode values must have same length"
|
||||
df = pd.DataFrame({"index_values": index_values, "explode_values": explode_values})
|
||||
df = df.explode("explode_values")
|
||||
return df["index_values"].tolist(), df["explode_values"].tolist()
|
||||
|
||||
|
||||
def replace_value_in_dict(target_dict: Dict, key: str, replace_value: Any) -> Dict:
|
||||
"""
|
||||
Replace the value of a certain key in target_dict.
|
||||
If there is no targeted key in target_dict, it will return target_dict.
|
||||
|
||||
:param target_dict: The target dictionary.
|
||||
:param key: The key is to replace.
|
||||
:param replace_value: The value to replace.
|
||||
:return: The replaced dictionary.
|
||||
"""
|
||||
replaced_dict = deepcopy(target_dict)
|
||||
if key not in replaced_dict:
|
||||
return replaced_dict
|
||||
replaced_dict[key] = replace_value
|
||||
return replaced_dict
|
||||
|
||||
|
||||
def normalize_string(s: str) -> str:
|
||||
"""
|
||||
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
|
||||
Lower text and remove punctuation, articles, and extra whitespace.
|
||||
"""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def convert_string_to_tuple_in_dict(d):
|
||||
"""Recursively converts strings that start with '(' and end with ')' to tuples in a dictionary."""
|
||||
for key, value in d.items():
|
||||
# If the value is a dictionary, recurse
|
||||
if isinstance(value, dict):
|
||||
convert_string_to_tuple_in_dict(value)
|
||||
# If the value is a list, iterate through its elements
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
# If an item in the list is a dictionary, recurse
|
||||
if isinstance(item, dict):
|
||||
convert_string_to_tuple_in_dict(item)
|
||||
# If an item in the list is a string matching the criteria, convert it to a tuple
|
||||
elif (
|
||||
isinstance(item, str)
|
||||
and item.startswith("(")
|
||||
and item.endswith(")")
|
||||
):
|
||||
value[i] = ast.literal_eval(item)
|
||||
# If the value is a string matching the criteria, convert it to a tuple
|
||||
elif isinstance(value, str) and value.startswith("(") and value.endswith(")"):
|
||||
d[key] = ast.literal_eval(value)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def convert_env_in_dict(d: Dict):
|
||||
"""
|
||||
Recursively converts environment variable string in a dictionary to actual environment variable.
|
||||
|
||||
:param d: The dictionary to convert.
|
||||
:return: The converted dictionary.
|
||||
"""
|
||||
env_pattern = re.compile(r".*?\${(.*?)}.*?")
|
||||
|
||||
def convert_env(val: str):
|
||||
matches = env_pattern.findall(val)
|
||||
for match in matches:
|
||||
val = val.replace(f"${{{match}}}", os.environ.get(match, ""))
|
||||
return val
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
convert_env_in_dict(value)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
convert_env_in_dict(item)
|
||||
elif isinstance(item, str):
|
||||
value[i] = convert_env(item)
|
||||
elif isinstance(value, str):
|
||||
d[key] = convert_env(value)
|
||||
return d
|
||||
|
||||
|
||||
async def process_batch(tasks, batch_size: int = 64) -> List[Any]:
|
||||
"""
|
||||
Processes tasks in batches asynchronously.
|
||||
|
||||
:param tasks: A list of no-argument functions or coroutines to be executed.
|
||||
:param batch_size: The number of tasks to process in a single batch.
|
||||
Default is 64.
|
||||
:return: A list of results from the processed tasks.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i in range(0, len(tasks), batch_size):
|
||||
batch = tasks[i : i + batch_size]
|
||||
batch_results = await asyncio.gather(*batch)
|
||||
results.extend(batch_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def make_batch(elems: List[Any], batch_size: int) -> List[List[Any]]:
|
||||
"""
|
||||
Make a batch of elems with batch_size.
|
||||
"""
|
||||
return [elems[i : i + batch_size] for i in range(0, len(elems), batch_size)]
|
||||
|
||||
|
||||
def save_parquet_safe(df: pd.DataFrame, filepath: str, upsert: bool = False):
|
||||
output_file_dir = os.path.dirname(filepath)
|
||||
if not os.path.isdir(output_file_dir):
|
||||
raise NotADirectoryError(f"directory {output_file_dir} not found.")
|
||||
if not filepath.endswith("parquet"):
|
||||
raise NameError(
|
||||
f'file path: {filepath} filename extension need to be ".parquet"'
|
||||
)
|
||||
if os.path.exists(filepath) and not upsert:
|
||||
raise FileExistsError(
|
||||
f"file {filepath} already exists."
|
||||
"Set upsert True if you want to overwrite the file."
|
||||
)
|
||||
|
||||
df.to_parquet(filepath, index=False)
|
||||
|
||||
|
||||
def openai_truncate_by_token(
|
||||
texts: List[str], token_limit: int, model_name: str
|
||||
) -> List[str]:
|
||||
try:
|
||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
# This is not a real OpenAI model
|
||||
return texts
|
||||
|
||||
def truncate_text(text: str, limit: int, tokenizer):
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= limit:
|
||||
return text
|
||||
truncated_text = tokenizer.decode(tokens[:limit])
|
||||
return truncated_text
|
||||
|
||||
return list(map(lambda x: truncate_text(x, token_limit, tokenizer), texts))
|
||||
|
||||
|
||||
def reconstruct_list(flat_list: List[Any], lengths: List[int]) -> List[List[Any]]:
|
||||
result = []
|
||||
start = 0
|
||||
for length in lengths:
|
||||
result.append(flat_list[start : start + length])
|
||||
start += length
|
||||
return result
|
||||
|
||||
|
||||
def flatten_apply(
|
||||
func: Callable, nested_list: List[List[Any]], **kwargs
|
||||
) -> List[List[Any]]:
|
||||
"""
|
||||
This function flattens the input list and applies the function to the elements.
|
||||
After that, it reconstructs the list to the original shape.
|
||||
Its speciality is that the first dimension length of the list can be different from each other.
|
||||
|
||||
:param func: The function that applies to the flattened list.
|
||||
:param nested_list: The nested list to be flattened.
|
||||
:return: The list that is reconstructed after applying the function.
|
||||
"""
|
||||
df = pd.DataFrame({"col1": nested_list})
|
||||
df = df.explode("col1")
|
||||
df["result"] = func(df["col1"].tolist(), **kwargs)
|
||||
return df.groupby(level=0, sort=False)["result"].apply(list).tolist()
|
||||
|
||||
|
||||
async def aflatten_apply(
|
||||
func: Callable, nested_list: List[List[Any]], **kwargs
|
||||
) -> List[List[Any]]:
|
||||
"""
|
||||
This function flattens the input list and applies the function to the elements.
|
||||
After that, it reconstructs the list to the original shape.
|
||||
Its speciality is that the first dimension length of the list can be different from each other.
|
||||
|
||||
:param func: The function that applies to the flattened list.
|
||||
:param nested_list: The nested list to be flattened.
|
||||
:return: The list that is reconstructed after applying the function.
|
||||
"""
|
||||
df = pd.DataFrame({"col1": nested_list})
|
||||
df = df.explode("col1")
|
||||
df["result"] = await func(df["col1"].tolist(), **kwargs)
|
||||
return df.groupby(level=0, sort=False)["result"].apply(list).tolist()
|
||||
|
||||
|
||||
def sort_by_scores(row, reverse=True):
|
||||
"""
|
||||
Sorts each row by 'scores' column.
|
||||
The input column names must be 'contents', 'ids', and 'scores'.
|
||||
And its elements must be list type.
|
||||
"""
|
||||
results = sorted(
|
||||
zip(row["contents"], row["ids"], row["scores"]),
|
||||
key=lambda x: x[2],
|
||||
reverse=reverse,
|
||||
)
|
||||
reranked_contents, reranked_ids, reranked_scores = zip(*results)
|
||||
return list(reranked_contents), list(reranked_ids), list(reranked_scores)
|
||||
|
||||
|
||||
def select_top_k(df, column_names: List[str], top_k: int):
|
||||
for column_name in column_names:
|
||||
df[column_name] = df[column_name].apply(lambda x: x[:top_k])
|
||||
return df
|
||||
|
||||
|
||||
def filter_dict_keys(dict_, keys: List[str]):
|
||||
result = {}
|
||||
for key in keys:
|
||||
if key in dict_:
|
||||
result[key] = dict_[key]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' not found in dictionary.")
|
||||
return result
|
||||
|
||||
|
||||
def split_dataframe(df, chunk_size):
|
||||
num_chunks = (
|
||||
len(df) // chunk_size + 1
|
||||
if len(df) % chunk_size != 0
|
||||
else len(df) // chunk_size
|
||||
)
|
||||
result = list(
|
||||
map(lambda x: df[x * chunk_size : (x + 1) * chunk_size], range(num_chunks))
|
||||
)
|
||||
result = list(map(lambda x: x.reset_index(drop=True), result))
|
||||
return result
|
||||
|
||||
|
||||
def find_trial_dir(project_dir: str) -> List[str]:
|
||||
# Pattern to match directories named with numbers
|
||||
pattern = os.path.join(project_dir, "[0-9]*")
|
||||
all_entries = glob.glob(pattern)
|
||||
|
||||
# Filter out only directories
|
||||
trial_dirs = [
|
||||
entry
|
||||
for entry in all_entries
|
||||
if os.path.isdir(entry) and entry.split(os.sep)[-1].isdigit()
|
||||
]
|
||||
|
||||
return trial_dirs
|
||||
|
||||
|
||||
def find_node_summary_files(trial_dir: str) -> List[str]:
|
||||
# Find all summary.csv files recursively
|
||||
all_summary_files = glob.glob(
|
||||
os.path.join(trial_dir, "**", "summary.csv"), recursive=True
|
||||
)
|
||||
|
||||
# Filter out files that are at a lower directory level
|
||||
filtered_files = [
|
||||
f for f in all_summary_files if f.count(os.sep) > trial_dir.count(os.sep) + 2
|
||||
]
|
||||
|
||||
return filtered_files
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
return normalize_unicode(demojize(text))
|
||||
|
||||
|
||||
def demojize(text: str) -> str:
|
||||
return emoji.demojize(text)
|
||||
|
||||
|
||||
def normalize_unicode(text: str) -> str:
|
||||
return unicodedata.normalize("NFC", text)
|
||||
|
||||
|
||||
def dict_to_markdown(d, level=1):
|
||||
"""
|
||||
Convert a dictionary to a Markdown formatted string.
|
||||
|
||||
:param d: Dictionary to convert
|
||||
:param level: Current level of heading (used for nested dictionaries)
|
||||
:return: Markdown formatted string
|
||||
"""
|
||||
markdown = ""
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
markdown += f"{'#' * level} {key}\n"
|
||||
markdown += dict_to_markdown(value, level + 1)
|
||||
elif isinstance(value, list):
|
||||
markdown += f"{'#' * level} {key}\n"
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
markdown += dict_to_markdown(item, level + 1)
|
||||
else:
|
||||
markdown += f"- {item}\n"
|
||||
else:
|
||||
markdown += f"{'#' * level} {key}\n{value}\n"
|
||||
return markdown
|
||||
|
||||
|
||||
def dict_to_markdown_table(data, key_column_name: str, value_column_name: str):
|
||||
# Check if the input is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Input must be a dictionary")
|
||||
|
||||
# Create the header of the table
|
||||
header = f"| {key_column_name} | {value_column_name} |\n| :---: | :-----: |\n"
|
||||
|
||||
# Create the rows of the table
|
||||
rows = ""
|
||||
for key, value in data.items():
|
||||
rows += f"| {key} | {value} |\n"
|
||||
|
||||
# Combine header and rows
|
||||
markdown_table = header + rows
|
||||
return markdown_table
|
||||
|
||||
|
||||
def embedding_query_content(
|
||||
queries: List[str],
|
||||
contents_list: List[List[str]],
|
||||
embedding_model: Optional[str] = None,
|
||||
batch: int = 128,
|
||||
):
|
||||
flatten_contents = list(itertools.chain.from_iterable(contents_list))
|
||||
|
||||
openai_embedding_limit = 8000 # all openai embedding model has 8000 max token input
|
||||
if isinstance(embedding_model, OpenAIEmbedding):
|
||||
queries = openai_truncate_by_token(
|
||||
queries, openai_embedding_limit, embedding_model.model_name
|
||||
)
|
||||
flatten_contents = openai_truncate_by_token(
|
||||
flatten_contents, openai_embedding_limit, embedding_model.model_name
|
||||
)
|
||||
|
||||
# Embedding using batch
|
||||
embedding_model.embed_batch_size = batch
|
||||
query_embeddings = embedding_model.get_text_embedding_batch(queries)
|
||||
|
||||
content_lengths = list(map(len, contents_list))
|
||||
content_embeddings_flatten = embedding_model.get_text_embedding_batch(
|
||||
flatten_contents
|
||||
)
|
||||
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths)
|
||||
return query_embeddings, content_embeddings
|
||||
|
||||
|
||||
def to_list(item):
|
||||
"""Recursively convert collections to Python lists."""
|
||||
if isinstance(item, np.ndarray):
|
||||
# Convert numpy array to list and recursively process each element
|
||||
return [to_list(sub_item) for sub_item in item.tolist()]
|
||||
elif isinstance(item, pd.Series):
|
||||
# Convert pandas Series to list and recursively process each element
|
||||
return [to_list(sub_item) for sub_item in item.tolist()]
|
||||
elif isinstance(item, Iterable) and not isinstance(
|
||||
item, (str, bytes, BaseModel, BM)
|
||||
):
|
||||
# Recursively process each element in other iterables
|
||||
return [to_list(sub_item) for sub_item in item]
|
||||
else:
|
||||
return item
|
||||
|
||||
|
||||
def convert_inputs_to_list(func):
|
||||
"""Decorator to convert all function inputs to Python lists."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
new_args = [to_list(arg) for arg in args]
|
||||
new_kwargs = {k: to_list(v) for k, v in kwargs.items()}
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_best_row(
|
||||
summary_df: pd.DataFrame, best_column_name: str = "is_best"
|
||||
) -> pd.Series:
|
||||
"""
|
||||
From the summary dataframe, find the best result row by 'is_best' column and return it.
|
||||
|
||||
:param summary_df: Summary dataframe created by AutoRAG.
|
||||
:param best_column_name: The column name that indicates the best result.
|
||||
Default is 'is_best'.
|
||||
You don't have to change this unless the column name is different.
|
||||
:return: Best row pandas Series instance.
|
||||
"""
|
||||
bests = summary_df.loc[summary_df[best_column_name]]
|
||||
assert len(bests) == 1, "There must be only one best result."
|
||||
return bests.iloc[0]
|
||||
|
||||
|
||||
def get_event_loop() -> AbstractEventLoop:
|
||||
"""
|
||||
Get asyncio event loop safely.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def find_key_values(data, target_key: str) -> List[Any]:
|
||||
"""
|
||||
Recursively find all values for a specific key in a nested dictionary or list.
|
||||
|
||||
:param data: The dictionary or list to search.
|
||||
:param target_key: The key to search for.
|
||||
:return: A list of values associated with the target key.
|
||||
"""
|
||||
values = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if key == target_key:
|
||||
values.append(value)
|
||||
if isinstance(value, (dict, list)):
|
||||
values.extend(find_key_values(value, target_key))
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
values.extend(find_key_values(item, target_key))
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def pop_params(func: Callable, kwargs: Dict) -> Dict:
|
||||
"""
|
||||
Pop parameters from the given func and return them.
|
||||
It automatically deletes the parameters like "self" or "cls".
|
||||
|
||||
:param func: The function to pop parameters.
|
||||
:param kwargs: kwargs to pop parameters.
|
||||
:return: The popped parameters.
|
||||
"""
|
||||
ignore_params = ["self", "cls"]
|
||||
target_params = list(inspect.signature(func).parameters.keys())
|
||||
target_params = list(filter(lambda x: x not in ignore_params, target_params))
|
||||
|
||||
init_params = {}
|
||||
kwargs_keys = list(kwargs.keys())
|
||||
for key in kwargs_keys:
|
||||
if key in target_params:
|
||||
init_params[key] = kwargs.pop(key)
|
||||
return init_params
|
||||
|
||||
|
||||
def apply_recursive(func, data):
|
||||
"""
|
||||
Recursively apply a function to all elements in a list, tuple, set, np.ndarray, or pd.Series and return as List.
|
||||
|
||||
:param func: Function to apply to each element.
|
||||
:param data: List or nested list.
|
||||
:return: List with the function applied to each element.
|
||||
"""
|
||||
if (
|
||||
isinstance(data, list)
|
||||
or isinstance(data, tuple)
|
||||
or isinstance(data, set)
|
||||
or isinstance(data, np.ndarray)
|
||||
or isinstance(data, pd.Series)
|
||||
):
|
||||
return [apply_recursive(func, item) for item in data]
|
||||
else:
|
||||
return func(data)
|
||||
|
||||
|
||||
def empty_cuda_cache():
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def load_yaml_config(yaml_path: str) -> Dict:
|
||||
"""
|
||||
Load a YAML configuration file for AutoRAG.
|
||||
It contains safe loading, converting string to tuple, and insert environment variables.
|
||||
|
||||
:param yaml_path: The path of the YAML configuration file.
|
||||
:return: The loaded configuration dictionary.
|
||||
"""
|
||||
if not os.path.exists(yaml_path):
|
||||
raise ValueError(f"YAML file {yaml_path} does not exist.")
|
||||
with open(yaml_path, "r", encoding="utf-8") as stream:
|
||||
try:
|
||||
yaml_dict = yaml.safe_load(stream)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"YAML file {yaml_path} could not be loaded.") from exc
|
||||
|
||||
yaml_dict = convert_string_to_tuple_in_dict(yaml_dict)
|
||||
yaml_dict = convert_env_in_dict(yaml_dict)
|
||||
return yaml_dict
|
||||
|
||||
|
||||
def decode_multiple_json_from_bytes(byte_data: bytes) -> list:
|
||||
"""
|
||||
Decode multiple JSON objects from bytes received from SSE server.
|
||||
|
||||
Args:
|
||||
byte_data: Bytes containing one or more JSON objects
|
||||
|
||||
Returns:
|
||||
List of decoded JSON objects
|
||||
"""
|
||||
# Decode bytes to string
|
||||
try:
|
||||
text_data = byte_data.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError("Invalid byte data: Unable to decode as UTF-8")
|
||||
|
||||
# Initialize decoder and result list
|
||||
decoder = JSONDecoder()
|
||||
result = []
|
||||
|
||||
# Keep track of position in string
|
||||
pos = 0
|
||||
text_data = text_data.strip()
|
||||
|
||||
while pos < len(text_data):
|
||||
try:
|
||||
# Try to decode next JSON object
|
||||
json_obj, json_end = decoder.raw_decode(text_data[pos:])
|
||||
result.append(json_obj)
|
||||
|
||||
# Move position to end of current JSON object
|
||||
pos += json_end
|
||||
|
||||
# Skip any whitespace
|
||||
while pos < len(text_data) and text_data[pos].isspace():
|
||||
pos += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If we can't decode at current position, move forward one character
|
||||
pos += 1
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user