Fix Dockerfile build issue
This commit is contained in:
3
autorag/schema/__init__.py
Normal file
3
autorag/schema/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .module import Module
|
||||
from .node import Node
|
||||
from .base import BaseModule
|
||||
35
autorag/schema/base.py
Normal file
35
autorag/schema/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseModule(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _pure(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run_evaluator(
|
||||
cls,
|
||||
project_dir: Union[str, Path],
|
||||
previous_result: pd.DataFrame,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
instance = cls(project_dir, *args, **kwargs)
|
||||
result = instance.pure(previous_result, *args, **kwargs)
|
||||
del instance
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
This function is for cast function (a.k.a decorator) only for pure function in the whole node.
|
||||
"""
|
||||
pass
|
||||
99
autorag/schema/metricinput.py
Normal file
99
autorag/schema/metricinput.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Callable, Any, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricInput:
|
||||
query: Optional[str] = None
|
||||
queries: Optional[List[str]] = None
|
||||
retrieval_gt_contents: Optional[List[List[str]]] = None
|
||||
retrieved_contents: Optional[List[str]] = None
|
||||
retrieval_gt: Optional[List[List[str]]] = None
|
||||
retrieved_ids: Optional[List[str]] = None
|
||||
prompt: Optional[str] = None
|
||||
generated_texts: Optional[str] = None
|
||||
generation_gt: Optional[List[str]] = None
|
||||
generated_log_probs: Optional[List[float]] = None
|
||||
|
||||
def is_fields_notnone(self, fields_to_check: List[str]) -> bool:
|
||||
for field in fields_to_check:
|
||||
actual_value = getattr(self, field)
|
||||
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not type_checks.get(type(actual_value), lambda _: False)(
|
||||
actual_value
|
||||
):
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_dataframe(cls, qa_data: pd.DataFrame) -> List["MetricInput"]:
|
||||
"""
|
||||
Convert a pandas DataFrame into a list of MetricInput instances.
|
||||
qa_data: pd.DataFrame: qa_data DataFrame containing metric data.
|
||||
|
||||
:returns: List[MetricInput]: List of MetricInput objects created from DataFrame rows.
|
||||
"""
|
||||
instances = []
|
||||
|
||||
for _, row in qa_data.iterrows():
|
||||
instance = cls()
|
||||
|
||||
for attr_name in cls.__annotations__:
|
||||
if attr_name in row:
|
||||
value = row[attr_name]
|
||||
|
||||
if isinstance(value, str):
|
||||
setattr(
|
||||
instance,
|
||||
attr_name,
|
||||
value.strip() if value.strip() != "" else None,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
setattr(instance, attr_name, value if len(value) > 0 else None)
|
||||
else:
|
||||
setattr(instance, attr_name, value)
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
return instances
|
||||
|
||||
@staticmethod
|
||||
def _check_list(lst_or_arr: Union[List[Any], np.ndarray]) -> bool:
|
||||
if isinstance(lst_or_arr, np.ndarray):
|
||||
lst_or_arr = lst_or_arr.flatten().tolist()
|
||||
|
||||
if len(lst_or_arr) == 0:
|
||||
return False
|
||||
|
||||
for item in lst_or_arr:
|
||||
if item is None:
|
||||
return False
|
||||
|
||||
item_type = type(item)
|
||||
|
||||
if item_type in type_checks:
|
||||
if not type_checks[item_type](item):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
type_checks: Dict[type, Callable[[Any], bool]] = {
|
||||
str: lambda x: len(x.strip()) > 0,
|
||||
list: MetricInput._check_list,
|
||||
np.ndarray: MetricInput._check_list,
|
||||
int: lambda _: True,
|
||||
float: lambda _: True,
|
||||
}
|
||||
24
autorag/schema/module.py
Normal file
24
autorag/schema/module.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module:
|
||||
module_type: str
|
||||
module_param: Dict
|
||||
module: Callable = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.module = get_support_modules(self.module_type)
|
||||
if self.module is None:
|
||||
raise ValueError(f"Module type {self.module_type} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, module_dict: Dict) -> "Module":
|
||||
_module_dict = deepcopy(module_dict)
|
||||
module_type = _module_dict.pop("module_type")
|
||||
module_params = _module_dict
|
||||
return cls(module_type, module_params)
|
||||
143
autorag/schema/node.py
Normal file
143
autorag/schema/node.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import itertools
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Callable, Tuple, Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.schema.module import Module
|
||||
from autorag.support import get_support_nodes
|
||||
from autorag.utils.util import make_combinations, explode, find_key_values
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
node_type: str
|
||||
strategy: Dict
|
||||
node_params: Dict
|
||||
modules: List[Module]
|
||||
run_node: Callable = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.run_node = get_support_nodes(self.node_type)
|
||||
if self.run_node is None:
|
||||
raise ValueError(f"Node type {self.node_type} is not supported.")
|
||||
|
||||
def get_param_combinations(self) -> Tuple[List[Callable], List[Dict]]:
|
||||
"""
|
||||
This method returns a combination of module and node parameters, also corresponding modules.
|
||||
|
||||
:return: Each module and its module parameters.
|
||||
:rtype: Tuple[List[Callable], List[Dict]]
|
||||
"""
|
||||
|
||||
def make_single_combination(module: Module) -> List[Dict]:
|
||||
input_dict = {**self.node_params, **module.module_param}
|
||||
return make_combinations(input_dict)
|
||||
|
||||
combinations = list(map(make_single_combination, self.modules))
|
||||
module_list, combination_list = explode(self.modules, combinations)
|
||||
return list(map(lambda x: x.module, module_list)), combination_list
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, node_dict: Dict) -> "Node":
|
||||
_node_dict = deepcopy(node_dict)
|
||||
node_type = _node_dict.pop("node_type")
|
||||
strategy = _node_dict.pop("strategy")
|
||||
modules = list(map(lambda x: Module.from_dict(x), _node_dict.pop("modules")))
|
||||
node_params = _node_dict
|
||||
return cls(node_type, strategy, node_params, modules)
|
||||
|
||||
def run(self, previous_result: pd.DataFrame, node_line_dir: str) -> pd.DataFrame:
|
||||
logger.info(f"Running node {self.node_type}...")
|
||||
input_modules, input_params = self.get_param_combinations()
|
||||
return self.run_node(
|
||||
modules=input_modules,
|
||||
module_params=input_params,
|
||||
previous_result=previous_result,
|
||||
node_line_dir=node_line_dir,
|
||||
strategies=self.strategy,
|
||||
)
|
||||
|
||||
|
||||
def extract_values(node: Node, key: str) -> List[str]:
|
||||
"""
|
||||
This function extract values from node's modules' module_param.
|
||||
|
||||
:param node: The node you want to extract values from.
|
||||
:param key: The key of module_param that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
|
||||
def extract_module_values(module: Module):
|
||||
if key not in module.module_param:
|
||||
return []
|
||||
value = module.module_param[key]
|
||||
if isinstance(value, str) or isinstance(value, int):
|
||||
return [value]
|
||||
elif isinstance(value, list):
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"{key} must be str,list or int, but got {type(value)}")
|
||||
|
||||
values = list(map(extract_module_values, node.modules))
|
||||
return list(set(list(itertools.chain.from_iterable(values))))
|
||||
|
||||
|
||||
def extract_values_from_nodes(nodes: List[Node], key: str) -> List[str]:
|
||||
"""
|
||||
This function extract values from nodes' modules' module_param.
|
||||
|
||||
:param nodes: The nodes you want to extract values from.
|
||||
:param key: The key of module_param that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
values = list(map(lambda node: extract_values(node, key), nodes))
|
||||
return list(set(list(itertools.chain.from_iterable(values))))
|
||||
|
||||
|
||||
def extract_values_from_nodes_strategy(nodes: List[Node], key: str) -> List[Any]:
|
||||
"""
|
||||
This function extract values from nodes' strategy.
|
||||
|
||||
:param nodes: The nodes you want to extract values from.
|
||||
:param key: The key string that you want to extract.
|
||||
:return: The list of extracted values.
|
||||
It removes duplicated elements automatically.
|
||||
"""
|
||||
values = []
|
||||
for node in nodes:
|
||||
value_list = find_key_values(node.strategy, key)
|
||||
if value_list:
|
||||
values.extend(value_list)
|
||||
return values
|
||||
|
||||
|
||||
def module_type_exists(nodes: List[Node], module_type: str) -> bool:
|
||||
"""
|
||||
This function check if the module type exists in the nodes.
|
||||
|
||||
:param nodes: The nodes you want to check.
|
||||
:param module_type: The module type you want to check.
|
||||
:return: True if the module type exists in the nodes.
|
||||
"""
|
||||
return any(
|
||||
list(
|
||||
map(
|
||||
lambda node: any(
|
||||
list(
|
||||
map(
|
||||
lambda module: module.module_type == module_type,
|
||||
node.modules,
|
||||
)
|
||||
)
|
||||
),
|
||||
nodes,
|
||||
)
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user