Fix Dockerfile build issue
This commit is contained in:
4
autorag/nodes/passagecompressor/__init__.py
Normal file
4
autorag/nodes/passagecompressor/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .longllmlingua import LongLLMLingua
|
||||
from .pass_compressor import PassCompressor
|
||||
from .refine import Refine
|
||||
from .tree_summarize import TreeSummarize
|
||||
83
autorag/nodes/passagecompressor/base.py
Normal file
83
autorag/nodes/passagecompressor/base.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from llama_index.core.llms import LLM
|
||||
|
||||
from autorag import generator_models
|
||||
from autorag.schema import BaseModule
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class BasePassageCompressor(BaseModule, metaclass=abc.ABCMeta):
|
||||
def __init__(self, project_dir: str, *args, **kwargs):
|
||||
logger.info(
|
||||
f"Initialize passage compressor node - {self.__class__.__name__} module..."
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
logger.info(
|
||||
f"Deleting passage compressor node - {self.__class__.__name__} module..."
|
||||
)
|
||||
|
||||
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
logger.info(
|
||||
f"Running passage compressor node - {self.__class__.__name__} module..."
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
column in previous_result.columns
|
||||
for column in [
|
||||
"query",
|
||||
"retrieved_contents",
|
||||
]
|
||||
]
|
||||
), "previous_result must have retrieved_contents, retrieved_ids, and retrieve_scores columns."
|
||||
assert len(previous_result) > 0, "previous_result must have at least one row."
|
||||
|
||||
queries = previous_result["query"].tolist()
|
||||
retrieved_contents = previous_result["retrieved_contents"].tolist()
|
||||
return queries, retrieved_contents
|
||||
|
||||
|
||||
class LlamaIndexCompressor(BasePassageCompressor, metaclass=abc.ABCMeta):
|
||||
param_list = ["prompt", "chat_prompt", "batch"]
|
||||
|
||||
def __init__(self, project_dir: str, **kwargs):
|
||||
"""
|
||||
Initialize passage compressor module.
|
||||
|
||||
:param project_dir: The project directory
|
||||
:param llm: The llm name that will be used to summarize.
|
||||
The LlamaIndex LLM model can be used in here.
|
||||
:param kwargs: Extra parameter for init llm
|
||||
"""
|
||||
super().__init__(project_dir)
|
||||
kwargs_dict = dict(
|
||||
filter(lambda x: x[0] not in self.param_list, kwargs.items())
|
||||
)
|
||||
llm_name = kwargs_dict.pop("llm")
|
||||
self.llm: LLM = make_llm(llm_name, kwargs_dict)
|
||||
|
||||
def __del__(self):
|
||||
del self.llm
|
||||
super().__del__()
|
||||
|
||||
@result_to_dataframe(["retrieved_contents"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
queries, retrieved_contents = self.cast_to_run(previous_result)
|
||||
param_dict = dict(filter(lambda x: x[0] in self.param_list, kwargs.items()))
|
||||
result = self._pure(queries, retrieved_contents, **param_dict)
|
||||
return list(map(lambda x: [x], result))
|
||||
|
||||
|
||||
def make_llm(llm_name: str, kwargs: Dict) -> LLM:
|
||||
if llm_name not in generator_models:
|
||||
raise KeyError(
|
||||
f"{llm_name} is not supported. "
|
||||
"You can add it manually by calling autorag.generator_models."
|
||||
)
|
||||
return generator_models[llm_name](**kwargs)
|
||||
115
autorag/nodes/passagecompressor/longllmlingua.py
Normal file
115
autorag/nodes/passagecompressor/longllmlingua.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.passagecompressor.base import BasePassageCompressor
|
||||
from autorag.utils.util import pop_params, result_to_dataframe, empty_cuda_cache
|
||||
|
||||
|
||||
# TODO: Parallel Processing Refactoring at #460
|
||||
|
||||
|
||||
class LongLLMLingua(BasePassageCompressor):
|
||||
def __init__(
|
||||
self, project_dir: str, model_name: str = "NousResearch/Llama-2-7b-hf", **kwargs
|
||||
):
|
||||
try:
|
||||
from llmlingua import PromptCompressor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LongLLMLingua is not installed. Please install it by running `pip install llmlingua`."
|
||||
)
|
||||
|
||||
super().__init__(project_dir)
|
||||
model_init_params = pop_params(PromptCompressor.__init__, kwargs)
|
||||
self.llm_lingua = PromptCompressor(model_name=model_name, **model_init_params)
|
||||
|
||||
def __del__(self):
|
||||
del self.llm_lingua
|
||||
empty_cuda_cache()
|
||||
super().__del__()
|
||||
|
||||
@result_to_dataframe(["retrieved_contents"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
queries, retrieved_contents = self.cast_to_run(previous_result)
|
||||
results = self._pure(queries, retrieved_contents, **kwargs)
|
||||
return list(map(lambda x: [x], results))
|
||||
|
||||
def _pure(
|
||||
self,
|
||||
queries: List[str],
|
||||
contents: List[List[str]],
|
||||
instructions: Optional[str] = None,
|
||||
target_token: int = 300,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Compresses the retrieved texts using LongLLMLingua.
|
||||
For more information, visit https://github.com/microsoft/LLMLingua.
|
||||
|
||||
:param queries: The queries for retrieved passages.
|
||||
:param contents: The contents of retrieved passages.
|
||||
:param model_name: The model name to use for compression.
|
||||
The default is "NousResearch/Llama-2-7b-hf".
|
||||
:param instructions: The instructions for compression.
|
||||
Default is None. When it is None, it will use default instructions.
|
||||
:param target_token: The target token for compression.
|
||||
Default is 300.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: The list of compressed texts.
|
||||
"""
|
||||
if instructions is None:
|
||||
instructions = "Given the context, please answer the final question"
|
||||
results = [
|
||||
llmlingua_pure(
|
||||
query, contents_, self.llm_lingua, instructions, target_token, **kwargs
|
||||
)
|
||||
for query, contents_ in zip(queries, contents)
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def llmlingua_pure(
|
||||
query: str,
|
||||
contents: List[str],
|
||||
llm_lingua,
|
||||
instructions: str,
|
||||
target_token: int = 300,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Return the compressed text.
|
||||
|
||||
:param query: The query for retrieved passages.
|
||||
:param contents: The contents of retrieved passages.
|
||||
:param llm_lingua: The llm instance, that will be used to compress.
|
||||
:param instructions: The instructions for compression.
|
||||
:param target_token: The target token for compression.
|
||||
Default is 300.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: The compressed text.
|
||||
"""
|
||||
try:
|
||||
from llmlingua import PromptCompressor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LongLLMLingua is not installed. Please install it by running `pip install llmlingua`."
|
||||
)
|
||||
# split by "\n\n" (recommended by LongLLMLingua authors)
|
||||
new_context_texts = [c for context in contents for c in context.split("\n\n")]
|
||||
compress_prompt_params = pop_params(PromptCompressor.compress_prompt, kwargs)
|
||||
compressed_prompt = llm_lingua.compress_prompt(
|
||||
new_context_texts,
|
||||
question=query,
|
||||
instruction=instructions,
|
||||
rank_method="longllmlingua",
|
||||
target_token=target_token,
|
||||
**compress_prompt_params,
|
||||
)
|
||||
compressed_prompt_txt = compressed_prompt["compressed_prompt"]
|
||||
|
||||
# separate out the question and instruction
|
||||
result = "\n\n".join(compressed_prompt_txt.split("\n\n")[1:-1])
|
||||
|
||||
return result
|
||||
16
autorag/nodes/passagecompressor/pass_compressor.py
Normal file
16
autorag/nodes/passagecompressor/pass_compressor.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.nodes.passagecompressor.base import BasePassageCompressor
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
|
||||
class PassCompressor(BasePassageCompressor):
|
||||
@result_to_dataframe(["retrieved_contents"])
|
||||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
|
||||
_, contents = self.cast_to_run(previous_result)
|
||||
return self._pure(contents)
|
||||
|
||||
def _pure(self, contents: List[List[str]]):
|
||||
return contents
|
||||
54
autorag/nodes/passagecompressor/refine.py
Normal file
54
autorag/nodes/passagecompressor/refine.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.core import PromptTemplate
|
||||
from llama_index.core.prompts import PromptType
|
||||
from llama_index.core.prompts.utils import is_chat_model
|
||||
from llama_index.core.response_synthesizers import Refine as rf
|
||||
|
||||
from autorag.nodes.passagecompressor.base import LlamaIndexCompressor
|
||||
from autorag.utils.util import get_event_loop, process_batch
|
||||
|
||||
|
||||
class Refine(LlamaIndexCompressor):
|
||||
def _pure(
|
||||
self,
|
||||
queries: List[str],
|
||||
contents: List[List[str]],
|
||||
prompt: Optional[str] = None,
|
||||
chat_prompt: Optional[str] = None,
|
||||
batch: int = 16,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Refine a response to a query across text chunks.
|
||||
This function is a wrapper for llama_index.response_synthesizers.Refine.
|
||||
For more information, visit https://docs.llamaindex.ai/en/stable/examples/response_synthesizers/refine/.
|
||||
|
||||
:param queries: The queries for retrieved passages.
|
||||
:param contents: The contents of retrieved passages.
|
||||
:param prompt: The prompt template for refine.
|
||||
If you want to use chat prompt, you should pass chat_prompt instead.
|
||||
At prompt, you must specify where to put 'context_msg' and 'query_str'.
|
||||
Default is None. When it is None, it will use llama index default prompt.
|
||||
:param chat_prompt: The chat prompt template for refine.
|
||||
If you want to use normal prompt, you should pass prompt instead.
|
||||
At prompt, you must specify where to put 'context_msg' and 'query_str'.
|
||||
Default is None. When it is None, it will use llama index default chat prompt.
|
||||
:param batch: The batch size for llm.
|
||||
Set low if you face some errors.
|
||||
Default is 16.
|
||||
:return: The list of compressed texts.
|
||||
"""
|
||||
if prompt is not None and not is_chat_model(self.llm):
|
||||
refine_template = PromptTemplate(prompt, prompt_type=PromptType.REFINE)
|
||||
elif chat_prompt is not None and is_chat_model(self.llm):
|
||||
refine_template = PromptTemplate(chat_prompt, prompt_type=PromptType.REFINE)
|
||||
else:
|
||||
refine_template = None
|
||||
summarizer = rf(llm=self.llm, refine_template=refine_template, verbose=True)
|
||||
tasks = [
|
||||
summarizer.aget_response(query, content)
|
||||
for query, content in zip(queries, contents)
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
|
||||
return results
|
||||
186
autorag/nodes/passagecompressor/run.py
Normal file
186
autorag/nodes/passagecompressor/run.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import os.path
|
||||
import pathlib
|
||||
from typing import List, Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.evaluation.metric import (
|
||||
retrieval_token_recall,
|
||||
retrieval_token_precision,
|
||||
retrieval_token_f1,
|
||||
)
|
||||
from autorag.schema.metricinput import MetricInput
|
||||
from autorag.strategy import measure_speed, filter_by_threshold, select_best
|
||||
from autorag.utils.util import fetch_contents
|
||||
|
||||
|
||||
def run_passage_compressor_node(
|
||||
modules: List,
|
||||
module_params: List[Dict],
|
||||
previous_result: pd.DataFrame,
|
||||
node_line_dir: str,
|
||||
strategies: Dict,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Run evaluation and select the best module among passage compressor modules.
|
||||
|
||||
:param modules: Passage compressor modules to run.
|
||||
:param module_params: Passage compressor module parameters.
|
||||
:param previous_result: Previous result dataframe.
|
||||
Could be retrieval, reranker modules result.
|
||||
It means it must contain 'query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores' columns.
|
||||
:param node_line_dir: This node line's directory.
|
||||
:param strategies: Strategies for passage compressor node.
|
||||
In this node, we use
|
||||
You can skip evaluation when you use only one module and a module parameter.
|
||||
:return: The best result dataframe with previous result columns.
|
||||
This node will replace 'retrieved_contents' to compressed passages, so its length will be one.
|
||||
"""
|
||||
if not os.path.exists(node_line_dir):
|
||||
os.makedirs(node_line_dir)
|
||||
project_dir = pathlib.PurePath(node_line_dir).parent.parent
|
||||
data_dir = os.path.join(project_dir, "data")
|
||||
save_dir = os.path.join(node_line_dir, "passage_compressor")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
# make retrieval contents gt
|
||||
qa_data = pd.read_parquet(os.path.join(data_dir, "qa.parquet"), engine="pyarrow")
|
||||
corpus_data = pd.read_parquet(
|
||||
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
|
||||
)
|
||||
# check qa_data have retrieval_gt
|
||||
assert all(
|
||||
len(x[0]) > 0 for x in qa_data["retrieval_gt"].tolist()
|
||||
), "Can't use passage compressor if you don't have retrieval gt values in QA dataset."
|
||||
|
||||
# run modules
|
||||
results, execution_times = zip(
|
||||
*map(
|
||||
lambda task: measure_speed(
|
||||
task[0].run_evaluator,
|
||||
project_dir=project_dir,
|
||||
previous_result=previous_result,
|
||||
**task[1],
|
||||
),
|
||||
zip(modules, module_params),
|
||||
)
|
||||
)
|
||||
results = list(results)
|
||||
average_times = list(map(lambda x: x / len(results[0]), execution_times))
|
||||
|
||||
retrieval_gt_contents = list(
|
||||
map(lambda x: fetch_contents(corpus_data, x), qa_data["retrieval_gt"].tolist())
|
||||
)
|
||||
|
||||
metric_inputs = [
|
||||
MetricInput(retrieval_gt_contents=ret_cont_gt)
|
||||
for ret_cont_gt in retrieval_gt_contents
|
||||
]
|
||||
|
||||
# run metrics before filtering
|
||||
if strategies.get("metrics") is None:
|
||||
raise ValueError(
|
||||
"You must at least one metrics for retrieval contents evaluation."
|
||||
"It can be 'retrieval_token_f1', 'retrieval_token_precision', 'retrieval_token_recall'."
|
||||
)
|
||||
results = list(
|
||||
map(
|
||||
lambda x: evaluate_passage_compressor_node(
|
||||
x, metric_inputs, strategies.get("metrics")
|
||||
),
|
||||
results,
|
||||
)
|
||||
)
|
||||
|
||||
# save results to folder
|
||||
filepaths = list(
|
||||
map(lambda x: os.path.join(save_dir, f"{x}.parquet"), range(len(modules)))
|
||||
)
|
||||
list(
|
||||
map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))
|
||||
) # execute save to parquet
|
||||
filenames = list(map(lambda x: os.path.basename(x), filepaths))
|
||||
|
||||
# make summary file
|
||||
summary_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
"module_name": list(map(lambda module: module.__name__, modules)),
|
||||
"module_params": module_params,
|
||||
"execution_time": average_times,
|
||||
**{
|
||||
f"passage_compressor_{metric}": list(
|
||||
map(lambda result: result[metric].mean(), results)
|
||||
)
|
||||
for metric in strategies.get("metrics")
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# filter by strategies
|
||||
if strategies.get("speed_threshold") is not None:
|
||||
results, filenames = filter_by_threshold(
|
||||
results, average_times, strategies["speed_threshold"], filenames
|
||||
)
|
||||
selected_result, selected_filename = select_best(
|
||||
results,
|
||||
strategies.get("metrics"),
|
||||
filenames,
|
||||
strategies.get("strategy", "mean"),
|
||||
)
|
||||
new_retrieved_contents = selected_result["retrieved_contents"]
|
||||
previous_result["retrieved_contents"] = new_retrieved_contents
|
||||
selected_result = selected_result.drop(columns=["retrieved_contents"])
|
||||
best_result = pd.concat([previous_result, selected_result], axis=1)
|
||||
|
||||
# add 'is_best' column to summary file
|
||||
summary_df["is_best"] = summary_df["filename"] == selected_filename
|
||||
|
||||
# add prefix 'passage_compressor' to best_result columns
|
||||
best_result = best_result.rename(
|
||||
columns={
|
||||
metric_name: f"passage_compressor_{metric_name}"
|
||||
for metric_name in strategies.get("metrics")
|
||||
}
|
||||
)
|
||||
|
||||
# save the result files
|
||||
best_result.to_parquet(
|
||||
os.path.join(
|
||||
save_dir, f"best_{os.path.splitext(selected_filename)[0]}.parquet"
|
||||
),
|
||||
index=False,
|
||||
)
|
||||
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False)
|
||||
return best_result
|
||||
|
||||
|
||||
def evaluate_passage_compressor_node(
|
||||
result_df: pd.DataFrame, metric_inputs: List[MetricInput], metrics: List[str]
|
||||
):
|
||||
metric_funcs = {
|
||||
retrieval_token_recall.__name__: retrieval_token_recall,
|
||||
retrieval_token_precision.__name__: retrieval_token_precision,
|
||||
retrieval_token_f1.__name__: retrieval_token_f1,
|
||||
}
|
||||
for metric_input, generated_text in zip(
|
||||
metric_inputs, result_df["retrieved_contents"].tolist()
|
||||
):
|
||||
metric_input.retrieved_contents = generated_text
|
||||
metrics = list(filter(lambda x: x in metric_funcs.keys(), metrics))
|
||||
if len(metrics) <= 0:
|
||||
raise ValueError(f"metrics must be one of {metric_funcs.keys()}")
|
||||
metrics_scores = dict(
|
||||
map(
|
||||
lambda metric: (
|
||||
metric,
|
||||
metric_funcs[metric](
|
||||
metric_inputs=metric_inputs,
|
||||
),
|
||||
),
|
||||
metrics,
|
||||
)
|
||||
)
|
||||
result_df = pd.concat([result_df, pd.DataFrame(metrics_scores)], axis=1)
|
||||
return result_df
|
||||
56
autorag/nodes/passagecompressor/tree_summarize.py
Normal file
56
autorag/nodes/passagecompressor/tree_summarize.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.core import PromptTemplate
|
||||
from llama_index.core.prompts import PromptType
|
||||
from llama_index.core.prompts.utils import is_chat_model
|
||||
from llama_index.core.response_synthesizers import TreeSummarize as ts
|
||||
|
||||
from autorag.nodes.passagecompressor.base import LlamaIndexCompressor
|
||||
from autorag.utils.util import get_event_loop, process_batch
|
||||
|
||||
|
||||
class TreeSummarize(LlamaIndexCompressor):
|
||||
def _pure(
|
||||
self,
|
||||
queries: List[str],
|
||||
contents: List[List[str]],
|
||||
prompt: Optional[str] = None,
|
||||
chat_prompt: Optional[str] = None,
|
||||
batch: int = 16,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Recursively merge retrieved texts and summarizes them in a bottom-up fashion.
|
||||
This function is a wrapper for llama_index.response_synthesizers.TreeSummarize.
|
||||
For more information, visit https://docs.llamaindex.ai/en/latest/examples/response_synthesizers/tree_summarize.html.
|
||||
|
||||
:param queries: The queries for retrieved passages.
|
||||
:param contents: The contents of retrieved passages.
|
||||
:param prompt: The prompt template for summarization.
|
||||
If you want to use chat prompt, you should pass chat_prompt instead.
|
||||
At prompt, you must specify where to put 'context_str' and 'query_str'.
|
||||
Default is None. When it is None, it will use llama index default prompt.
|
||||
:param chat_prompt: The chat prompt template for summarization.
|
||||
If you want to use normal prompt, you should pass prompt instead.
|
||||
At prompt, you must specify where to put 'context_str' and 'query_str'.
|
||||
Default is None. When it is None, it will use llama index default chat prompt.
|
||||
:param batch: The batch size for llm.
|
||||
Set low if you face some errors.
|
||||
Default is 16.
|
||||
:return: The list of compressed texts.
|
||||
"""
|
||||
if prompt is not None and not is_chat_model(self.llm):
|
||||
summary_template = PromptTemplate(prompt, prompt_type=PromptType.SUMMARY)
|
||||
elif chat_prompt is not None and is_chat_model(self.llm):
|
||||
summary_template = PromptTemplate(
|
||||
chat_prompt, prompt_type=PromptType.SUMMARY
|
||||
)
|
||||
else:
|
||||
summary_template = None
|
||||
summarizer = ts(llm=self.llm, summary_template=summary_template, use_async=True)
|
||||
tasks = [
|
||||
summarizer.aget_response(query, content)
|
||||
for query, content in zip(queries, contents)
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
|
||||
return results
|
||||
Reference in New Issue
Block a user