Initial commit
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from .hyde import HyDE
|
||||
from .multi_query_expansion import MultiQueryExpansion
|
||||
from .pass_query_expansion import PassQueryExpansion
|
||||
from .query_decompose import QueryDecompose
|
||||
62
autorag-workspace/autorag/nodes/queryexpansion/base.py
Normal file
62
autorag-workspace/autorag/nodes/queryexpansion/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import abc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.util import make_generator_callable_param
|
||||
from autorag.schema import BaseModule
|
||||
from autorag.utils import validate_qa_dataset
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class BaseQueryExpansion(BaseModule, metaclass=abc.ABCMeta):
|
||||
def __init__(self, project_dir: Union[str, Path], *args, **kwargs):
|
||||
logger.info(
|
||||
f"Initialize query expansion node - {self.__class__.__name__} module..."
|
||||
)
|
||||
# set generator module for query expansion
|
||||
generator_class, generator_param = make_generator_callable_param(kwargs)
|
||||
self.generator = generator_class(project_dir, **generator_param)
|
||||
|
||||
def __del__(self):
|
||||
del self.generator
|
||||
logger.info(
|
||||
f"Delete query expansion node - {self.__class__.__name__} module..."
|
||||
)
|
||||
|
||||
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
logger.info(
|
||||
f"Running query expansion node - {self.__class__.__name__} module..."
|
||||
)
|
||||
validate_qa_dataset(previous_result)
|
||||
|
||||
# find queries columns
|
||||
assert (
|
||||
"query" in previous_result.columns
|
||||
), "previous_result must have query column."
|
||||
queries = previous_result["query"].tolist()
|
||||
return queries
|
||||
|
||||
@staticmethod
|
||||
def _check_expanded_query(queries: List[str], expanded_queries: List[List[str]]):
|
||||
return list(
|
||||
map(
|
||||
lambda query, expanded_query_list: check_expanded_query(
|
||||
query, expanded_query_list
|
||||
),
|
||||
queries,
|
||||
expanded_queries,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def check_expanded_query(query: str, expanded_query_list: List[str]):
|
||||
# check if the expanded query is the same as the original query
|
||||
expanded_query_list = list(map(lambda x: x.strip(), expanded_query_list))
|
||||
return [
|
||||
expanded_query if expanded_query else query
|
||||
for expanded_query in expanded_query_list
|
||||
]
|
||||
43
autorag-workspace/autorag/nodes/queryexpansion/hyde.py
Normal file
43
autorag-workspace/autorag/nodes/queryexpansion/hyde.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
hyde_prompt = "Please write a passage to answer the question"
|
||||
|
||||
|
||||
class HyDE(BaseQueryExpansion):
|
||||
@result_to_dataframe(["queries"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
queries = self.cast_to_run(previous_result, *args, **kwargs)
|
||||
|
||||
# pop prompt from kwargs
|
||||
prompt = kwargs.pop("prompt", hyde_prompt)
|
||||
kwargs.pop("generator_module_type", None)
|
||||
|
||||
expanded_queries = self._pure(queries, prompt, **kwargs)
|
||||
return self._check_expanded_query(queries, expanded_queries)
|
||||
|
||||
def _pure(self, queries: List[str], prompt: str = hyde_prompt, **generator_params):
|
||||
"""
|
||||
HyDE, which inspired by "Precise Zero-shot Dense Retrieval without Relevance Labels" (https://arxiv.org/pdf/2212.10496.pdf)
|
||||
LLM model creates a hypothetical passage.
|
||||
And then, retrieve passages using hypothetical passage as a query.
|
||||
:param queries: List[str], queries to retrieve.
|
||||
:param prompt: Prompt to use when generating hypothetical passage
|
||||
:return: List[List[str]], List of hyde results.
|
||||
"""
|
||||
full_prompts = list(
|
||||
map(
|
||||
lambda x: (prompt if not bool(prompt) else hyde_prompt)
|
||||
+ f"\nQuestion: {x}\nPassage:",
|
||||
queries,
|
||||
)
|
||||
)
|
||||
input_df = pd.DataFrame({"prompts": full_prompts})
|
||||
result_df = self.generator.pure(previous_result=input_df, **generator_params)
|
||||
answers = result_df["generated_texts"].tolist()
|
||||
results = list(map(lambda x: [x], answers))
|
||||
return results
|
||||
@@ -0,0 +1,57 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
multi_query_expansion_prompt = """You are an AI language model assistant.
|
||||
Your task is to generate 3 different versions of the given user
|
||||
question to retrieve relevant documents from a vector database.
|
||||
By generating multiple perspectives on the user question,
|
||||
your goal is to help the user overcome some of the limitations
|
||||
of distance-based similarity search. Provide these alternative
|
||||
questions separated by newlines. Original question: {query}"""
|
||||
|
||||
|
||||
class MultiQueryExpansion(BaseQueryExpansion):
|
||||
@result_to_dataframe(["queries"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
queries = self.cast_to_run(previous_result, *args, **kwargs)
|
||||
|
||||
# pop prompt from kwargs
|
||||
prompt = kwargs.pop("prompt", multi_query_expansion_prompt)
|
||||
kwargs.pop("generator_module_type", None)
|
||||
|
||||
expanded_queries = self._pure(queries, prompt, **kwargs)
|
||||
return self._check_expanded_query(queries, expanded_queries)
|
||||
|
||||
def _pure(
|
||||
self, queries, prompt: str = multi_query_expansion_prompt, **kwargs
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Expand a list of queries using a multi-query expansion approach.
|
||||
LLM model generate 3 different versions queries for each input query.
|
||||
|
||||
:param queries: List[str], queries to decompose.
|
||||
:param prompt: str, prompt to use for multi-query expansion.
|
||||
default prompt comes from langchain MultiQueryRetriever default query prompt.
|
||||
:return: List[List[str]], list of expansion query.
|
||||
"""
|
||||
full_prompts = list(map(lambda x: prompt.format(query=x), queries))
|
||||
input_df = pd.DataFrame({"prompts": full_prompts})
|
||||
result_df = self.generator.pure(previous_result=input_df, **kwargs)
|
||||
answers = result_df["generated_texts"].tolist()
|
||||
results = list(
|
||||
map(lambda x: get_multi_query_expansion(x[0], x[1]), zip(queries, answers))
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def get_multi_query_expansion(query: str, answer: str) -> List[str]:
|
||||
try:
|
||||
queries = answer.split("\n")
|
||||
queries.insert(0, query)
|
||||
return queries
|
||||
except:
|
||||
return [query]
|
||||
@@ -0,0 +1,22 @@
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
|
||||
class PassQueryExpansion(BaseQueryExpansion):
|
||||
@result_to_dataframe(["queries"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
Do not perform query expansion.
|
||||
Return with the same queries.
|
||||
The dimension will be 2-d list, and the column name will be 'queries'.
|
||||
"""
|
||||
assert (
|
||||
"query" in previous_result.columns
|
||||
), "previous_result must have query column."
|
||||
queries = previous_result["query"].tolist()
|
||||
return list(map(lambda x: [x], queries))
|
||||
|
||||
def _pure(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -0,0 +1,111 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
decompose_prompt = """Decompose a question in self-contained sub-questions. Use \"The question needs no decomposition\" when no decomposition is needed.
|
||||
|
||||
Example 1:
|
||||
|
||||
Question: Is Hamlet more common on IMDB than Comedy of Errors?
|
||||
Decompositions:
|
||||
1: How many listings of Hamlet are there on IMDB?
|
||||
2: How many listing of Comedy of Errors is there on IMDB?
|
||||
|
||||
Example 2:
|
||||
|
||||
Question: Are birds important to badminton?
|
||||
|
||||
Decompositions:
|
||||
The question needs no decomposition
|
||||
|
||||
Example 3:
|
||||
|
||||
Question: Is it legal for a licensed child driving Mercedes-Benz to be employed in US?
|
||||
|
||||
Decompositions:
|
||||
1: What is the minimum driving age in the US?
|
||||
2: What is the minimum age for someone to be employed in the US?
|
||||
|
||||
Example 4:
|
||||
|
||||
Question: Are all cucumbers the same texture?
|
||||
|
||||
Decompositions:
|
||||
The question needs no decomposition
|
||||
|
||||
Example 5:
|
||||
|
||||
Question: Hydrogen's atomic number squared exceeds number of Spice Girls?
|
||||
|
||||
Decompositions:
|
||||
1: What is the atomic number of hydrogen?
|
||||
2: How many Spice Girls are there?
|
||||
|
||||
Example 6:
|
||||
|
||||
Question: {question}
|
||||
|
||||
Decompositions:
|
||||
"""
|
||||
|
||||
|
||||
class QueryDecompose(BaseQueryExpansion):
|
||||
@result_to_dataframe(["queries"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
queries = self.cast_to_run(previous_result, *args, **kwargs)
|
||||
|
||||
# pop prompt from kwargs
|
||||
prompt = kwargs.pop("prompt", decompose_prompt)
|
||||
kwargs.pop("generator_module_type", None)
|
||||
|
||||
expanded_queries = self._pure(queries, prompt, **kwargs)
|
||||
return self._check_expanded_query(queries, expanded_queries)
|
||||
|
||||
def _pure(
|
||||
self, queries: List[str], prompt: str = decompose_prompt, *args, **kwargs
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
decompose query to little piece of questions.
|
||||
:param queries: List[str], queries to decompose.
|
||||
:param prompt: str, prompt to use for query decomposition.
|
||||
default prompt comes from Visconde's StrategyQA few-shot prompt.
|
||||
:return: List[List[str]], list of decomposed query. Return input query if query is not decomposable.
|
||||
"""
|
||||
full_prompts = []
|
||||
for query in queries:
|
||||
if bool(prompt):
|
||||
full_prompt = f"prompt: {prompt}\n\n question: {query}"
|
||||
else:
|
||||
full_prompt = decompose_prompt.format(question=query)
|
||||
full_prompts.append(full_prompt)
|
||||
input_df = pd.DataFrame({"prompts": full_prompts})
|
||||
result_df = self.generator.pure(previous_result=input_df, *args, **kwargs)
|
||||
answers = result_df["generated_texts"].tolist()
|
||||
results = list(
|
||||
map(lambda x: get_query_decompose(x[0], x[1]), zip(queries, answers))
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def get_query_decompose(query: str, answer: str) -> List[str]:
|
||||
"""
|
||||
decompose query to little piece of questions.
|
||||
:param query: str, query to decompose.
|
||||
:param answer: str, answer from query_decompose function.
|
||||
:return: List[str], list of a decomposed query. Return input query if query is not decomposable.
|
||||
"""
|
||||
if answer.lower() == "the question needs no decomposition":
|
||||
return [query]
|
||||
try:
|
||||
lines = [line.strip() for line in answer.splitlines() if line.strip()]
|
||||
if lines[0].startswith("Decompositions:"):
|
||||
lines.pop(0)
|
||||
questions = [line.split(":", 1)[1].strip() for line in lines if ":" in line]
|
||||
if not questions:
|
||||
return [query]
|
||||
return questions
|
||||
except:
|
||||
return [query]
|
||||
276
autorag-workspace/autorag/nodes/queryexpansion/run.py
Normal file
276
autorag-workspace/autorag/nodes/queryexpansion/run.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.retrieval.run import evaluate_retrieval_node
|
||||
from autorag.schema.metricinput import MetricInput
|
||||
from autorag.strategy import measure_speed, filter_by_threshold, select_best
|
||||
from autorag.support import get_support_modules
|
||||
from autorag.utils.util import make_combinations, explode
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def run_query_expansion_node(
|
||||
modules: List,
|
||||
module_params: List[Dict],
|
||||
previous_result: pd.DataFrame,
|
||||
node_line_dir: str,
|
||||
strategies: Dict,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Run evaluation and select the best module among query expansion node results.
|
||||
Initially, retrieval is run using expanded_queries, the result of the query_expansion module.
|
||||
The retrieval module is run as a combination of the retrieval_modules in strategies.
|
||||
If there are multiple retrieval_modules, run them all and choose the best result.
|
||||
If there are no retrieval_modules, run them with the default of bm25.
|
||||
In this way, the best result is selected for each module, and then the best result is selected.
|
||||
|
||||
:param modules: Query expansion modules to run.
|
||||
:param module_params: Query expansion module parameters.
|
||||
:param previous_result: Previous result dataframe.
|
||||
In this case, it would be qa data.
|
||||
:param node_line_dir: This node line's directory.
|
||||
:param strategies: Strategies for query expansion node.
|
||||
:return: The best result dataframe.
|
||||
"""
|
||||
if not os.path.exists(node_line_dir):
|
||||
os.makedirs(node_line_dir)
|
||||
node_dir = os.path.join(node_line_dir, "query_expansion")
|
||||
if not os.path.exists(node_dir):
|
||||
os.makedirs(node_dir)
|
||||
project_dir = pathlib.PurePath(node_line_dir).parent.parent
|
||||
|
||||
# run query expansion
|
||||
results, execution_times = zip(
|
||||
*map(
|
||||
lambda task: measure_speed(
|
||||
task[0].run_evaluator,
|
||||
project_dir=project_dir,
|
||||
previous_result=previous_result,
|
||||
**task[1],
|
||||
),
|
||||
zip(modules, module_params),
|
||||
)
|
||||
)
|
||||
average_times = list(map(lambda x: x / len(results[0]), execution_times))
|
||||
|
||||
# save results to folder
|
||||
pseudo_module_params = deepcopy(module_params)
|
||||
for i, module_param in enumerate(pseudo_module_params):
|
||||
if "prompt" in module_params:
|
||||
module_param["prompt"] = str(i)
|
||||
filepaths = list(
|
||||
map(lambda x: os.path.join(node_dir, f"{x}.parquet"), range(len(modules)))
|
||||
)
|
||||
list(
|
||||
map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))
|
||||
) # execute save to parquet
|
||||
filenames = list(map(lambda x: os.path.basename(x), filepaths))
|
||||
|
||||
# make summary file
|
||||
summary_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
"module_name": list(map(lambda module: module.__name__, modules)),
|
||||
"module_params": module_params,
|
||||
"execution_time": average_times,
|
||||
}
|
||||
)
|
||||
|
||||
# Run evaluation when there are more than one module.
|
||||
if len(modules) > 1:
|
||||
# pop general keys from strategies (e.g. metrics, speed_threshold)
|
||||
general_key = ["metrics", "speed_threshold", "strategy"]
|
||||
general_strategy = dict(
|
||||
filter(lambda x: x[0] in general_key, strategies.items())
|
||||
)
|
||||
extra_strategy = dict(
|
||||
filter(lambda x: x[0] not in general_key, strategies.items())
|
||||
)
|
||||
|
||||
# first, filter by threshold if it is enabled.
|
||||
if general_strategy.get("speed_threshold") is not None:
|
||||
results, filenames = filter_by_threshold(
|
||||
results, average_times, general_strategy["speed_threshold"], filenames
|
||||
)
|
||||
|
||||
# check metrics in strategy
|
||||
if general_strategy.get("metrics") is None:
|
||||
raise ValueError(
|
||||
"You must at least one metrics for query expansion evaluation."
|
||||
)
|
||||
|
||||
if extra_strategy.get("top_k") is None:
|
||||
extra_strategy["top_k"] = 10 # default value
|
||||
|
||||
# get retrieval modules from strategy
|
||||
retrieval_callables, retrieval_params = make_retrieval_callable_params(
|
||||
extra_strategy
|
||||
)
|
||||
|
||||
# get retrieval_gt
|
||||
retrieval_gt = pd.read_parquet(
|
||||
os.path.join(project_dir, "data", "qa.parquet"), engine="pyarrow"
|
||||
)["retrieval_gt"].tolist()
|
||||
|
||||
# make rows to metric_inputs
|
||||
metric_inputs = [
|
||||
MetricInput(retrieval_gt=ret_gt, query=query, generation_gt=gen_gt)
|
||||
for ret_gt, query, gen_gt in zip(
|
||||
retrieval_gt,
|
||||
previous_result["query"].tolist(),
|
||||
previous_result["generation_gt"].tolist(),
|
||||
)
|
||||
]
|
||||
|
||||
# run evaluation
|
||||
evaluation_results = list(
|
||||
map(
|
||||
lambda result: evaluate_one_query_expansion_node(
|
||||
retrieval_callables,
|
||||
retrieval_params,
|
||||
[
|
||||
setattr(metric_input, "queries", queries) or metric_input
|
||||
for metric_input, queries in zip(
|
||||
metric_inputs, result["queries"].to_list()
|
||||
)
|
||||
],
|
||||
general_strategy["metrics"],
|
||||
project_dir,
|
||||
previous_result,
|
||||
general_strategy.get("strategy", "mean"),
|
||||
),
|
||||
results,
|
||||
)
|
||||
)
|
||||
|
||||
evaluation_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
**{
|
||||
f"query_expansion_{metric_name}": list(
|
||||
map(lambda x: x[metric_name].mean(), evaluation_results)
|
||||
)
|
||||
for metric_name in general_strategy["metrics"]
|
||||
},
|
||||
}
|
||||
)
|
||||
summary_df = pd.merge(
|
||||
on="filename", left=summary_df, right=evaluation_df, how="left"
|
||||
)
|
||||
|
||||
best_result, best_filename = select_best(
|
||||
evaluation_results,
|
||||
general_strategy["metrics"],
|
||||
filenames,
|
||||
strategies.get("strategy", "mean"),
|
||||
)
|
||||
# change metric name columns to query_expansion_metric_name
|
||||
best_result = best_result.rename(
|
||||
columns={
|
||||
metric_name: f"query_expansion_{metric_name}"
|
||||
for metric_name in strategies["metrics"]
|
||||
}
|
||||
)
|
||||
best_result = best_result.drop(
|
||||
columns=["retrieved_contents", "retrieved_ids", "retrieve_scores"]
|
||||
)
|
||||
else:
|
||||
best_result, best_filename = results[0], filenames[0]
|
||||
best_result = pd.concat([previous_result, best_result], axis=1)
|
||||
|
||||
# add 'is_best' column at summary file
|
||||
summary_df["is_best"] = summary_df["filename"] == best_filename
|
||||
|
||||
# save files
|
||||
summary_df.to_csv(os.path.join(node_dir, "summary.csv"), index=False)
|
||||
best_result.to_parquet(
|
||||
os.path.join(node_dir, f"best_{os.path.splitext(best_filename)[0]}.parquet"),
|
||||
index=False,
|
||||
)
|
||||
|
||||
return best_result
|
||||
|
||||
|
||||
def evaluate_one_query_expansion_node(
|
||||
retrieval_funcs: List,
|
||||
retrieval_params: List[Dict],
|
||||
metric_inputs: List[MetricInput],
|
||||
metrics: List[str],
|
||||
project_dir,
|
||||
previous_result: pd.DataFrame,
|
||||
strategy_name: str,
|
||||
) -> pd.DataFrame:
|
||||
previous_result["queries"] = [
|
||||
metric_input.queries for metric_input in metric_inputs
|
||||
]
|
||||
retrieval_results = list(
|
||||
map(
|
||||
lambda x: x[0].run_evaluator(
|
||||
project_dir=project_dir, previous_result=previous_result, **x[1]
|
||||
),
|
||||
zip(retrieval_funcs, retrieval_params),
|
||||
)
|
||||
)
|
||||
evaluation_results = list(
|
||||
map(
|
||||
lambda x: evaluate_retrieval_node(
|
||||
x,
|
||||
metric_inputs,
|
||||
metrics,
|
||||
),
|
||||
retrieval_results,
|
||||
)
|
||||
)
|
||||
best_result, _ = select_best(
|
||||
evaluation_results, metrics, strategy_name=strategy_name
|
||||
)
|
||||
best_result = pd.concat([previous_result, best_result], axis=1)
|
||||
return best_result
|
||||
|
||||
|
||||
def make_retrieval_callable_params(strategy_dict: Dict):
|
||||
"""
|
||||
strategy_dict looks like this:
|
||||
|
||||
.. Code:: json
|
||||
|
||||
{
|
||||
"metrics": ["retrieval_f1", "retrieval_recall"],
|
||||
"top_k": 50,
|
||||
"retrieval_modules": [
|
||||
{"module_type": "bm25"},
|
||||
{"module_type": "vectordb", "embedding_model": ["openai", "huggingface"]}
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
node_dict = deepcopy(strategy_dict)
|
||||
retrieval_module_list: Optional[List[Dict]] = node_dict.pop(
|
||||
"retrieval_modules", None
|
||||
)
|
||||
if retrieval_module_list is None:
|
||||
retrieval_module_list = [
|
||||
{
|
||||
"module_type": "bm25",
|
||||
}
|
||||
]
|
||||
node_params = node_dict
|
||||
modules = list(
|
||||
map(
|
||||
lambda module_dict: get_support_modules(module_dict.pop("module_type")),
|
||||
retrieval_module_list,
|
||||
)
|
||||
)
|
||||
param_combinations = list(
|
||||
map(
|
||||
lambda module_dict: make_combinations({**module_dict, **node_params}),
|
||||
retrieval_module_list,
|
||||
)
|
||||
)
|
||||
return explode(modules, param_combinations)
|
||||
Reference in New Issue
Block a user