Fix Dockerfile build issue

This commit is contained in:
kyy
2025-03-18 16:41:12 +09:00
parent 6814230bfb
commit 9323aa254a
228 changed files with 467 additions and 3488 deletions

View File

@@ -0,0 +1,3 @@
from .module import Module
from .node import Node
from .base import BaseModule

35
autorag/schema/base.py Normal file
View File

@@ -0,0 +1,35 @@
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Union
import pandas as pd
class BaseModule(metaclass=ABCMeta):
@abstractmethod
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
pass
@abstractmethod
def _pure(self, *args, **kwargs):
pass
@classmethod
def run_evaluator(
cls,
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args,
**kwargs,
):
instance = cls(project_dir, *args, **kwargs)
result = instance.pure(previous_result, *args, **kwargs)
del instance
return result
@abstractmethod
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
"""
This function is for cast function (a.k.a decorator) only for pure function in the whole node.
"""
pass

View File

@@ -0,0 +1,99 @@
from dataclasses import dataclass
from typing import Optional, List, Dict, Callable, Any, Union
import numpy as np
import pandas as pd
@dataclass
class MetricInput:
query: Optional[str] = None
queries: Optional[List[str]] = None
retrieval_gt_contents: Optional[List[List[str]]] = None
retrieved_contents: Optional[List[str]] = None
retrieval_gt: Optional[List[List[str]]] = None
retrieved_ids: Optional[List[str]] = None
prompt: Optional[str] = None
generated_texts: Optional[str] = None
generation_gt: Optional[List[str]] = None
generated_log_probs: Optional[List[float]] = None
def is_fields_notnone(self, fields_to_check: List[str]) -> bool:
for field in fields_to_check:
actual_value = getattr(self, field)
if actual_value is None:
return False
try:
if not type_checks.get(type(actual_value), lambda _: False)(
actual_value
):
return False
except Exception:
return False
return True
@classmethod
def from_dataframe(cls, qa_data: pd.DataFrame) -> List["MetricInput"]:
"""
Convert a pandas DataFrame into a list of MetricInput instances.
qa_data: pd.DataFrame: qa_data DataFrame containing metric data.
:returns: List[MetricInput]: List of MetricInput objects created from DataFrame rows.
"""
instances = []
for _, row in qa_data.iterrows():
instance = cls()
for attr_name in cls.__annotations__:
if attr_name in row:
value = row[attr_name]
if isinstance(value, str):
setattr(
instance,
attr_name,
value.strip() if value.strip() != "" else None,
)
elif isinstance(value, list):
setattr(instance, attr_name, value if len(value) > 0 else None)
else:
setattr(instance, attr_name, value)
instances.append(instance)
return instances
@staticmethod
def _check_list(lst_or_arr: Union[List[Any], np.ndarray]) -> bool:
if isinstance(lst_or_arr, np.ndarray):
lst_or_arr = lst_or_arr.flatten().tolist()
if len(lst_or_arr) == 0:
return False
for item in lst_or_arr:
if item is None:
return False
item_type = type(item)
if item_type in type_checks:
if not type_checks[item_type](item):
return False
else:
return False
return True
type_checks: Dict[type, Callable[[Any], bool]] = {
str: lambda x: len(x.strip()) > 0,
list: MetricInput._check_list,
np.ndarray: MetricInput._check_list,
int: lambda _: True,
float: lambda _: True,
}

24
autorag/schema/module.py Normal file
View File

@@ -0,0 +1,24 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Callable, Dict
from autorag.support import get_support_modules
@dataclass
class Module:
module_type: str
module_param: Dict
module: Callable = field(init=False)
def __post_init__(self):
self.module = get_support_modules(self.module_type)
if self.module is None:
raise ValueError(f"Module type {self.module_type} is not supported.")
@classmethod
def from_dict(cls, module_dict: Dict) -> "Module":
_module_dict = deepcopy(module_dict)
module_type = _module_dict.pop("module_type")
module_params = _module_dict
return cls(module_type, module_params)

143
autorag/schema/node.py Normal file
View File

@@ -0,0 +1,143 @@
import itertools
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, List, Callable, Tuple, Any
import pandas as pd
from autorag.schema.module import Module
from autorag.support import get_support_nodes
from autorag.utils.util import make_combinations, explode, find_key_values
logger = logging.getLogger("AutoRAG")
@dataclass
class Node:
node_type: str
strategy: Dict
node_params: Dict
modules: List[Module]
run_node: Callable = field(init=False)
def __post_init__(self):
self.run_node = get_support_nodes(self.node_type)
if self.run_node is None:
raise ValueError(f"Node type {self.node_type} is not supported.")
def get_param_combinations(self) -> Tuple[List[Callable], List[Dict]]:
"""
This method returns a combination of module and node parameters, also corresponding modules.
:return: Each module and its module parameters.
:rtype: Tuple[List[Callable], List[Dict]]
"""
def make_single_combination(module: Module) -> List[Dict]:
input_dict = {**self.node_params, **module.module_param}
return make_combinations(input_dict)
combinations = list(map(make_single_combination, self.modules))
module_list, combination_list = explode(self.modules, combinations)
return list(map(lambda x: x.module, module_list)), combination_list
@classmethod
def from_dict(cls, node_dict: Dict) -> "Node":
_node_dict = deepcopy(node_dict)
node_type = _node_dict.pop("node_type")
strategy = _node_dict.pop("strategy")
modules = list(map(lambda x: Module.from_dict(x), _node_dict.pop("modules")))
node_params = _node_dict
return cls(node_type, strategy, node_params, modules)
def run(self, previous_result: pd.DataFrame, node_line_dir: str) -> pd.DataFrame:
logger.info(f"Running node {self.node_type}...")
input_modules, input_params = self.get_param_combinations()
return self.run_node(
modules=input_modules,
module_params=input_params,
previous_result=previous_result,
node_line_dir=node_line_dir,
strategies=self.strategy,
)
def extract_values(node: Node, key: str) -> List[str]:
"""
This function extract values from node's modules' module_param.
:param node: The node you want to extract values from.
:param key: The key of module_param that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
def extract_module_values(module: Module):
if key not in module.module_param:
return []
value = module.module_param[key]
if isinstance(value, str) or isinstance(value, int):
return [value]
elif isinstance(value, list):
return value
else:
raise ValueError(f"{key} must be str,list or int, but got {type(value)}")
values = list(map(extract_module_values, node.modules))
return list(set(list(itertools.chain.from_iterable(values))))
def extract_values_from_nodes(nodes: List[Node], key: str) -> List[str]:
"""
This function extract values from nodes' modules' module_param.
:param nodes: The nodes you want to extract values from.
:param key: The key of module_param that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
values = list(map(lambda node: extract_values(node, key), nodes))
return list(set(list(itertools.chain.from_iterable(values))))
def extract_values_from_nodes_strategy(nodes: List[Node], key: str) -> List[Any]:
"""
This function extract values from nodes' strategy.
:param nodes: The nodes you want to extract values from.
:param key: The key string that you want to extract.
:return: The list of extracted values.
It removes duplicated elements automatically.
"""
values = []
for node in nodes:
value_list = find_key_values(node.strategy, key)
if value_list:
values.extend(value_list)
return values
def module_type_exists(nodes: List[Node], module_type: str) -> bool:
"""
This function check if the module type exists in the nodes.
:param nodes: The nodes you want to check.
:param module_type: The module type you want to check.
:return: True if the module type exists in the nodes.
"""
return any(
list(
map(
lambda node: any(
list(
map(
lambda module: module.module_type == module_type,
node.modules,
)
)
),
nodes,
)
)
)