Fix Dockerfile build issue
This commit is contained in:
109
autorag/data/__init__.py
Normal file
109
autorag/data/__init__.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import logging
|
||||
from typing import List, Callable
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
PDFMinerLoader,
|
||||
PDFPlumberLoader,
|
||||
PyPDFium2Loader,
|
||||
PyPDFLoader,
|
||||
PyMuPDFLoader,
|
||||
UnstructuredPDFLoader,
|
||||
CSVLoader,
|
||||
JSONLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
BSHTMLLoader,
|
||||
UnstructuredXMLLoader,
|
||||
DirectoryLoader,
|
||||
)
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
from langchain_upstage import UpstageDocumentParseLoader
|
||||
|
||||
from llama_index.core.node_parser import (
|
||||
TokenTextSplitter,
|
||||
SentenceSplitter,
|
||||
SentenceWindowNodeParser,
|
||||
SemanticSplitterNodeParser,
|
||||
SemanticDoubleMergingSplitterNodeParser,
|
||||
SimpleFileNodeParser,
|
||||
)
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
CharacterTextSplitter,
|
||||
KonlpyTextSplitter,
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
|
||||
from autorag import LazyInit
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
parse_modules = {
|
||||
# PDF
|
||||
"pdfminer": PDFMinerLoader,
|
||||
"pdfplumber": PDFPlumberLoader,
|
||||
"pypdfium2": PyPDFium2Loader,
|
||||
"pypdf": PyPDFLoader,
|
||||
"pymupdf": PyMuPDFLoader,
|
||||
"unstructuredpdf": UnstructuredPDFLoader,
|
||||
# Common File Types
|
||||
# 1. CSV
|
||||
"csv": CSVLoader,
|
||||
# 2. JSON
|
||||
"json": JSONLoader,
|
||||
# 3. Markdown
|
||||
"unstructuredmarkdown": UnstructuredMarkdownLoader,
|
||||
# 4. HTML
|
||||
"bshtml": BSHTMLLoader,
|
||||
# 5. XML
|
||||
"unstructuredxml": UnstructuredXMLLoader,
|
||||
# 6. All files
|
||||
"directory": DirectoryLoader,
|
||||
"unstructured": UnstructuredLoader,
|
||||
"upstagedocumentparse": UpstageDocumentParseLoader,
|
||||
}
|
||||
|
||||
chunk_modules = {
|
||||
# Llama Index
|
||||
# Token
|
||||
"token": TokenTextSplitter,
|
||||
# Sentence
|
||||
"sentence": SentenceSplitter,
|
||||
# window
|
||||
"sentencewindow": SentenceWindowNodeParser,
|
||||
# Semantic
|
||||
"semantic_llama_index": SemanticSplitterNodeParser,
|
||||
"semanticdoublemerging": SemanticDoubleMergingSplitterNodeParser,
|
||||
# Simple
|
||||
"simplefile": SimpleFileNodeParser,
|
||||
# LangChain
|
||||
# Token
|
||||
"sentencetransformerstoken": SentenceTransformersTokenTextSplitter,
|
||||
# Character
|
||||
"recursivecharacter": RecursiveCharacterTextSplitter,
|
||||
"character": CharacterTextSplitter,
|
||||
# Sentence
|
||||
"konlpy": KonlpyTextSplitter,
|
||||
}
|
||||
|
||||
|
||||
def split_by_sentence_kiwi() -> Callable[[str], List[str]]:
|
||||
try:
|
||||
from kiwipiepy import Kiwi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install kiwipiepy to use 'ko_kiwi' tokenizer. "
|
||||
"Please install kiwipiepy by running 'pip install kiwipiepy'. "
|
||||
"Or install Korean version of AutoRAG by running 'pip install AutoRAG[ko]'."
|
||||
)
|
||||
kiwi = Kiwi()
|
||||
|
||||
def split(text: str) -> List[str]:
|
||||
kiwi_result = kiwi.split_into_sents(text)
|
||||
sentences = list(map(lambda x: x.text, kiwi_result))
|
||||
|
||||
return sentences
|
||||
|
||||
return split
|
||||
|
||||
|
||||
sentence_splitter_modules = {"kiwi": LazyInit(split_by_sentence_kiwi)}
|
||||
2
autorag/data/chunk/__init__.py
Normal file
2
autorag/data/chunk/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .llama_index_chunk import llama_index_chunk
|
||||
from .langchain_chunk import langchain_chunk
|
||||
128
autorag/data/chunk/base.py
Normal file
128
autorag/data/chunk/base.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import functools
|
||||
import logging
|
||||
from typing import Tuple, List, Dict, Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from autorag.embedding.base import EmbeddingModel
|
||||
from autorag.data import chunk_modules, sentence_splitter_modules
|
||||
from autorag.utils import result_to_dataframe
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def chunker_node(func):
|
||||
@functools.wraps(func)
|
||||
@result_to_dataframe(["doc_id", "contents", "path", "start_end_idx", "metadata"])
|
||||
def wrapper(
|
||||
parsed_result: pd.DataFrame, chunk_method: str, **kwargs
|
||||
) -> Tuple[
|
||||
List[str], List[str], List[str], List[Tuple[int, int]], List[Dict[str, Any]]
|
||||
]:
|
||||
logger.info(f"Running chunker - {func.__name__} module...")
|
||||
|
||||
# get texts from parsed_result
|
||||
texts = parsed_result["texts"].tolist()
|
||||
|
||||
# get filenames from parsed_result when 'add_file_name' is setting
|
||||
file_name_language = kwargs.pop("add_file_name", None)
|
||||
metadata_list = make_metadata_list(parsed_result)
|
||||
|
||||
# run chunk module
|
||||
if func.__name__ in ["llama_index_chunk", "langchain_chunk"]:
|
||||
chunk_instance = __get_chunk_instance(
|
||||
func.__name__, chunk_method.lower(), **kwargs
|
||||
)
|
||||
result = func(
|
||||
texts=texts,
|
||||
chunker=chunk_instance,
|
||||
file_name_language=file_name_language,
|
||||
metadata_list=metadata_list,
|
||||
)
|
||||
del chunk_instance
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"Unsupported module_type: {func.__name__}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def make_metadata_list(parsed_result: pd.DataFrame) -> List[Dict[str, str]]:
|
||||
metadata_list = [{} for _ in range(len(parsed_result["texts"]))]
|
||||
|
||||
def _make_metadata_pure(
|
||||
lst: List[str], key: str, metadata_lst: List[Dict[str, str]]
|
||||
):
|
||||
for value, metadata in zip(lst, metadata_lst):
|
||||
metadata[key] = value
|
||||
|
||||
for column in ["page", "last_modified_datetime", "path"]:
|
||||
if column in parsed_result.columns:
|
||||
_make_metadata_pure(parsed_result[column].tolist(), column, metadata_list)
|
||||
return metadata_list
|
||||
|
||||
|
||||
def __get_chunk_instance(module_type: str, chunk_method: str, **kwargs):
|
||||
# Add sentence_splitter to kwargs
|
||||
sentence_available_methods = [
|
||||
"semantic_llama_index",
|
||||
"semanticdoublemerging",
|
||||
"sentencewindow",
|
||||
]
|
||||
if chunk_method in sentence_available_methods:
|
||||
# llama index default sentence_splitter is 'nltk -PunktSentenceTokenizer'
|
||||
if "sentence_splitter" in kwargs.keys():
|
||||
sentence_splitter_str = kwargs.pop("sentence_splitter")
|
||||
sentence_splitter_func = sentence_splitter_modules[sentence_splitter_str]()
|
||||
kwargs.update({"sentence_splitter": sentence_splitter_func})
|
||||
|
||||
def get_embedding_model(_embed_model_str: str, _module_type: str):
|
||||
if _embed_model_str == "openai":
|
||||
if _module_type == "langchain_chunk":
|
||||
_embed_model_str = "openai_langchain"
|
||||
return EmbeddingModel.load(_embed_model_str)()
|
||||
|
||||
# Add embed_model to kwargs
|
||||
embedding_available_methods = ["semantic_llama_index", "semantic_langchain"]
|
||||
if chunk_method in embedding_available_methods:
|
||||
# there is no default embed_model, so we have to get it parameter and add it.
|
||||
if "embed_model" not in kwargs.keys():
|
||||
raise ValueError(f"embed_model is required for {chunk_method} method.")
|
||||
embed_model_str = kwargs.pop("embed_model")
|
||||
embed_model = get_embedding_model(embed_model_str, module_type)
|
||||
if chunk_method == "semantic_llama_index":
|
||||
kwargs.update({"embed_model": embed_model})
|
||||
elif chunk_method == "semantic_langchain":
|
||||
kwargs.update({"embeddings": embed_model})
|
||||
|
||||
return chunk_modules[chunk_method](**kwargs)
|
||||
|
||||
|
||||
def add_file_name(
|
||||
file_name_language: str, file_names: List[str], chunk_texts: List[str]
|
||||
) -> List[str]:
|
||||
if file_name_language == "en":
|
||||
return list(
|
||||
map(
|
||||
lambda x: f"file_name: {x[1]}\n contents: {x[0]}",
|
||||
zip(chunk_texts, file_names),
|
||||
)
|
||||
)
|
||||
elif file_name_language == "ko":
|
||||
return list(
|
||||
map(
|
||||
lambda x: f"파일 제목: {x[1]}\n 내용: {x[0]}",
|
||||
zip(chunk_texts, file_names),
|
||||
)
|
||||
)
|
||||
elif file_name_language == "ja":
|
||||
return list(
|
||||
map(
|
||||
lambda x: f"ファイル名: {x[1]}\n 内容: {x[0]}",
|
||||
zip(chunk_texts, file_names),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file_name_language: {file_name_language}. Choose from 'en' ,'ko' or 'ja."
|
||||
)
|
||||
76
autorag/data/chunk/langchain_chunk.py
Normal file
76
autorag/data/chunk/langchain_chunk.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from itertools import chain
|
||||
import uuid
|
||||
from typing import Tuple, List, Dict, Any, Optional
|
||||
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
from autorag.data.chunk.base import chunker_node, add_file_name
|
||||
from autorag.data.utils.util import add_essential_metadata, get_start_end_idx
|
||||
|
||||
|
||||
@chunker_node
|
||||
def langchain_chunk(
|
||||
texts: List[str],
|
||||
chunker: TextSplitter,
|
||||
file_name_language: Optional[str] = None,
|
||||
metadata_list: Optional[List[Dict[str, str]]] = None,
|
||||
) -> Tuple[
|
||||
List[str], List[str], List[str], List[Tuple[int, int]], List[Dict[str, Any]]
|
||||
]:
|
||||
"""
|
||||
Chunk texts from the parsed result to use langchain chunk method
|
||||
|
||||
:param texts: The list of texts to chunk from the parsed result
|
||||
:param chunker: A langchain TextSplitter(Chunker) instance.
|
||||
:param file_name_language: The language to use 'add_file_name' feature.
|
||||
You need to set one of 'English' and 'Korean'
|
||||
The 'add_file_name' feature is to add a file_name to chunked_contents.
|
||||
This is used to prevent hallucination by retrieving contents from the wrong document.
|
||||
Default form of 'English' is "file_name: {file_name}\n contents: {content}"
|
||||
:param metadata_list: The list of dict of metadata from the parsed result
|
||||
:return: tuple of lists containing the chunked doc_id, contents, path, start_idx, end_idx and metadata
|
||||
"""
|
||||
results = [
|
||||
langchain_chunk_pure(text, chunker, file_name_language, meta)
|
||||
for text, meta in zip(texts, metadata_list)
|
||||
]
|
||||
|
||||
doc_id, contents, path, start_end_idx, metadata = (
|
||||
list(chain.from_iterable(item)) for item in zip(*results)
|
||||
)
|
||||
|
||||
return doc_id, contents, path, start_end_idx, metadata
|
||||
|
||||
|
||||
def langchain_chunk_pure(
|
||||
text: str,
|
||||
chunker: TextSplitter,
|
||||
file_name_language: Optional[str] = None,
|
||||
_metadata: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
# chunk
|
||||
chunk_results = chunker.create_documents([text], metadatas=[_metadata])
|
||||
|
||||
# make doc_id
|
||||
doc_id = list(str(uuid.uuid4()) for _ in range(len(chunk_results)))
|
||||
|
||||
# make path
|
||||
path_lst = list(map(lambda x: x.metadata.get("path", ""), chunk_results))
|
||||
|
||||
# make contents and start_end_idx
|
||||
if file_name_language:
|
||||
chunked_file_names = list(map(lambda x: os.path.basename(x), path_lst))
|
||||
chunked_texts = list(map(lambda x: x.page_content, chunk_results))
|
||||
start_end_idx = list(map(lambda x: get_start_end_idx(text, x), chunked_texts))
|
||||
contents = add_file_name(file_name_language, chunked_file_names, chunked_texts)
|
||||
else:
|
||||
contents = list(map(lambda node: node.page_content, chunk_results))
|
||||
start_end_idx = list(map(lambda x: get_start_end_idx(text, x), contents))
|
||||
|
||||
# make metadata
|
||||
metadata = list(
|
||||
map(lambda node: add_essential_metadata(node.metadata), chunk_results)
|
||||
)
|
||||
|
||||
return doc_id, contents, path_lst, start_end_idx, metadata
|
||||
96
autorag/data/chunk/llama_index_chunk.py
Normal file
96
autorag/data/chunk/llama_index_chunk.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os.path
|
||||
from itertools import chain
|
||||
from typing import Tuple, List, Dict, Any, Optional
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.node_parser.interface import NodeParser
|
||||
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
from autorag.data.chunk.base import chunker_node, add_file_name
|
||||
from autorag.data.utils.util import (
|
||||
add_essential_metadata_llama_text_node,
|
||||
get_start_end_idx,
|
||||
)
|
||||
|
||||
|
||||
@chunker_node
|
||||
def llama_index_chunk(
|
||||
texts: List[str],
|
||||
chunker: NodeParser,
|
||||
file_name_language: Optional[str] = None,
|
||||
metadata_list: Optional[List[Dict[str, str]]] = None,
|
||||
batch: int = 8,
|
||||
) -> Tuple[
|
||||
List[str], List[str], List[str], List[Tuple[int, int]], List[Dict[str, Any]]
|
||||
]:
|
||||
"""
|
||||
Chunk texts from the parsed result to use llama index chunk method
|
||||
|
||||
:param texts: The list of texts to chunk from the parsed result
|
||||
:param chunker: A llama index NodeParser(Chunker) instance.
|
||||
:param file_name_language: The language to use 'add_file_name' feature.
|
||||
You need to set one of 'English' and 'Korean'
|
||||
The 'add_file_name' feature is to add a file_name to chunked_contents.
|
||||
This is used to prevent hallucination by retrieving contents from the wrong document.
|
||||
Default form of 'English' is "file_name: {file_name}\n contents: {content}"
|
||||
:param metadata_list: The list of dict of metadata from the parsed result
|
||||
:param batch: The batch size for chunk texts. Default is 8
|
||||
:return: tuple of lists containing the chunked doc_id, contents, path, start_idx, end_idx and metadata
|
||||
"""
|
||||
tasks = [
|
||||
llama_index_chunk_pure(text, chunker, file_name_language, meta)
|
||||
for text, meta in zip(texts, metadata_list)
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
doc_id, contents, path, start_end_idx, metadata = (
|
||||
list(chain.from_iterable(item)) for item in zip(*results)
|
||||
)
|
||||
|
||||
return list(doc_id), list(contents), list(path), list(start_end_idx), list(metadata)
|
||||
|
||||
|
||||
async def llama_index_chunk_pure(
|
||||
text: str,
|
||||
chunker: NodeParser,
|
||||
file_name_language: Optional[str] = None,
|
||||
_metadata: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
# set document
|
||||
document = [Document(text=text, metadata=_metadata)]
|
||||
|
||||
# chunk document
|
||||
chunk_results = await chunker.aget_nodes_from_documents(documents=document)
|
||||
|
||||
# make doc_id
|
||||
doc_id = list(map(lambda node: node.node_id, chunk_results))
|
||||
|
||||
# make path
|
||||
path_lst = list(map(lambda x: x.metadata.get("path", ""), chunk_results))
|
||||
|
||||
# make contents and start_end_idx
|
||||
if file_name_language:
|
||||
chunked_file_names = list(map(lambda x: os.path.basename(x), path_lst))
|
||||
chunked_texts = list(map(lambda x: x.text, chunk_results))
|
||||
start_end_idx = list(
|
||||
map(
|
||||
lambda x: get_start_end_idx(text, x),
|
||||
chunked_texts,
|
||||
)
|
||||
)
|
||||
contents = add_file_name(file_name_language, chunked_file_names, chunked_texts)
|
||||
else:
|
||||
contents = list(map(lambda x: x.text, chunk_results))
|
||||
start_end_idx = list(map(lambda x: get_start_end_idx(text, x), contents))
|
||||
|
||||
metadata = list(
|
||||
map(
|
||||
lambda node: add_essential_metadata_llama_text_node(
|
||||
node.metadata, node.relationships
|
||||
),
|
||||
chunk_results,
|
||||
)
|
||||
)
|
||||
|
||||
return doc_id, contents, path_lst, start_end_idx, metadata
|
||||
38
autorag/data/chunk/run.py
Normal file
38
autorag/data/chunk/run.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
from typing import Callable, List, Dict
|
||||
import pandas as pd
|
||||
|
||||
from autorag.strategy import measure_speed
|
||||
|
||||
|
||||
def run_chunker(
|
||||
modules: List[Callable],
|
||||
module_params: List[Dict],
|
||||
parsed_result: pd.DataFrame,
|
||||
project_dir: str,
|
||||
):
|
||||
results, execution_times = zip(
|
||||
*map(
|
||||
lambda x: measure_speed(x[0], parsed_result=parsed_result, **x[1]),
|
||||
zip(modules, module_params),
|
||||
)
|
||||
)
|
||||
average_times = list(map(lambda x: x / len(results[0]), execution_times))
|
||||
|
||||
# save results to parquet files
|
||||
filepaths = list(
|
||||
map(lambda x: os.path.join(project_dir, f"{x}.parquet"), range(len(modules)))
|
||||
)
|
||||
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths)))
|
||||
filenames = list(map(lambda x: os.path.basename(x), filepaths))
|
||||
|
||||
summary_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
"module_name": list(map(lambda module: module.__name__, modules)),
|
||||
"module_params": module_params,
|
||||
"execution_time": average_times,
|
||||
}
|
||||
)
|
||||
summary_df.to_csv(os.path.join(project_dir, "summary.csv"), index=False)
|
||||
return summary_df
|
||||
0
autorag/data/legacy/__init__.py
Normal file
0
autorag/data/legacy/__init__.py
Normal file
2
autorag/data/legacy/corpus/__init__.py
Normal file
2
autorag/data/legacy/corpus/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .langchain import langchain_documents_to_parquet
|
||||
from .llama_index import llama_documents_to_parquet, llama_text_node_to_parquet
|
||||
47
autorag/data/legacy/corpus/langchain.py
Normal file
47
autorag/data/legacy/corpus/langchain.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from autorag.data.utils.util import add_essential_metadata
|
||||
from autorag.utils.util import save_parquet_safe
|
||||
|
||||
|
||||
def langchain_documents_to_parquet(
|
||||
langchain_documents: List[Document],
|
||||
output_filepath: Optional[str] = None,
|
||||
upsert: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Langchain documents to corpus dataframe.
|
||||
Corpus dataframe will be saved to filepath(file_dir/filename) if given.
|
||||
Return corpus dataframe whether the filepath is given.
|
||||
You can use this method to create corpus.parquet after load and chunk using Llama Index.
|
||||
|
||||
:param langchain_documents: List of langchain documents.
|
||||
:param output_filepath: Optional filepath to save the parquet file.
|
||||
If None, the function will return the processed_data as pd.DataFrame, but do not save as parquet.
|
||||
File directory must exist. File extension must be .parquet
|
||||
:param upsert: If true, the function will overwrite the existing file if it exists.
|
||||
Default is False.
|
||||
:return: Corpus data as pd.DataFrame
|
||||
"""
|
||||
|
||||
corpus_df = pd.DataFrame(
|
||||
list(
|
||||
map(
|
||||
lambda doc: {
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"contents": doc.page_content,
|
||||
"metadata": add_essential_metadata(doc.metadata),
|
||||
},
|
||||
langchain_documents,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if output_filepath is not None:
|
||||
save_parquet_safe(corpus_df, output_filepath, upsert=upsert)
|
||||
|
||||
return corpus_df
|
||||
93
autorag/data/legacy/corpus/llama_index.py
Normal file
93
autorag/data/legacy/corpus/llama_index.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from autorag.data.utils.util import (
|
||||
add_essential_metadata,
|
||||
add_essential_metadata_llama_text_node,
|
||||
)
|
||||
from autorag.utils.util import save_parquet_safe
|
||||
|
||||
|
||||
def llama_documents_to_parquet(
|
||||
llama_documents: List[Document],
|
||||
output_filepath: Optional[str] = None,
|
||||
upsert: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Llama Index documents to corpus dataframe.
|
||||
Corpus dataframe will be saved to filepath(file_dir/filename) if given.
|
||||
Return corpus dataframe whether the filepath is given.
|
||||
You can use this method to create corpus.parquet after load and chunk using Llama Index.
|
||||
|
||||
:param llama_documents: List[Document]
|
||||
:param output_filepath: Optional filepath to save the parquet file.
|
||||
If None, the function will return the processed_data as pd.DataFrame, but do not save as parquet.
|
||||
File directory must exist. File extension must be .parquet
|
||||
:param upsert: If true, the function will overwrite the existing file if it exists.
|
||||
Default is False.
|
||||
:return: Corpus data as pd.DataFrame
|
||||
"""
|
||||
|
||||
doc_lst = pd.DataFrame(
|
||||
list(
|
||||
map(
|
||||
lambda doc: {
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"contents": doc.text,
|
||||
"metadata": add_essential_metadata(doc.metadata),
|
||||
},
|
||||
llama_documents,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
processed_df = pd.DataFrame(doc_lst)
|
||||
|
||||
if output_filepath is not None:
|
||||
save_parquet_safe(processed_df, output_filepath, upsert=upsert)
|
||||
|
||||
return processed_df
|
||||
|
||||
|
||||
def llama_text_node_to_parquet(
|
||||
text_nodes: List[TextNode],
|
||||
output_filepath: Optional[str] = None,
|
||||
upsert: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Llama Index text nodes to corpus dataframe.
|
||||
Corpus dataframe will be saved to filepath(file_dir/filename) if given.
|
||||
Return corpus dataframe whether the filepath is given.
|
||||
You can use this method to create corpus.parquet after load and chunk using Llama Index.
|
||||
|
||||
:param text_nodes: List of llama index text nodes.
|
||||
:param output_filepath: Optional filepath to save the parquet file.
|
||||
If None, the function will return the processed_data as pd.DataFrame, but do not save as parquet.
|
||||
File directory must exist. File extension must be .parquet
|
||||
:param upsert: If true, the function will overwrite the existing file if it exists.
|
||||
Default is False.
|
||||
:return: Corpus data as pd.DataFrame
|
||||
"""
|
||||
corpus_df = pd.DataFrame(
|
||||
list(
|
||||
map(
|
||||
lambda node: {
|
||||
"doc_id": node.node_id,
|
||||
"contents": node.text,
|
||||
"metadata": add_essential_metadata_llama_text_node(
|
||||
node.metadata, node.relationships
|
||||
),
|
||||
},
|
||||
text_nodes,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if output_filepath is not None:
|
||||
save_parquet_safe(corpus_df, output_filepath, upsert=upsert)
|
||||
|
||||
return corpus_df
|
||||
6
autorag/data/legacy/qacreation/__init__.py
Normal file
6
autorag/data/legacy/qacreation/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .base import make_single_content_qa, make_qa_with_existing_qa
|
||||
from .llama_index import (
|
||||
generate_qa_llama_index,
|
||||
generate_answers,
|
||||
generate_qa_llama_index_by_ratio,
|
||||
)
|
||||
239
autorag/data/legacy/qacreation/base.py
Normal file
239
autorag/data/legacy/qacreation/base.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
import chromadb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
import autorag
|
||||
from autorag.nodes.retrieval.vectordb import vectordb_ingest, vectordb_pure
|
||||
from autorag.utils.util import (
|
||||
save_parquet_safe,
|
||||
fetch_contents,
|
||||
get_event_loop,
|
||||
process_batch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def make_single_content_qa(
|
||||
corpus_df: pd.DataFrame,
|
||||
content_size: int,
|
||||
qa_creation_func: Callable,
|
||||
output_filepath: Optional[str] = None,
|
||||
upsert: bool = False,
|
||||
random_state: int = 42,
|
||||
cache_batch: int = 32,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Make single content (single-hop, single-document) QA dataset using given qa_creation_func.
|
||||
It generates a single content QA dataset, which means its retrieval ground truth will be only one.
|
||||
It is the most basic form of QA dataset.
|
||||
|
||||
:param corpus_df: The corpus dataframe to make QA dataset from.
|
||||
:param content_size: This function will generate QA dataset for the given number of contents.
|
||||
:param qa_creation_func: The function to create QA pairs.
|
||||
You can use like `generate_qa_llama_index` or `generate_qa_llama_index_by_ratio`.
|
||||
The input func must have `contents` parameter for the list of content string.
|
||||
:param output_filepath: Optional filepath to save the parquet file.
|
||||
If None, the function will return the processed_data as pd.DataFrame, but do not save as parquet.
|
||||
File directory must exist. File extension must be .parquet
|
||||
:param upsert: If true, the function will overwrite the existing file if it exists.
|
||||
Default is False.
|
||||
:param random_state: The random state for sampling corpus from the given corpus_df.
|
||||
:param cache_batch: The number of batches to use for caching the generated QA dataset.
|
||||
When the cache_batch size data is generated, the dataset will save to the designated output_filepath.
|
||||
If the cache_batch size is too small, the process time will be longer.
|
||||
:param kwargs: The keyword arguments for qa_creation_func.
|
||||
:return: QA dataset dataframe.
|
||||
You can save this as parquet file to use at AutoRAG.
|
||||
"""
|
||||
assert content_size > 0, "content_size must be greater than 0."
|
||||
if content_size > len(corpus_df):
|
||||
logger.warning(
|
||||
f"content_size {content_size} is larger than the corpus size {len(corpus_df)}. "
|
||||
"Setting content_size to the corpus size."
|
||||
)
|
||||
content_size = len(corpus_df)
|
||||
sampled_corpus = corpus_df.sample(n=content_size, random_state=random_state)
|
||||
sampled_corpus = sampled_corpus.reset_index(drop=True)
|
||||
|
||||
def make_query_generation_gt(row):
|
||||
return row["qa"]["query"], row["qa"]["generation_gt"]
|
||||
|
||||
qa_data = pd.DataFrame()
|
||||
for idx, i in tqdm(enumerate(range(0, len(sampled_corpus), cache_batch))):
|
||||
qa = qa_creation_func(
|
||||
contents=sampled_corpus["contents"].tolist()[i : i + cache_batch], **kwargs
|
||||
)
|
||||
|
||||
temp_qa_data = pd.DataFrame(
|
||||
{
|
||||
"qa": qa,
|
||||
"retrieval_gt": sampled_corpus["doc_id"].tolist()[i : i + cache_batch],
|
||||
}
|
||||
)
|
||||
temp_qa_data = temp_qa_data.explode("qa", ignore_index=True)
|
||||
temp_qa_data["qid"] = [str(uuid.uuid4()) for _ in range(len(temp_qa_data))]
|
||||
temp_qa_data[["query", "generation_gt"]] = temp_qa_data.apply(
|
||||
make_query_generation_gt, axis=1, result_type="expand"
|
||||
)
|
||||
temp_qa_data = temp_qa_data.drop(columns=["qa"])
|
||||
|
||||
temp_qa_data["retrieval_gt"] = temp_qa_data["retrieval_gt"].apply(
|
||||
lambda x: [[x]]
|
||||
)
|
||||
temp_qa_data["generation_gt"] = temp_qa_data["generation_gt"].apply(
|
||||
lambda x: [x]
|
||||
)
|
||||
|
||||
if idx == 0:
|
||||
qa_data = temp_qa_data
|
||||
else:
|
||||
qa_data = pd.concat([qa_data, temp_qa_data], ignore_index=True)
|
||||
if output_filepath is not None:
|
||||
save_parquet_safe(qa_data, output_filepath, upsert=upsert)
|
||||
|
||||
return qa_data
|
||||
|
||||
|
||||
def make_qa_with_existing_qa(
|
||||
corpus_df: pd.DataFrame,
|
||||
existing_query_df: pd.DataFrame,
|
||||
content_size: int,
|
||||
answer_creation_func: Optional[Callable] = None,
|
||||
exist_gen_gt: Optional[bool] = False,
|
||||
output_filepath: Optional[str] = None,
|
||||
embedding_model: str = "openai_embed_3_large",
|
||||
collection: Optional[chromadb.Collection] = None,
|
||||
upsert: bool = False,
|
||||
random_state: int = 42,
|
||||
cache_batch: int = 32,
|
||||
top_k: int = 3,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Make single-hop QA dataset using given qa_creation_func and existing queries.
|
||||
|
||||
:param corpus_df: The corpus dataframe to make QA dataset from.
|
||||
:param existing_query_df: Dataframe containing existing queries to use for QA pair creation.
|
||||
:param content_size: This function will generate QA dataset for the given number of contents.
|
||||
:param answer_creation_func: Optional function to create answer with input query.
|
||||
If exist_gen_gt is False, this function must be given.
|
||||
:param exist_gen_gt: Optional boolean to use existing generation_gt.
|
||||
If True, the existing_query_df must have 'generation_gt' column.
|
||||
If False, the answer_creation_func must be given.
|
||||
:param output_filepath: Optional filepath to save the parquet file.
|
||||
:param embedding_model: The embedding model to use for vectorization.
|
||||
You can add your own embedding model in the autorag.embedding_models.
|
||||
Please refer to how to add an embedding model in this doc: https://docs.auto-rag.com/local_model.html
|
||||
The default is 'openai_embed_3_large'.
|
||||
:param collection: The chromadb collection to use for vector DB.
|
||||
You can make any chromadb collection and use it here.
|
||||
If you already ingested the corpus_df to the collection, the embedding process will not be repeated.
|
||||
The default is None. If None, it makes a temporary collection.
|
||||
:param upsert: If true, the function will overwrite the existing file if it exists.
|
||||
:param random_state: The random state for sampling corpus from the given corpus_df.
|
||||
:param cache_batch: The number of batches to use for caching the generated QA dataset.
|
||||
:param top_k: The number of sources to refer by model.
|
||||
Default is 3.
|
||||
:param kwargs: The keyword arguments for qa_creation_func.
|
||||
:return: QA dataset dataframe.
|
||||
"""
|
||||
raise DeprecationWarning("This function is deprecated.")
|
||||
assert (
|
||||
"query" in existing_query_df.columns
|
||||
), "existing_query_df must have 'query' column."
|
||||
|
||||
if exist_gen_gt:
|
||||
assert (
|
||||
"generation_gt" in existing_query_df.columns
|
||||
), "existing_query_df must have 'generation_gt' column."
|
||||
else:
|
||||
assert (
|
||||
answer_creation_func is not None
|
||||
), "answer_creation_func must be given when exist_gen_gt is False."
|
||||
|
||||
assert content_size > 0, "content_size must be greater than 0."
|
||||
if content_size > len(corpus_df):
|
||||
logger.warning(
|
||||
f"content_size {content_size} is larger than the corpus size {len(corpus_df)}. "
|
||||
"Setting content_size to the corpus size."
|
||||
)
|
||||
content_size = len(corpus_df)
|
||||
|
||||
logger.info("Loading local embedding model...")
|
||||
embeddings = autorag.embedding_models[embedding_model]()
|
||||
|
||||
# Vector DB creation
|
||||
if collection is None:
|
||||
chroma_client = chromadb.Client()
|
||||
collection_name = "auto-rag"
|
||||
collection = chroma_client.get_or_create_collection(collection_name)
|
||||
|
||||
# embed corpus_df
|
||||
vectordb_ingest(collection, corpus_df, embeddings)
|
||||
query_embeddings = embeddings.get_text_embedding_batch(
|
||||
existing_query_df["query"].tolist()
|
||||
)
|
||||
|
||||
loop = get_event_loop()
|
||||
tasks = [
|
||||
vectordb_pure([query_embedding], top_k, collection)
|
||||
for query_embedding in query_embeddings
|
||||
]
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size=cache_batch))
|
||||
retrieved_ids = list(map(lambda x: x[0], results))
|
||||
|
||||
retrieved_contents: List[List[str]] = fetch_contents(corpus_df, retrieved_ids)
|
||||
input_passage_strs: List[str] = list(
|
||||
map(
|
||||
lambda x: "\n".join(
|
||||
[f"Document {i + 1}\n{content}" for i, content in enumerate(x)]
|
||||
),
|
||||
retrieved_contents,
|
||||
)
|
||||
)
|
||||
|
||||
retrieved_qa_df = pd.DataFrame(
|
||||
{
|
||||
"qid": [str(uuid.uuid4()) for _ in range(len(existing_query_df))],
|
||||
"query": existing_query_df["query"].tolist(),
|
||||
"retrieval_gt": list(map(lambda x: [x], retrieved_ids)),
|
||||
"input_passage_str": input_passage_strs,
|
||||
}
|
||||
)
|
||||
|
||||
if exist_gen_gt:
|
||||
generation_gt = existing_query_df["generation_gt"].tolist()
|
||||
if isinstance(generation_gt[0], np.ndarray):
|
||||
retrieved_qa_df["generation_gt"] = generation_gt
|
||||
else:
|
||||
raise ValueError(
|
||||
"In existing_query_df, generation_gt (per query) must be in the form of List[str]."
|
||||
)
|
||||
|
||||
sample_qa_df = retrieved_qa_df.sample(
|
||||
n=min(content_size, len(retrieved_qa_df)), random_state=random_state
|
||||
)
|
||||
|
||||
qa_df = sample_qa_df.copy(deep=True)
|
||||
qa_df.drop(columns=["input_passage_str"], inplace=True)
|
||||
|
||||
if not exist_gen_gt:
|
||||
generation_gt = answer_creation_func(
|
||||
contents=sample_qa_df["input_passage_str"].tolist(),
|
||||
queries=sample_qa_df["query"].tolist(),
|
||||
batch=cache_batch,
|
||||
**kwargs,
|
||||
)
|
||||
qa_df["generation_gt"] = generation_gt
|
||||
|
||||
if output_filepath is not None:
|
||||
save_parquet_safe(qa_df, output_filepath, upsert=upsert)
|
||||
|
||||
return qa_df
|
||||
253
autorag/data/legacy/qacreation/llama_index.py
Normal file
253
autorag/data/legacy/qacreation/llama_index.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os.path
|
||||
import random
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
import pandas as pd
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.core.llms import LLM
|
||||
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
|
||||
package_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def generate_qa_llama_index(
|
||||
llm: LLM,
|
||||
contents: List[str],
|
||||
prompt: Optional[str] = None,
|
||||
question_num_per_content: int = 1,
|
||||
max_retries: int = 3,
|
||||
batch: int = 4,
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Generate a qa set from the list of contents.
|
||||
It uses a single prompt for all contents.
|
||||
If you want to use more than one prompt for generating qa,
|
||||
you can consider using generate_qa_llama_index_by_ratio.
|
||||
|
||||
:param llm: Llama index model
|
||||
:param contents: List of content strings.
|
||||
:param prompt: The prompt to use for the qa generation.
|
||||
The prompt must include the following placeholders:
|
||||
- {{text}}: The content string
|
||||
- {{num_questions}}: The number of questions to generate
|
||||
As default, the prompt is set to the default prompt for the question type.
|
||||
:param question_num_per_content: Number of questions to generate for each content.
|
||||
Default is 1.
|
||||
:param max_retries: The maximum number of retries when generated question number is not equal to the target number.
|
||||
Default is 3.
|
||||
:param batch: The batch size to process asynchronously.
|
||||
Default is 4.
|
||||
:return: 2-d list of dictionaries containing the query and generation_gt.
|
||||
"""
|
||||
# load default prompt
|
||||
if prompt is None:
|
||||
prompt = open(
|
||||
os.path.join(package_dir, "llama_index_default_prompt.txt"), "r"
|
||||
).read()
|
||||
|
||||
tasks = [
|
||||
async_qa_gen_llama_index(
|
||||
content, llm, prompt, question_num_per_content, max_retries
|
||||
)
|
||||
for content in contents
|
||||
]
|
||||
loops = get_event_loop()
|
||||
results = loops.run_until_complete(process_batch(tasks, batch))
|
||||
return results
|
||||
|
||||
|
||||
def generate_answers(
|
||||
llm: LLM,
|
||||
contents: List[str],
|
||||
queries: List[str],
|
||||
batch: int = 4,
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Generate qa sets from the list of contents using existing queries.
|
||||
|
||||
:param llm: Llama index model
|
||||
:param contents: List of content strings.
|
||||
:param queries: List of existing queries.
|
||||
:param batch: The batch size to process asynchronously.
|
||||
:return: 2-d list of dictionaries containing the query and generation_gt.
|
||||
"""
|
||||
|
||||
tasks = [
|
||||
generate_basic_answer(llm, content, query)
|
||||
for content, query in zip(contents, queries)
|
||||
]
|
||||
loops = get_event_loop()
|
||||
results = loops.run_until_complete(process_batch(tasks, batch))
|
||||
return results
|
||||
|
||||
|
||||
def generate_qa_llama_index_by_ratio(
|
||||
llm: LLM,
|
||||
contents: List[str],
|
||||
prompts_ratio: Dict,
|
||||
question_num_per_content: int = 1,
|
||||
max_retries: int = 3,
|
||||
random_state: int = 42,
|
||||
batch: int = 4,
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Generate a qa set from the list of contents.
|
||||
You can set the ratio of prompts that you want to use for generating qa.
|
||||
It distributes the number of questions to generate for each content by the ratio randomly.
|
||||
|
||||
:param llm: Llama index model
|
||||
:param contents: List of content strings.
|
||||
:param prompts_ratio: Dictionary of prompt paths and their ratios.
|
||||
Example: {"prompt/prompt1.txt": 0.5, "prompt/prompt2.txt": 0.5}
|
||||
The value sum doesn't have to be 1.
|
||||
The path must be the absolute path, and the file must exist.
|
||||
Plus, it has to be a text file which contains proper prompt.
|
||||
Each prompt must contain the following placeholders:
|
||||
- {{text}}: The content string
|
||||
- {{num_questions}}: The number of questions to generate
|
||||
:param question_num_per_content: Number of questions to generate for each content.
|
||||
Default is 1.
|
||||
:param max_retries: The maximum number of retries when generated question number is not equal to the target number.
|
||||
Default is 3.
|
||||
:param random_state: Random seed
|
||||
Default is 42.
|
||||
:param batch: The batch size to process asynchronously.
|
||||
Default is 4.
|
||||
:return: 2-d list of dictionaries containing the query and generation_gt.
|
||||
"""
|
||||
prompts = list(map(lambda path: open(path, "r").read(), prompts_ratio.keys()))
|
||||
assert all([validate_llama_index_prompt(prompt) for prompt in prompts])
|
||||
|
||||
content_indices = list(range(len(contents)))
|
||||
random.seed(random_state)
|
||||
random.shuffle(content_indices)
|
||||
|
||||
slice_content_indices: List[List[str]] = distribute_list_by_ratio(
|
||||
content_indices, list(prompts_ratio.values())
|
||||
)
|
||||
temp_df = pd.DataFrame({"idx": slice_content_indices, "prompt": prompts})
|
||||
temp_df = temp_df.explode("idx", ignore_index=True)
|
||||
temp_df = temp_df.sort_values(by="idx", ascending=True)
|
||||
|
||||
final_df = pd.DataFrame({"content": contents, "prompt": temp_df["prompt"].tolist()})
|
||||
|
||||
tasks = [
|
||||
async_qa_gen_llama_index(
|
||||
content, llm, prompt, question_num_per_content, max_retries
|
||||
)
|
||||
for content, prompt in zip(
|
||||
final_df["content"].tolist(), final_df["prompt"].tolist()
|
||||
)
|
||||
]
|
||||
|
||||
loops = get_event_loop()
|
||||
results = loops.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def async_qa_gen_llama_index(
|
||||
content: str,
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
question_num: int = 1,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""
|
||||
Generate a qa set by using the given content and the llama index model.
|
||||
You must select the question type.
|
||||
|
||||
:param content: Content string
|
||||
:param llm: Llama index model
|
||||
:param prompt: The prompt to use for the qa generation.
|
||||
The prompt must include the following placeholders:
|
||||
- {{text}}: The content string
|
||||
- {{num_questions}}: The number of questions to generate
|
||||
:param question_num: The number of questions to generate
|
||||
:param max_retries: Maximum number of retries when generated question number is not equal to the target number
|
||||
:return: List of dictionaries containing the query and generation_gt
|
||||
"""
|
||||
validate_llama_index_prompt(prompt)
|
||||
|
||||
async def generate(content: str, llm: LLM):
|
||||
for _ in range(max_retries):
|
||||
output = await llm.acomplete(
|
||||
prompt.replace("{{text}}", content).replace(
|
||||
"{{num_questions}}", str(question_num)
|
||||
)
|
||||
)
|
||||
result = parse_output(output.text)
|
||||
if len(result) == question_num:
|
||||
return result
|
||||
raise InterruptedError(
|
||||
f"Failed to generate output of length {question_num} after {max_retries} retries."
|
||||
)
|
||||
|
||||
return await generate(content, llm)
|
||||
|
||||
|
||||
async def generate_basic_answer(llm: LLM, passage_str: str, query: str) -> str:
|
||||
basic_answer_system_prompt = """You are an AI assistant to answer the given question in the provide evidence text.
|
||||
You can find the evidence from the given text about question, and you have to write a proper answer to the given question.
|
||||
You have to preserve the question's language at the answer.
|
||||
For example, if the input question is Korean, the output answer must be in Korean.
|
||||
"""
|
||||
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:"
|
||||
|
||||
response = await llm.achat(
|
||||
messages=[
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=basic_answer_system_prompt),
|
||||
ChatMessage(role=MessageRole.USER, content=user_prompt),
|
||||
],
|
||||
temperature=1.0,
|
||||
)
|
||||
return response.message.content
|
||||
|
||||
|
||||
def validate_llama_index_prompt(prompt: str) -> bool:
|
||||
"""
|
||||
Validate the prompt for the llama index model.
|
||||
The prompt must include the following placeholders:
|
||||
- {{text}}: The content string
|
||||
- {{num_questions}}: The number of questions to generate
|
||||
"""
|
||||
if "{{text}}" not in prompt:
|
||||
raise ValueError("The prompt must include the placeholder {{text}}.")
|
||||
if "{{num_questions}}" not in prompt:
|
||||
raise ValueError("The prompt must include the placeholder {{num_questions}}.")
|
||||
return True
|
||||
|
||||
|
||||
def parse_output(result: str) -> List[Dict]:
|
||||
result = result.strip()
|
||||
result = result.split("[Q]:")
|
||||
final_result = list()
|
||||
for res in result:
|
||||
res = res.strip()
|
||||
if res and "\n[A]:" in res:
|
||||
qa = res.split("\n[A]:")
|
||||
final_result.append(
|
||||
{"query": qa[0].strip(), "generation_gt": qa[1].strip()}
|
||||
)
|
||||
return final_result
|
||||
|
||||
|
||||
def distribute_list_by_ratio(input_list, ratio) -> List[List[Any]]:
|
||||
total_ratio = sum(ratio)
|
||||
total_length = len(input_list)
|
||||
|
||||
# Calculate the length of each slice
|
||||
slice_lengths = [int((r / total_ratio) * total_length) for r in ratio]
|
||||
|
||||
# Adjust the last slice in case of rounding issues
|
||||
slice_lengths[-1] = total_length - sum(slice_lengths[:-1])
|
||||
|
||||
slices = []
|
||||
start = 0
|
||||
for length in slice_lengths:
|
||||
end = start + length
|
||||
slices.append(input_list[start:end])
|
||||
start = end
|
||||
|
||||
return slices
|
||||
@@ -0,0 +1,54 @@
|
||||
You're an AI tasked to convert Text into a question and answer set.
|
||||
Cover as many details from Text as possible in the QnA set.
|
||||
|
||||
Instructions:
|
||||
1. Both Questions and Answers MUST BE extracted from given Text
|
||||
2. Answers must be full sentences
|
||||
3. Questions should be as detailed as possible from Text
|
||||
4. Output must always have the provided number of QnAs
|
||||
5. Create questions that ask about information from the Text
|
||||
6. MUST include specific keywords from the Text.
|
||||
7. Do not mention any of these in the questions: "in the given text", "in the provided information", etc.
|
||||
|
||||
Question examples:
|
||||
1. How do owen and riggs know each other?
|
||||
2. What does the word fore "mean" in golf?
|
||||
3. What makes charging bull in nyc popular to tourists?
|
||||
4. What kind of pistol does the army use?
|
||||
5. Who was the greatest violin virtuoso in the romantic period?
|
||||
<|separator|>
|
||||
|
||||
Text:
|
||||
<|text_start|>
|
||||
Mark Hamill as Luke Skywalker : One of the last living Jedi , trained by Obi - Wan and Yoda , who is also a skilled X-wing fighter pilot allied with the Rebellion .
|
||||
Harrison Ford as Han Solo : A rogue smuggler , who aids the Rebellion against the Empire . Han is Luke and Leia 's friend , as well as Leia 's love interest .
|
||||
Carrie Fisher as Leia Organa : The former Princess of the destroyed planet Alderaan , who joins the Rebellion ; Luke 's twin sister , and Han 's love interest .
|
||||
Billy Dee Williams as Lando Calrissian : The former Baron Administrator of Cloud City and one of Han 's friends who aids the Rebellion .
|
||||
Anthony Daniels as C - 3PO : A humanoid protocol droid , who sides with the Rebellion .
|
||||
Peter Mayhew as Chewbacca : A Wookiee who is Han 's longtime friend , who takes part in the Rebellion .
|
||||
Kenny Baker as R2 - D2 : An astromech droid , bought by Luke ; and long - time friend to C - 3PO . He also portrays a GONK power droid in the background .
|
||||
Ian McDiarmid as the Emperor : The evil founding supreme ruler of the Galactic Empire , and Vader 's Sith Master .
|
||||
Frank Oz as Yoda : The wise , centuries - old Grand Master of the Jedi , who is Luke 's self - exiled Jedi Master living on Dagobah . After dying , he reappears to Luke as a Force - ghost . Yoda 's Puppetry was assisted by Mike Quinn .
|
||||
David Prowse as Darth Vader / Anakin Skywalker : A powerful Sith lord and the second in command of the Galactic Empire ; Luke and Leia 's father .
|
||||
<|text_end|>
|
||||
Output with 4 QnAs:
|
||||
<|separator|>
|
||||
|
||||
[Q]: who played luke father in return of the jedi
|
||||
[A]: David Prowse acted as Darth Vader, a.k.a Anakin Skywalker, which is Luke and Leia's father.
|
||||
[Q]: Who is Han Solo's best friend? And what species is he?
|
||||
[A]: Han Solo's best friend is Chewbacca, who is a Wookiee.
|
||||
[Q]: Who played luke's teacher in the return of the jedi
|
||||
[A]: Yoda, the wise, centuries-old Grand Master of the Jedi, who is Luke's self-exiled Jedi Master living on Dagobah, was played by Frank Oz.
|
||||
Also, there is a mention of Obi-Wan Kenobi, who trained Luke Skywalker.
|
||||
But I can't find who played Obi-Wan Kenobi in the given text.
|
||||
[Q]: Where Yoda lives in the return of the jedi?
|
||||
[A]: Yoda, the Jedi Master, lives on Dagobah.
|
||||
<|separator|>
|
||||
|
||||
Text:
|
||||
<|text_start|>
|
||||
{{text}}
|
||||
<|text_end|>
|
||||
Output with {{num_questions}} QnAs:
|
||||
<|separator|>
|
||||
75
autorag/data/legacy/qacreation/ragas.py
Normal file
75
autorag/data/legacy/qacreation/ragas.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
|
||||
from autorag.data.utils.util import corpus_df_to_langchain_documents
|
||||
from autorag.utils import cast_qa_dataset
|
||||
|
||||
|
||||
def generate_qa_ragas(
|
||||
corpus_df: pd.DataFrame,
|
||||
test_size: int,
|
||||
distributions: Optional[dict] = None,
|
||||
generator_llm: Optional[BaseChatModel] = None,
|
||||
critic_llm: Optional[BaseChatModel] = None,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
QA dataset generation using RAGAS.
|
||||
Returns qa dataset dataframe.
|
||||
|
||||
:param corpus_df: Corpus dataframe.
|
||||
:param test_size: Number of queries to generate.
|
||||
:param distributions: Distributions of different types of questions.
|
||||
Default is "simple is 0.5, multi_context is 0.4, and reasoning is 0.1."
|
||||
Each type of questions refers to Ragas evolution types.
|
||||
:param generator_llm: Generator language model from Langchain.
|
||||
:param critic_llm: Critic language model from Langchain.
|
||||
:param embedding_model: Embedding model from Langchain.
|
||||
:param kwargs: The additional option to pass to the 'generate_with_langchain_docs' method.
|
||||
You can input 'with_debugging_logs', 'is_async', 'raise_exceptions', and 'run_config'.
|
||||
:return: QA dataset dataframe.
|
||||
"""
|
||||
from ragas.testset import TestsetGenerator
|
||||
from ragas.testset.evolutions import simple, reasoning, multi_context
|
||||
|
||||
if generator_llm is None:
|
||||
generator_llm = ChatOpenAI(model="gpt-3.5-turbo-16k")
|
||||
if critic_llm is None:
|
||||
critic_llm = ChatOpenAI(model="gpt-4-turbo")
|
||||
if embedding_model is None:
|
||||
embedding_model = OpenAIEmbeddings()
|
||||
if distributions is None:
|
||||
distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1}
|
||||
|
||||
assert sum(list(distributions.values())) == 1.0, "Sum of distributions must be 1.0"
|
||||
|
||||
generator = TestsetGenerator.from_langchain(
|
||||
generator_llm, critic_llm, embedding_model
|
||||
)
|
||||
|
||||
langchain_docs = corpus_df_to_langchain_documents(corpus_df)
|
||||
|
||||
test_df = generator.generate_with_langchain_docs(
|
||||
langchain_docs, test_size, distributions=distributions, **kwargs
|
||||
).to_pandas()
|
||||
|
||||
result_df = pd.DataFrame(
|
||||
{
|
||||
"qid": [str(uuid.uuid4()) for _ in range(len(test_df))],
|
||||
"query": test_df["question"].tolist(),
|
||||
"generation_gt": list(map(lambda x: x, test_df["ground_truth"].tolist())),
|
||||
}
|
||||
)
|
||||
|
||||
result_df["retrieval_gt"] = test_df["metadata"].apply(
|
||||
lambda x: list(map(lambda y: y["filename"], x))
|
||||
)
|
||||
result_df = cast_qa_dataset(result_df)
|
||||
|
||||
return result_df
|
||||
99
autorag/data/legacy/qacreation/simple.py
Normal file
99
autorag/data/legacy/qacreation/simple.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def generate_qa_row(llm, corpus_data_row):
|
||||
"""
|
||||
this sample code to generate rag dataset using OpenAI chat model
|
||||
|
||||
:param llm: guidance model
|
||||
:param corpus_data_row: need "contents" column
|
||||
:return: should to be dict which has "query", "generation_gt" columns at least.
|
||||
"""
|
||||
from guidance import gen
|
||||
import guidance
|
||||
|
||||
temp_llm = llm
|
||||
with guidance.user():
|
||||
temp_llm += f"""
|
||||
You have to found a passge to solve "the problem".
|
||||
You need to build a clean and clear set of (problem, passage, answer) in json format
|
||||
so that you don't have to ask about "the problem" again.
|
||||
problem need to end with question mark("?").
|
||||
The process of approaching the answer based on the information of the given passage
|
||||
must be clearly and neatly displayed in the answer.\n
|
||||
\n
|
||||
Here is set of (problem, passage, answer) in JSON format:\n
|
||||
{{\n
|
||||
"passage": {corpus_data_row["contents"]}\n
|
||||
"problem":
|
||||
"""
|
||||
|
||||
with guidance.assistant():
|
||||
temp_llm += gen("query", stop="?")
|
||||
with guidance.user():
|
||||
temp_llm += """
|
||||
"answer":
|
||||
"""
|
||||
with guidance.assistant():
|
||||
temp_llm += gen("generation_gt")
|
||||
|
||||
corpus_data_row["metadata"]["qa_generation"] = "simple"
|
||||
|
||||
response = {"query": temp_llm["query"], "generation_gt": temp_llm["generation_gt"]}
|
||||
return response
|
||||
|
||||
|
||||
def generate_simple_qa_dataset(
|
||||
llm,
|
||||
corpus_data: pd.DataFrame,
|
||||
output_filepath: str,
|
||||
generate_row_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
corpus_data to qa_dataset
|
||||
qa_dataset will be saved to filepath(file_dir/filename)
|
||||
|
||||
:param llm: guidance.models.Model
|
||||
:param corpus_data: pd.DataFrame. refer to the basic structure
|
||||
:param output_filepath: file_dir must exist, filepath must not exist. file extension must be .parquet
|
||||
:param generate_row_function: input(llm, corpus_data_row, kwargs) output(dict[columns contain "query" and "generation_gt"])
|
||||
:param kwargs: if generate_row_function requires more args, use kwargs
|
||||
:return: qa_dataset as pd.DataFrame
|
||||
"""
|
||||
output_file_dir = pathlib.PurePath(output_filepath).parent
|
||||
if not os.path.isdir(output_file_dir):
|
||||
raise NotADirectoryError(f"directory {output_file_dir} not found.")
|
||||
if not output_filepath.endswith("parquet"):
|
||||
raise NameError(
|
||||
f'file path: {output_filepath} filename extension need to be ".parquet"'
|
||||
)
|
||||
if os.path.exists(output_filepath):
|
||||
raise FileExistsError(
|
||||
f"{output_filepath.split('/')[-1]} already exists in {output_file_dir}."
|
||||
)
|
||||
|
||||
qa_data_lst = []
|
||||
for _, corpus_data_row in corpus_data.iterrows():
|
||||
response = generate_row_function(
|
||||
llm=llm, corpus_data_row=corpus_data_row, **kwargs
|
||||
)
|
||||
qa_data_lst.append(
|
||||
{
|
||||
"qid": str(uuid.uuid4()),
|
||||
"query": response["query"],
|
||||
"retrieval_gt": [[corpus_data_row["doc_id"]]],
|
||||
"generation_gt": [response["generation_gt"]],
|
||||
"metadata": corpus_data_row["metadata"],
|
||||
}
|
||||
)
|
||||
|
||||
qa_dataset = pd.DataFrame(qa_data_lst)
|
||||
qa_dataset.to_parquet(output_filepath, index=False)
|
||||
|
||||
return qa_dataset
|
||||
1
autorag/data/parse/__init__.py
Normal file
1
autorag/data/parse/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .langchain_parse import langchain_parse
|
||||
79
autorag/data/parse/base.py
Normal file
79
autorag/data/parse/base.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from glob import glob
|
||||
from typing import Tuple, List, Optional
|
||||
import os
|
||||
|
||||
from autorag.utils import result_to_dataframe
|
||||
from autorag.data.utils.util import get_file_metadata
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def parser_node(func):
|
||||
@functools.wraps(func)
|
||||
@result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
|
||||
def wrapper(
|
||||
data_path_glob: str,
|
||||
file_type: str,
|
||||
parse_method: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[str], List[str], List[int], List[datetime]]:
|
||||
logger.info(f"Running parser - {func.__name__} module...")
|
||||
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
raise FileNotFoundError(f"data does not exits in {data_path_glob}")
|
||||
|
||||
assert file_type in [
|
||||
"pdf",
|
||||
"csv",
|
||||
"json",
|
||||
"md",
|
||||
"html",
|
||||
"xml",
|
||||
"all_files",
|
||||
], f"search type {file_type} is not supported"
|
||||
|
||||
# extract only files from data_path_list based on the file_type set in the YAML file
|
||||
data_paths = (
|
||||
[
|
||||
data_path
|
||||
for data_path in data_path_list
|
||||
if os.path.basename(data_path).split(".")[-1] == file_type
|
||||
]
|
||||
if file_type != "all_files"
|
||||
else data_path_list
|
||||
)
|
||||
|
||||
if func.__name__ == "langchain_parse":
|
||||
parse_method = parse_method.lower()
|
||||
if parse_method == "directory":
|
||||
path_split_list = data_path_glob.split("/")
|
||||
glob_path = path_split_list.pop()
|
||||
folder_path = "/".join(path_split_list)
|
||||
kwargs.update({"glob": glob_path, "path": folder_path})
|
||||
result = func(
|
||||
data_path_list=data_paths, parse_method=parse_method, **kwargs
|
||||
)
|
||||
else:
|
||||
result = func(
|
||||
data_path_list=data_paths, parse_method=parse_method, **kwargs
|
||||
)
|
||||
elif func.__name__ in ["clova_ocr", "llama_parse", "table_hybrid_parse"]:
|
||||
result = func(data_path_list=data_paths, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported module_type: {func.__name__}")
|
||||
result = _add_last_modified_datetime(result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _add_last_modified_datetime(result):
|
||||
last_modified_datetime_lst = list(
|
||||
map(lambda x: get_file_metadata(x)["last_modified_datetime"], result[1])
|
||||
)
|
||||
result_with_dates = result + (last_modified_datetime_lst,)
|
||||
return result_with_dates
|
||||
194
autorag/data/parse/clova.py
Normal file
194
autorag/data/parse/clova.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import base64
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import fitz # PyMuPDF
|
||||
|
||||
from autorag.data.parse.base import parser_node
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
|
||||
|
||||
@parser_node
|
||||
def clova_ocr(
|
||||
data_path_list: List[str],
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
batch: int = 5,
|
||||
table_detection: bool = False,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use Naver Clova OCR.
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param url: The URL for Clova OCR.
|
||||
You can get the URL with the guide at https://guide.ncloud-docs.com/docs/clovaocr-example01
|
||||
You can set the environment variable CLOVA_URL, or you can set it directly as a parameter.
|
||||
:param api_key: The API key for Clova OCR.
|
||||
You can get the API key with the guide at https://guide.ncloud-docs.com/docs/clovaocr-example01
|
||||
You can set the environment variable CLOVA_API_KEY, or you can set it directly as a parameter.
|
||||
:param batch: The batch size for parse documents. Default is 8.
|
||||
:param table_detection: Whether to enable table detection. Default is False.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
url = os.getenv("CLOVA_URL", None) if url is None else url
|
||||
if url is None:
|
||||
raise KeyError(
|
||||
"Please set the URL for Clova OCR in the environment variable CLOVA_URL "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
|
||||
api_key = os.getenv("CLOVA_API_KEY", None) if api_key is None else api_key
|
||||
if api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the API key for Clova OCR in the environment variable CLOVA_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
if batch > 5:
|
||||
raise ValueError("The batch size should be less than or equal to 5.")
|
||||
|
||||
image_data_lst = list(
|
||||
map(lambda data_path: pdf_to_images(data_path), data_path_list)
|
||||
)
|
||||
image_info_lst = [
|
||||
generate_image_info(pdf_path, len(image_data))
|
||||
for pdf_path, image_data in zip(data_path_list, image_data_lst)
|
||||
]
|
||||
|
||||
image_data_list = list(itertools.chain(*image_data_lst))
|
||||
image_info_list = list(itertools.chain(*image_info_lst))
|
||||
|
||||
tasks = [
|
||||
clova_ocr_pure(image_data, image_info, url, api_key, table_detection)
|
||||
for image_data, image_info in zip(image_data_list, image_info_list)
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
texts, path, pages = zip(*results)
|
||||
return list(texts), list(path), list(pages)
|
||||
|
||||
|
||||
async def clova_ocr_pure(
|
||||
image_data: bytes,
|
||||
image_info: dict,
|
||||
url: str,
|
||||
api_key: str,
|
||||
table_detection: bool = False,
|
||||
) -> Tuple[str, str, int]:
|
||||
session = aiohttp.ClientSession()
|
||||
table_html = ""
|
||||
headers = {"X-OCR-SECRET": api_key, "Content-Type": "application/json"}
|
||||
|
||||
# Convert image data to base64
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Set data
|
||||
data = {
|
||||
"version": "V2",
|
||||
"requestId": "sample_id",
|
||||
"timestamp": 0,
|
||||
"images": [{"format": "png", "name": "sample_image", "data": image_base64}],
|
||||
"enableTableDetection": table_detection,
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, data=json.dumps(data)) as response:
|
||||
resp_json = await response.json()
|
||||
if "images" not in resp_json:
|
||||
raise RuntimeError(
|
||||
f"Invalid response from Clova API: {resp_json['detail']}"
|
||||
)
|
||||
if "tables" in resp_json["images"][0].keys():
|
||||
table_html = json_to_html_table(
|
||||
resp_json["images"][0]["tables"][0]["cells"]
|
||||
)
|
||||
page_text = extract_text_from_fields(resp_json["images"][0]["fields"])
|
||||
|
||||
if table_html:
|
||||
page_text += f"\n\ntable html:\n{table_html}"
|
||||
|
||||
await session.close()
|
||||
return page_text, image_info["pdf_path"], image_info["pdf_page"]
|
||||
|
||||
|
||||
def pdf_to_images(pdf_path: str) -> List[bytes]:
|
||||
"""Convert each page of the PDF to an image and return the image data."""
|
||||
pdf_document = fitz.open(pdf_path)
|
||||
image_data_lst = []
|
||||
for page_num in range(len(pdf_document)):
|
||||
page = pdf_document.load_page(page_num)
|
||||
pix = page.get_pixmap()
|
||||
img_data = pix.tobytes("png")
|
||||
image_data_lst.append(img_data)
|
||||
return image_data_lst
|
||||
|
||||
|
||||
def generate_image_info(pdf_path: str, num_pages: int) -> List[dict]:
|
||||
"""Generate image names based on the PDF file name and the number of pages."""
|
||||
image_info_lst = [
|
||||
{"pdf_path": pdf_path, "pdf_page": page_num + 1}
|
||||
for page_num in range(num_pages)
|
||||
]
|
||||
return image_info_lst
|
||||
|
||||
|
||||
def extract_text_from_fields(fields):
|
||||
text = ""
|
||||
for field in fields:
|
||||
text += field["inferText"]
|
||||
if field["lineBreak"]:
|
||||
text += "\n"
|
||||
else:
|
||||
text += " "
|
||||
return text.strip()
|
||||
|
||||
|
||||
def json_to_html_table(json_data):
|
||||
# Initialize the HTML table
|
||||
html = '<table border="1">\n'
|
||||
# Determine the number of rows and columns
|
||||
max_row = max(cell["rowIndex"] + cell["rowSpan"] for cell in json_data)
|
||||
max_col = max(cell["columnIndex"] + cell["columnSpan"] for cell in json_data)
|
||||
# Create a 2D array to keep track of merged cells
|
||||
table = [["" for _ in range(max_col)] for _ in range(max_row)]
|
||||
# Fill the table with cell data
|
||||
for cell in json_data:
|
||||
row = cell["rowIndex"]
|
||||
col = cell["columnIndex"]
|
||||
row_span = cell["rowSpan"]
|
||||
col_span = cell["columnSpan"]
|
||||
cell_text = (
|
||||
" ".join(
|
||||
line["inferText"] for line in cell["cellTextLines"][0]["cellWords"]
|
||||
)
|
||||
if cell["cellTextLines"]
|
||||
else ""
|
||||
)
|
||||
# Place the cell in the table
|
||||
table[row][col] = {"text": cell_text, "rowSpan": row_span, "colSpan": col_span}
|
||||
# Mark merged cells as occupied
|
||||
for r in range(row, row + row_span):
|
||||
for c in range(col, col + col_span):
|
||||
if r != row or c != col:
|
||||
table[r][c] = None
|
||||
# Generate HTML from the table array
|
||||
for row in table:
|
||||
html += " <tr>\n"
|
||||
for cell in row:
|
||||
if cell is None:
|
||||
continue
|
||||
if cell == "":
|
||||
html += " <td></td>\n"
|
||||
else:
|
||||
row_span_attr = (
|
||||
f' rowspan="{cell["rowSpan"]}"' if cell["rowSpan"] > 1 else ""
|
||||
)
|
||||
col_span_attr = (
|
||||
f' colspan="{cell["colSpan"]}"' if cell["colSpan"] > 1 else ""
|
||||
)
|
||||
html += f' <td{row_span_attr}{col_span_attr}>{cell["text"]}</td>\n'
|
||||
html += " </tr>\n"
|
||||
html += "</table>"
|
||||
return html
|
||||
87
autorag/data/parse/langchain_parse.py
Normal file
87
autorag/data/parse/langchain_parse.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import multiprocessing as mp
|
||||
from itertools import chain
|
||||
from typing import List, Tuple
|
||||
|
||||
from autorag.data import parse_modules
|
||||
from autorag.data.parse.base import parser_node
|
||||
|
||||
|
||||
@parser_node
|
||||
def langchain_parse(
|
||||
data_path_list: List[str], parse_method: str, **kwargs
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use langchain document_loaders(parse) method
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param parse_method: A langchain document_loaders(parse) method to use.
|
||||
:param kwargs: The extra parameters for creating the langchain document_loaders(parse) instance.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
if parse_method in ["directory", "unstructured"]:
|
||||
results = parse_all_files(data_path_list, parse_method, **kwargs)
|
||||
texts, path = results[0], results[1]
|
||||
pages = [-1] * len(texts)
|
||||
|
||||
else:
|
||||
num_workers = mp.cpu_count()
|
||||
# Execute parallel processing
|
||||
with mp.Pool(num_workers) as pool:
|
||||
results = pool.starmap(
|
||||
langchain_parse_pure,
|
||||
[(data_path, parse_method, kwargs) for data_path in data_path_list],
|
||||
)
|
||||
|
||||
texts, path, pages = (list(chain.from_iterable(item)) for item in zip(*results))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def langchain_parse_pure(
|
||||
data_path: str, parse_method: str, kwargs
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parses a single file using the specified parse method.
|
||||
|
||||
Args:
|
||||
data_path (str): The file path to parse.
|
||||
parse_method (str): The parsing method to use.
|
||||
kwargs (Dict): Additional keyword arguments for the parsing method.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the parsed text and the file path.
|
||||
"""
|
||||
|
||||
parse_instance = parse_modules[parse_method](data_path, **kwargs)
|
||||
|
||||
# Load the text from the file
|
||||
documents = parse_instance.load()
|
||||
|
||||
texts = list(map(lambda x: x.page_content, documents))
|
||||
path = [data_path] * len(texts)
|
||||
if parse_method in ["pymupdf", "pdfplumber", "pypdf", "pypdfium2"]:
|
||||
pages = list(range(1, len(documents) + 1))
|
||||
else:
|
||||
pages = [-1] * len(texts)
|
||||
|
||||
# Clean up the parse instance
|
||||
del parse_instance
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def parse_all_files(
|
||||
data_path_list: List[str], parse_method: str, **kwargs
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
if parse_method == "unstructured":
|
||||
parse_instance = parse_modules[parse_method](data_path_list, **kwargs)
|
||||
elif parse_method == "directory":
|
||||
parse_instance = parse_modules[parse_method](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported parse method: {parse_method}")
|
||||
docs = parse_instance.load()
|
||||
texts = [doc.page_content for doc in docs]
|
||||
file_names = [doc.metadata["source"] for doc in docs]
|
||||
|
||||
del parse_instance
|
||||
return texts, file_names
|
||||
126
autorag/data/parse/llamaparse.py
Normal file
126
autorag/data/parse/llamaparse.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from itertools import chain
|
||||
|
||||
from llama_parse import LlamaParse
|
||||
|
||||
from autorag.data.parse.base import parser_node
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
|
||||
|
||||
@parser_node
|
||||
def llama_parse(
|
||||
data_path_list: List[str],
|
||||
batch: int = 8,
|
||||
use_vendor_multimodal_model: bool = False,
|
||||
vendor_multimodal_model_name: str = "openai-gpt4o",
|
||||
use_own_key: bool = False,
|
||||
vendor_multimodal_api_key: str = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use llama_parse.
|
||||
LLAMA_CLOUD_API_KEY environment variable should be set.
|
||||
You can get the key from https://cloud.llamaindex.ai/api-key
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param batch: The batch size for parse documents. Default is 8.
|
||||
:param use_vendor_multimodal_model: Whether to use the vendor multimodal model. Default is False.
|
||||
:param vendor_multimodal_model_name: The name of the vendor multimodal model. Default is "openai-gpt4o".
|
||||
:param use_own_key: Whether to use the own API key. Default is False.
|
||||
:param vendor_multimodal_api_key: The API key for the vendor multimodal model.
|
||||
:param kwargs: The extra parameters for creating the llama_parse instance.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
if use_vendor_multimodal_model:
|
||||
kwargs = _add_multimodal_params(
|
||||
kwargs,
|
||||
use_vendor_multimodal_model,
|
||||
vendor_multimodal_model_name,
|
||||
use_own_key,
|
||||
vendor_multimodal_api_key,
|
||||
)
|
||||
|
||||
parse_instance = LlamaParse(**kwargs)
|
||||
|
||||
tasks = [
|
||||
llama_parse_pure(data_path, parse_instance) for data_path in data_path_list
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
del parse_instance
|
||||
|
||||
texts, path, pages = (list(chain.from_iterable(item)) for item in zip(*results))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
async def llama_parse_pure(
|
||||
data_path: str, parse_instance
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
documents = await parse_instance.aload_data(data_path)
|
||||
|
||||
texts = list(map(lambda x: x.text, documents))
|
||||
path = [data_path] * len(texts)
|
||||
pages = list(range(1, len(documents) + 1))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def _add_multimodal_params(
|
||||
kwargs,
|
||||
use_vendor_multimodal_model,
|
||||
vendor_multimodal_model_name,
|
||||
use_own_key,
|
||||
vendor_multimodal_api_key,
|
||||
) -> dict:
|
||||
kwargs["use_vendor_multimodal_model"] = use_vendor_multimodal_model
|
||||
kwargs["vendor_multimodal_model_name"] = vendor_multimodal_model_name
|
||||
|
||||
def set_multimodal_api_key(
|
||||
multimodal_model_name: str = "openai-gpt4o", _api_key: str = None
|
||||
) -> str:
|
||||
if multimodal_model_name in ["openai-gpt4o", "openai-gpt-4o-mini"]:
|
||||
_api_key = (
|
||||
os.getenv("OPENAI_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the OPENAI_API_KEY in the environment variable OPENAI_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["anthropic-sonnet-3.5"]:
|
||||
_api_key = (
|
||||
os.getenv("ANTHROPIC_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the ANTHROPIC_API_KEY in the environment variable ANTHROPIC_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["gemini-1.5-flash", "gemini-1.5-pro"]:
|
||||
_api_key = (
|
||||
os.getenv("GEMINI_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the GEMINI_API_KEY in the environment variable GEMINI_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["custom-azure-model"]:
|
||||
raise NotImplementedError(
|
||||
"Custom Azure multimodal model is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid multimodal model name.")
|
||||
|
||||
return _api_key
|
||||
|
||||
if use_own_key:
|
||||
api_key = set_multimodal_api_key(
|
||||
vendor_multimodal_model_name, vendor_multimodal_api_key
|
||||
)
|
||||
kwargs["vendor_multimodal_api_key"] = api_key
|
||||
|
||||
return kwargs
|
||||
141
autorag/data/parse/run.py
Normal file
141
autorag/data/parse/run.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import os
|
||||
from typing import List, Callable, Dict
|
||||
import pandas as pd
|
||||
from glob import glob
|
||||
|
||||
from autorag.strategy import measure_speed
|
||||
from autorag.data.utils.util import get_param_combinations
|
||||
|
||||
default_map = {
|
||||
"pdf": {
|
||||
"file_type": "pdf",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "pdfminer",
|
||||
},
|
||||
"csv": {
|
||||
"file_type": "csv",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "csv",
|
||||
},
|
||||
"md": {
|
||||
"file_type": "md",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "unstructuredmarkdown",
|
||||
},
|
||||
"html": {
|
||||
"file_type": "html",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "bshtml",
|
||||
},
|
||||
"xml": {
|
||||
"file_type": "xml",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "unstructuredxml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_parser(
|
||||
modules: List[Callable],
|
||||
module_params: List[Dict],
|
||||
data_path_glob: str,
|
||||
project_dir: str,
|
||||
all_files: bool,
|
||||
):
|
||||
if not all_files:
|
||||
# Set the parsing module to default if it is a file type in paths but not set in YAML.
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
raise FileNotFoundError(f"data does not exits in {data_path_glob}")
|
||||
|
||||
file_types = set(
|
||||
[os.path.basename(data_path).split(".")[-1] for data_path in data_path_list]
|
||||
)
|
||||
set_file_types = set([module["file_type"] for module in module_params])
|
||||
|
||||
# Calculate the set difference once
|
||||
file_types_to_remove = set_file_types - file_types
|
||||
|
||||
# Use list comprehension to filter out unwanted elements
|
||||
module_params = [
|
||||
param
|
||||
for param in module_params
|
||||
if param["file_type"] not in file_types_to_remove
|
||||
]
|
||||
modules = [
|
||||
module
|
||||
for module, param in zip(modules, module_params)
|
||||
if param["file_type"] not in file_types_to_remove
|
||||
]
|
||||
|
||||
# create a list of only those file_types that are in file_types but not in set_file_types
|
||||
missing_file_types = list(file_types - set_file_types)
|
||||
|
||||
if missing_file_types:
|
||||
add_modules_list = []
|
||||
for missing_file_type in missing_file_types:
|
||||
if missing_file_type == "json":
|
||||
raise ValueError(
|
||||
"JSON file type must have a jq_schema so you must set it in the YAML file."
|
||||
)
|
||||
|
||||
add_modules_list.append(default_map[missing_file_type])
|
||||
|
||||
add_modules, add_params = get_param_combinations(add_modules_list)
|
||||
modules.extend(add_modules)
|
||||
module_params.extend(add_params)
|
||||
|
||||
results, execution_times = zip(
|
||||
*map(
|
||||
lambda x: measure_speed(x[0], data_path_glob=data_path_glob, **x[1]),
|
||||
zip(modules, module_params),
|
||||
)
|
||||
)
|
||||
average_times = list(map(lambda x: x / len(results[0]), execution_times))
|
||||
|
||||
# save results to parquet files
|
||||
if all_files:
|
||||
if len(module_params) > 1:
|
||||
raise ValueError(
|
||||
"All files is set to True, You can only use one parsing module."
|
||||
)
|
||||
filepaths = [os.path.join(project_dir, "parsed_result.parquet")]
|
||||
else:
|
||||
filepaths = list(
|
||||
map(
|
||||
lambda x: os.path.join(project_dir, f"{x['file_type']}.parquet"),
|
||||
module_params,
|
||||
)
|
||||
)
|
||||
|
||||
_files = {}
|
||||
for result, filepath in zip(results, filepaths):
|
||||
_files[filepath].append(result) if filepath in _files.keys() else _files.update(
|
||||
{filepath: [result]}
|
||||
)
|
||||
# Save files with a specific file type as Parquet files.
|
||||
for filepath, value in _files.items():
|
||||
pd.concat(value).to_parquet(filepath, index=False)
|
||||
|
||||
filenames = list(map(lambda x: os.path.basename(x), filepaths))
|
||||
|
||||
summary_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
"module_name": list(map(lambda module: module.__name__, modules)),
|
||||
"module_params": module_params,
|
||||
"execution_time": average_times,
|
||||
}
|
||||
)
|
||||
summary_df.to_csv(os.path.join(project_dir, "summary.csv"), index=False)
|
||||
|
||||
# concat all parquet files here if not all_files.
|
||||
_filepaths = list(_files.keys())
|
||||
if not all_files:
|
||||
dataframes = [pd.read_parquet(file) for file in _filepaths]
|
||||
combined_df = pd.concat(dataframes, ignore_index=True)
|
||||
combined_df.to_parquet(
|
||||
os.path.join(project_dir, "parsed_result.parquet"), index=False
|
||||
)
|
||||
|
||||
return summary_df
|
||||
134
autorag/data/parse/table_hybrid_parse.py
Normal file
134
autorag/data/parse/table_hybrid_parse.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
import tempfile
|
||||
from glob import glob
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
from PyPDF2 import PdfFileReader, PdfFileWriter
|
||||
import pdfplumber
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
from autorag.data.parse.base import parser_node
|
||||
|
||||
|
||||
@parser_node
|
||||
def table_hybrid_parse(
|
||||
data_path_list: List[str],
|
||||
text_parse_module: str,
|
||||
text_params: Dict,
|
||||
table_parse_module: str,
|
||||
table_params: Dict,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use table_hybrid_parse method.
|
||||
The table_hybrid_parse method is a hybrid method that combines the parsing results of PDFs with and without tables.
|
||||
It splits the PDF file into pages, separates pages with and without tables, and then parses and merges the results.
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param text_parse_module: The text parsing module to use. The type should be a string.
|
||||
:param text_params: The extra parameters for the text parsing module. The type should be a dictionary.
|
||||
:param table_parse_module: The table parsing module to use. The type should be a string.
|
||||
:param table_params: The extra parameters for the table parsing module. The type should be a dictionary.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
# make save folder directory
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as save_dir:
|
||||
text_dir = os.path.join(save_dir, "text")
|
||||
table_dir = os.path.join(save_dir, "table")
|
||||
|
||||
os.makedirs(text_dir, exist_ok=True)
|
||||
os.makedirs(table_dir, exist_ok=True)
|
||||
|
||||
# Split PDF file into pages and Save PDFs with and without tables
|
||||
path_map_dict_lst = [
|
||||
save_page_by_table(data_path, text_dir, table_dir)
|
||||
for data_path in data_path_list
|
||||
]
|
||||
path_map_dict = {k: v for d in path_map_dict_lst for k, v in d.items()}
|
||||
|
||||
# Extract text pages
|
||||
table_results, table_file_path = get_each_module_result(
|
||||
table_parse_module, table_params, os.path.join(table_dir, "*")
|
||||
)
|
||||
|
||||
# Extract table pages
|
||||
text_results, text_file_path = get_each_module_result(
|
||||
text_parse_module, text_params, os.path.join(text_dir, "*")
|
||||
)
|
||||
|
||||
# Merge parsing results of PDFs with and without tables
|
||||
texts = table_results + text_results
|
||||
temp_path_lst = table_file_path + text_file_path
|
||||
|
||||
# Sort by file names
|
||||
temp_path_lst, texts = zip(*sorted(zip(temp_path_lst, texts)))
|
||||
|
||||
# get original file path
|
||||
path = list(map(lambda temp_path: path_map_dict[temp_path], temp_path_lst))
|
||||
|
||||
# get pages
|
||||
pages = list(map(lambda x: get_page_from_path(x), temp_path_lst))
|
||||
|
||||
return list(texts), path, pages
|
||||
|
||||
|
||||
# Save PDFs with and without tables
|
||||
def save_page_by_table(data_path: str, text_dir: str, table_dir: str) -> Dict[str, str]:
|
||||
file_name = os.path.basename(data_path).split(".pdf")[0]
|
||||
|
||||
with open(data_path, "rb") as input_data:
|
||||
pdf_reader = PdfFileReader(input_data)
|
||||
num_pages = pdf_reader.getNumPages()
|
||||
|
||||
path_map_dict = {}
|
||||
for page_num in range(num_pages):
|
||||
output_pdf_path = _get_output_path(
|
||||
data_path, page_num, file_name, text_dir, table_dir
|
||||
)
|
||||
_save_single_page(pdf_reader, page_num, output_pdf_path)
|
||||
path_map_dict.update({output_pdf_path: data_path})
|
||||
|
||||
return path_map_dict
|
||||
|
||||
|
||||
def _get_output_path(
|
||||
data_path: str, page_num: int, file_name: str, text_dir: str, table_dir: str
|
||||
) -> str:
|
||||
with pdfplumber.open(data_path) as pdf:
|
||||
page = pdf.pages[page_num]
|
||||
tables = page.extract_tables()
|
||||
directory = table_dir if tables else text_dir
|
||||
return os.path.join(directory, f"{file_name}_page_{page_num + 1}.pdf")
|
||||
|
||||
|
||||
def _save_single_page(pdf_reader: PdfFileReader, page_num: int, output_pdf_path: str):
|
||||
pdf_writer = PdfFileWriter()
|
||||
pdf_writer.addPage(pdf_reader.getPage(page_num))
|
||||
|
||||
with open(output_pdf_path, "wb") as output_file:
|
||||
pdf_writer.write(output_file)
|
||||
|
||||
|
||||
def get_each_module_result(
|
||||
module: str, module_params: Dict, data_path_glob: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
module_params["module_type"] = module
|
||||
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
return [], []
|
||||
|
||||
module_name = module_params.pop("module_type")
|
||||
module_callable = get_support_modules(module_name)
|
||||
module_original = module_callable.__wrapped__
|
||||
texts, path, _ = module_original(data_path_list, **module_params)
|
||||
|
||||
return texts, path
|
||||
|
||||
|
||||
def get_page_from_path(file_path: str) -> int:
|
||||
file_name = os.path.basename(file_path)
|
||||
split_result = file_name.rsplit("_page_", -1)
|
||||
page_number_with_extension = split_result[1]
|
||||
page_number, _ = page_number_with_extension.split(".")
|
||||
|
||||
return int(page_number)
|
||||
3
autorag/data/qa/__init__.py
Normal file
3
autorag/data/qa/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# This is v2 version, the next version of data creation
|
||||
# The legacy (v1) version will be deprecated on AutoRAG version 0.3
|
||||
# The legacy (v1) version and new v2 data creation is not compatible with each other
|
||||
0
autorag/data/qa/evolve/__init__.py
Normal file
0
autorag/data/qa/evolve/__init__.py
Normal file
64
autorag/data/qa/evolve/llama_index_query_evolve.py
Normal file
64
autorag/data/qa/evolve/llama_index_query_evolve.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import itertools
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.base import BaseLLM
|
||||
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
|
||||
|
||||
from autorag.data.qa.evolve.prompt import QUERY_EVOLVE_PROMPT
|
||||
|
||||
|
||||
async def llama_index_generate_base(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
messages: List[ChatMessage],
|
||||
) -> Dict:
|
||||
original_query = row["query"]
|
||||
context = list(itertools.chain.from_iterable(row["retrieval_gt_contents"]))
|
||||
context_str = "Text:\n" + "\n".join(
|
||||
[f"{i + 1}. {c}" for i, c in enumerate(context)]
|
||||
)
|
||||
user_prompt = f"Question: {original_query}\nContext: {context_str}\nOutput: "
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
chat_response: ChatResponse = await llm.achat(messages=messages)
|
||||
row["query"] = chat_response.message.content
|
||||
return row
|
||||
|
||||
|
||||
async def conditional_evolve_ragas(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await llama_index_generate_base(
|
||||
row,
|
||||
llm,
|
||||
QUERY_EVOLVE_PROMPT["conditional_evolve_ragas"][lang],
|
||||
)
|
||||
|
||||
|
||||
async def reasoning_evolve_ragas(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await llama_index_generate_base(
|
||||
row,
|
||||
llm,
|
||||
QUERY_EVOLVE_PROMPT["reasoning_evolve_ragas"][lang],
|
||||
)
|
||||
|
||||
|
||||
async def compress_ragas(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
original_query = row["query"]
|
||||
user_prompt = f"Question: {original_query}\nOutput: "
|
||||
messages = QUERY_EVOLVE_PROMPT["compress_ragas"][lang]
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
chat_response: ChatResponse = await llm.achat(messages=messages)
|
||||
row["query"] = chat_response.message.content
|
||||
return row
|
||||
81
autorag/data/qa/evolve/openai_query_evolve.py
Normal file
81
autorag/data/qa/evolve/openai_query_evolve.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import itertools
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.llms.openai.utils import to_openai_message_dicts
|
||||
from openai import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autorag.data.qa.evolve.prompt import QUERY_EVOLVE_PROMPT
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
evolved_query: str
|
||||
|
||||
|
||||
async def query_evolve_openai_base(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
messages: List[ChatMessage],
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
):
|
||||
"""
|
||||
Evolve the original query to a new evolved query using OpenAI structured outputs.
|
||||
"""
|
||||
original_query = row["query"]
|
||||
context = list(itertools.chain.from_iterable(row["retrieval_gt_contents"]))
|
||||
context_str = "Text:\n" + "\n".join(
|
||||
[f"{i + 1}. {c}" for i, c in enumerate(context)]
|
||||
)
|
||||
user_prompt = f"Question: {original_query}\nContext: {context_str}\nOutput: "
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(messages),
|
||||
response_format=Response,
|
||||
)
|
||||
row["query"] = completion.choices[0].message.parsed.evolved_query
|
||||
return row
|
||||
|
||||
|
||||
async def conditional_evolve_ragas(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await query_evolve_openai_base(
|
||||
row, client, QUERY_EVOLVE_PROMPT["conditional_evolve_ragas"][lang], model_name
|
||||
)
|
||||
|
||||
|
||||
async def reasoning_evolve_ragas(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await query_evolve_openai_base(
|
||||
row, client, QUERY_EVOLVE_PROMPT["reasoning_evolve_ragas"][lang], model_name
|
||||
)
|
||||
|
||||
|
||||
async def compress_ragas(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
original_query = row["query"]
|
||||
messages = QUERY_EVOLVE_PROMPT["compress_ragas"][lang]
|
||||
user_prompt = f"Question: {original_query}\nOutput: "
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(messages),
|
||||
response_format=Response,
|
||||
)
|
||||
row["query"] = completion.choices[0].message.parsed.evolved_query
|
||||
return row
|
||||
288
autorag/data/qa/evolve/prompt.py
Normal file
288
autorag/data/qa/evolve/prompt.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# The RAGAS prompts are coming from RAGAS under Apache-2.0 License. (English version) (the AutoRAG team translates Korean version prompt)
|
||||
# You can see the original prompts at the RAGAS library at https://github.com/explodinggradients/ragas/blob/main/src/ragas/testset/prompts.py
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
|
||||
QUERY_EVOLVE_PROMPT = {
|
||||
"conditional_evolve_ragas": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""Rewrite the provided question to increase its complexity by introducing a conditional element.
|
||||
The goal is to make the question more intricate by incorporating a scenario or condition that affects the context of the question.
|
||||
Follow the rules given below while rewriting the question.
|
||||
1. The rewritten question should not be longer than 25 words. Use abbreviation wherever possible.
|
||||
2. The rewritten question must be reasonable and must be understood and responded by humans.
|
||||
3. The rewritten question must be fully answerable from information present context.
|
||||
4. phrases like 'provided context','according to the context?',etc are not allowed to appear in the question.
|
||||
""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question : What is the function of the roots of a plant?
|
||||
Context : The roots of a plant absorb water and nutrients from the soil, anchor the plant in the ground, and store food.
|
||||
Output : """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="What dual purpose do plant roots serve concerning soil nutrients and stability?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question : How do vaccines protect against diseases?
|
||||
Context : Vaccines protect against diseases by stimulating the body's immune response to produce antibodies, which recognize and combat pathogens.
|
||||
Output : """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="How do vaccines utilize the body's immune system to defend against pathogens?",
|
||||
),
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""제공된 질문에 조건에 관련한 내용을 추가하여 복잡성을 높이세요.
|
||||
질문의 Context에 영향을 미치는 시나리오나 조건을 포함하여 질문을 더 복잡하게 만드는 것이 목표입니다.
|
||||
질문을 다시 작성할 때 다음 규칙을 따르십시오.
|
||||
1. 다시 작성된 질문은 100자를 넘지 않아야 합니다. 가능한 경우 약어를 사용하십시오.
|
||||
2. 다시 작성된 질문은 합리적이어야 하며 사람이 이해하고 응답할 수 있어야 합니다.
|
||||
3. 다시 작성된 질문은 현재 Context에서 완전히 답변할 수 있어야 합니다.
|
||||
4. '제공된 글', '단락에 따르면?', 'Context에 의하면' 등의 문구는 질문에 나타날 수 없습니다.
|
||||
5. 한국어로 질문을 작성하세요.
|
||||
""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 식물의 뿌리 기능이 뭐야?
|
||||
Context: 식물의 뿌리는 토양에서 물과 영양분을 흡수하고, 식물을 땅에 고정하며, 영양분을 저장합니다.
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="식물의 뿌리는 토양 영양분과 안정성에 대해 어떤 역할을 하나요?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 백신은 질병을 어떻게 예방하나요?
|
||||
Context: 백신은 신체의 면역 반응을 자극하여 병원체를 인식하고 싸우는 항체를 생성함으로써 질병으로부터 보호합니다.
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="백신은 신체의 면역 체계를 어떻게 활용해서 질병을 예방합니까?",
|
||||
),
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""提供された質問に条件に関する内容を追加して、複雑さを高めます。
|
||||
質問のContextに影響を与えるシナリオや条件を含めて、質問をより複雑にすることが目標です。
|
||||
質問を再作成するときは、次のルールに従います。
|
||||
1. 再作成された質問は100文字を超えてはいけません。 可能であれば略語を使ってください
|
||||
2. 再作成された質問は合理的でなければならず、人が理解して回答できるものでなければなりません。
|
||||
3. 再作成された質問は、現在のContextで完全に答えられる必要があります。
|
||||
4. 「提供された文」、「段落によると?」、「Contextによると」などのフレーズは質問に表示されません。
|
||||
5. 日本語で質問を書きましょう。
|
||||
""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 植物の根の機能は何ですか?
|
||||
Context: 植物の根は土壌から水や栄養分を吸収し、植物を地面に固定し、栄養分を蓄えます。
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="植物の根は土壌栄養分と安定性に対してどのような役割をしますか?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: ワクチンは病気をどのように予防しますか?
|
||||
Context: ワクチンは、体の免疫反応を刺激して病原体を認識し、戦う抗体を生成することで病気から守ります。
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="ワクチンは体の免疫システムをどのように活用して病気を予防しますか?",
|
||||
),
|
||||
],
|
||||
},
|
||||
"reasoning_evolve_ragas": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""Complicate the given question by rewriting question into a multi-hop reasoning question based on the provided context.
|
||||
Answering the question should require the reader to make multiple logical connections or inferences using the information available in given context.
|
||||
Rules to follow when rewriting question:
|
||||
1. Ensure that the rewritten question can be answered entirely from the information present in the contexts.
|
||||
2. Do not frame questions that contains more than 15 words. Use abbreviation wherever possible.
|
||||
3. Make sure the question is clear and unambiguous.
|
||||
4. phrases like 'based on the provided context','according to the context',etc are not allowed to appear in the question.""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: What is the capital of France?,
|
||||
Context: France is a country in Western Europe. It has several cities, including Paris, Lyon, and Marseille. Paris is not only known for its cultural landmarks like the Eiffel Tower and the Louvre Museum but also as the administrative center.
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="Linking the Eiffel Tower and administrative center, which city stands as both?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: What does the append() method do in Python?
|
||||
Context: In Python, lists are used to store multiple items in a single variable. Lists are one of 4 built-in data types used to store collections of data. The append() method adds a single item to the end of a list.
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="If a list represents a variable collection, what method extends it by one item?",
|
||||
),
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""주어진 Context를 기반으로 기존 질문을 복잡하게 만들어 여러 논리적인 사고가 필요한 질문으로 다시 작성하세요.
|
||||
질문에 답하려면 주어진 Context의 정보를 사용해 여러 논리적 사고나 추론을 해야 합니다.
|
||||
질문을 다시 작성할 때 따라야 할 규칙:
|
||||
1. 다시 작성된 질문은 Context에 있는 정보만으로 완전히 답변할 수 있어야 합니다.
|
||||
2. 100자를 초과하는 질문을 작성하지 마세요. 가능한 경우 약어를 사용하세요.
|
||||
3. 질문이 명확하고 모호하지 않도록 하세요.
|
||||
4. '제공된 Context에 기반하여', '해당 단락에 따르면' 등의 문구는 질문에 포함되지 않아야 합니다.
|
||||
5. 한국어로 질문을 작성하세요.""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 프랑스의 수도는 어디인가요?,
|
||||
Context: 프랑스는 서유럽에 있는 나라입니다. 파리, 리옹, 마르세유를 포함한 여러 도시가 있습니다. 파리는 에펠탑과 루브르 박물관 같은 문화적 랜드마크로 유명할 뿐만 아니라 행정 중심지로도 알려져 있습니다.
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="에펠탑과 행정 중심지, 두 단어는 어떤 도시를 가리키나요?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""질문: Python에서 append() 메서드는 무엇을 하나요?
|
||||
컨텍스트: Python에서 리스트는 하나의 변수에 여러 항목을 저장하는 데 사용됩니다. 리스트는 데이터를 저장하는 데 사용되는 4가지 내장 데이터 유형 중 하나입니다. append() 메서드는 리스트의 끝에 새로운 항목을 추가합니다.
|
||||
출력: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="리스트가 변수들을 모아 놓은 것을 나타낸다면, 어떤 메서드를 사용해야 항목을 하나 더 추가할 수 있습니까?",
|
||||
),
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""与えられたContextに基づいて既存の質問を複雑にして、様々な論理的思考が必要な質問として書き直しましょう。
|
||||
質問に答えるためには、与えられたContextの情報を使って様々な論理的思考や推論をしなければなりません。
|
||||
質問を再作成するときに従うべきルール:
|
||||
1. 再作成された質問は、Contextにある情報だけで完全に答えられる必要があります。
|
||||
2. 100文字を超える質問を作成してはいけません。 可能であれば略語を使ってください。
|
||||
3. 質問が明確で曖昧にならないようにしましょう。
|
||||
4. 「提供されたContextに基づいて」、「当該段落によると」などのフレーズは、質問に含まれてはいけません。
|
||||
5. 日本語で質問を書きましょう。""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: フランスの首都はどこですか?,
|
||||
Context: フランスは西ヨーロッパにある国です。 パリ、リヨン、マルセイユを含むいくつかの都市があります。 パリはエッフェル塔やルーブル博物館のような文化的ランドマークとして有名なだけでなく、行政の中心地としても知られています。
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="エッフェル塔と行政の中心地、二つの単語はどんな都市を指していますか?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: Pythonでappend() メソッドは何をしますか?
|
||||
Context: Pythonで、リストは 1 つの変数に複数の項目を保存するために使用されます。 リストは、データを保存するために使用される 4 つの組み込みデータ タイプの 1 つです。 append()メソッドは、リストの最後に新しい項目を追加します。
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="リストが変数を集めたものである場合、どのメソッドを使えば項目を一つ追加することができますか?",
|
||||
),
|
||||
],
|
||||
},
|
||||
"compress_ragas": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""Rewrite the following question to make it more indirect and shorter while retaining the essence of the original question.
|
||||
The goal is to create a question that conveys the same meaning but in a less direct manner. The rewritten question should shorter so use abbreviation wherever possible.""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: What is the distance between the Earth and the Moon?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="How far is the Moon from Earth?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: What ingredients are required to bake a chocolate cake?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="What's needed for a chocolate cake?",
|
||||
),
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""주어진 질문을 더 간접적이고 짧게 다시 작성하세요.
|
||||
목표는 질문을 원래 질문의 본질을 유지하면서 너무 직설적이지 않게 만드는 것입니다.
|
||||
약어 등을 사용하여 질문을 더 짧게 만드세요.""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 지구와 달 사이의 거리는 얼마입니까?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="달은 지구에서 얼마나 떨어져 있나요?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 초콜릿 케이크를 굽기 위해 필요한 재료는 무엇입니까?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="초콜릿 케이크에 필요한 것은 무엇인가요?",
|
||||
),
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""与えられた質問をより間接的かつ短く書き換えます。
|
||||
目標は、質問を元の質問の本質を保ちながら、あまりストレートにならないようにすることです。
|
||||
略語などを使用して、質問をより短くします。""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: 地球と月の間の距離はどれくらいですか?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="月は地球からどれくらい離れていますか?",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Question: チョコレートケーキを焼くために必要な材料は何ですか?
|
||||
Output: """,
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="チョコレートケーキに必要なものは何ですか?",
|
||||
),
|
||||
],
|
||||
},
|
||||
}
|
||||
1
autorag/data/qa/extract_evidence.py
Normal file
1
autorag/data/qa/extract_evidence.py
Normal file
@@ -0,0 +1 @@
|
||||
# This module is about extracting evidence from the given retrieval gt passage
|
||||
0
autorag/data/qa/filter/__init__.py
Normal file
0
autorag/data/qa/filter/__init__.py
Normal file
117
autorag/data/qa/filter/dontknow.py
Normal file
117
autorag/data/qa/filter/dontknow.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.base import BaseLLM
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse
|
||||
from llama_index.llms.openai.utils import to_openai_message_dicts
|
||||
from openai import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autorag.data.qa.filter.prompt import FILTER_PROMPT
|
||||
|
||||
dont_know_phrases = {
|
||||
"en": [
|
||||
"I don't know",
|
||||
"I do not know",
|
||||
"Don't know",
|
||||
"Do not know",
|
||||
],
|
||||
"ko": [
|
||||
"몰라요",
|
||||
"모르겠습니다",
|
||||
"모르겠어요",
|
||||
"몰라",
|
||||
"내가 어떻게 알아?",
|
||||
"모르겠소",
|
||||
"몰라유",
|
||||
"모르것는디",
|
||||
"모르겠어유",
|
||||
"모르겠네유",
|
||||
"모르겠네요",
|
||||
],
|
||||
"ja": [
|
||||
"知りません",
|
||||
"わかりません",
|
||||
"分かりません",
|
||||
"知らないです",
|
||||
"よく分かってません",
|
||||
"わかりかねます",
|
||||
"存じません",
|
||||
"お答えいたしかねます",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def dontknow_filter_rule_based(row: Dict, lang: str = "en") -> bool:
|
||||
assert (
|
||||
"generation_gt" in row.keys()
|
||||
), "generation_gt column is not in the DataFrame."
|
||||
dont_know_phrase = dont_know_phrases[lang]
|
||||
return not any(
|
||||
phrase in s for phrase in dont_know_phrase for s in row["generation_gt"]
|
||||
)
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
is_dont_know: bool
|
||||
|
||||
|
||||
async def dontknow_filter_openai(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-mini-2024-07-18",
|
||||
lang: str = "en",
|
||||
) -> bool:
|
||||
"""
|
||||
This will drop rows that have a "don't know" answer.
|
||||
It will drop unanswerable questions from the QA dataset.
|
||||
You can use this filter with the ` batch_filter ` function at `QA` class.
|
||||
|
||||
:param row: The row dict from QA dataset.
|
||||
:param client: The OpenAI client.
|
||||
:param model_name: The model name.
|
||||
You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18.
|
||||
:param lang: The supported language is en, ko or ja.
|
||||
:return: False if the row generation_gt is a "don't know" meaning.
|
||||
"""
|
||||
assert "generation_gt" in row.keys(), "generation_gt column is not in the row."
|
||||
system_prompt: List[ChatMessage] = FILTER_PROMPT["dontknow_filter"][lang]
|
||||
result = []
|
||||
for gen_gt in row["generation_gt"]:
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(
|
||||
system_prompt + [ChatMessage(role=MessageRole.USER, content=gen_gt)]
|
||||
),
|
||||
response_format=Response,
|
||||
)
|
||||
result.append(completion.choices[0].message.parsed.is_dont_know)
|
||||
return not any(result)
|
||||
|
||||
|
||||
async def dontknow_filter_llama_index(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> bool:
|
||||
"""
|
||||
This will drop rows that have a "don't know" answer.
|
||||
It will drop unanswerable questions from the QA dataset.
|
||||
You can use this filter with the ` batch_filter ` function at `QA` class.
|
||||
|
||||
:param row: The row dict from QA dataset.
|
||||
:param llm: The Llama index llm instance.
|
||||
It will be good if you set max tokens to low for saving tokens.
|
||||
:param lang: The supported language is en, ko or ja.
|
||||
:return: False if the row generation_gt is a "don't know" meaning.
|
||||
"""
|
||||
assert "generation_gt" in row.keys(), "generation_gt column is not in the row."
|
||||
system_prompt: List[ChatMessage] = FILTER_PROMPT["dontknow_filter"][lang]
|
||||
results = []
|
||||
for gen_gt in row["generation_gt"]:
|
||||
response: ChatResponse = await llm.achat(
|
||||
messages=system_prompt
|
||||
+ [ChatMessage(role=MessageRole.USER, content=gen_gt)]
|
||||
)
|
||||
result_str = response.message.content
|
||||
results.append("true" in result_str.lower().strip())
|
||||
return not any(results)
|
||||
88
autorag/data/qa/filter/passage_dependency.py
Normal file
88
autorag/data/qa/filter/passage_dependency.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.base import BaseLLM
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse
|
||||
from llama_index.llms.openai.utils import to_openai_message_dicts
|
||||
from openai import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autorag.data.qa.filter.prompt import FILTER_PROMPT
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
is_passage_dependent: bool
|
||||
|
||||
|
||||
async def passage_dependency_filter_openai(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-mini-2024-07-18",
|
||||
lang: str = "en",
|
||||
) -> bool:
|
||||
"""
|
||||
This will drop passage-dependent question rows.
|
||||
Passage-dependent questions are questions that the answer will change depending on what passage you choose.
|
||||
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question.
|
||||
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table.
|
||||
And what is the table? The retrieval system can't find the right passage with this question.
|
||||
You can use this filter with the ` batch_filter ` function at `QA` class.
|
||||
|
||||
:param row: The row dict from QA dataset.
|
||||
:param client: The OpenAI client.
|
||||
:param model_name: The model name.
|
||||
You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18.
|
||||
:param lang: The supported language is en, ko or ja.
|
||||
:return: False if the row question is a passage-dependent question (to be filtered).
|
||||
"""
|
||||
assert "query" in row.keys(), "query column is not in the row."
|
||||
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang]
|
||||
query = row["query"]
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(
|
||||
system_prompt
|
||||
+ [
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content=f"Question: {query}\nIs this the question passage dependent?",
|
||||
)
|
||||
]
|
||||
),
|
||||
response_format=Response,
|
||||
)
|
||||
return not completion.choices[0].message.parsed.is_passage_dependent
|
||||
|
||||
|
||||
async def passage_dependency_filter_llama_index(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> bool:
|
||||
"""
|
||||
This will drop passage-dependent question rows.
|
||||
Passage-dependent questions are questions that the answer will change depending on what passage you choose.
|
||||
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question.
|
||||
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table.
|
||||
And what is the table? The retrieval system can't find the right passage with this question.
|
||||
You can use this filter with the ` batch_filter ` function at `QA` class.
|
||||
|
||||
:param row: The row dict from QA dataset.
|
||||
:param llm: The Llama index llm instance.
|
||||
It will be good if you set max tokens to low for saving tokens.
|
||||
:param lang: The supported language is en, ko or ja.
|
||||
:return: False if the row question is a passage-dependent question (to be filtered).
|
||||
"""
|
||||
assert "query" in row.keys(), "query column is not in the row."
|
||||
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang]
|
||||
query = row["query"]
|
||||
response: ChatResponse = await llm.achat(
|
||||
messages=system_prompt
|
||||
+ [
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content=f"Question: {query}\nIs this the question passage dependent?",
|
||||
)
|
||||
]
|
||||
)
|
||||
result_str = response.message.content
|
||||
return "true" not in result_str.lower().strip()
|
||||
73
autorag/data/qa/filter/prompt.py
Normal file
73
autorag/data/qa/filter/prompt.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
|
||||
FILTER_PROMPT = {
|
||||
"dontknow_filter": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""The following sentence is an answer about a question. You have to decide the answer implies 'I don't know'.
|
||||
If the answer implies 'I don't know', return True. If not, return False.""",
|
||||
),
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""다음 문장은 어떠한 질문에 대한 대답입니다. 해당 문장이 질문에 대해서 '모른다고' 답한 것인지 판단하십시오.
|
||||
만약 해당 문장이 '모른다고' 답한 것이라면, True를 반환하세요. 그렇지 않다면 False를 반환하세요.""",
|
||||
)
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""次の文章はある質問に対する答えです。 該当文章が質問に対して「知らない」と答えたのか判断します。
|
||||
もし、その文章が「知らない」と答えたのであれば、Trueを返します。 そうでなければFalseを返します。""",
|
||||
)
|
||||
],
|
||||
},
|
||||
"passage_dependency": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""You are a classifier that recognize 'passage dependent' questions.
|
||||
The 'passage dependent' is the question that the answer will be change depending on what passage you choose.
|
||||
For example) 'What is the highest score according to the table?'
|
||||
This sentence is the passage dependent question because the answer will be different depending on the table.
|
||||
|
||||
In contrast, the following sentence is not passage dependant.
|
||||
'What is the highest score of the KBO baseball history in one game?'
|
||||
'What is the capital of France?'
|
||||
These sentences will have the same answer regardless of the passage.
|
||||
|
||||
Please return True if the input question is passage dependent. Else return False.""",
|
||||
)
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""당신은 '단락 의존' 질문을 인식하는 분류기입니다.
|
||||
'단락 의존'이란 어떤 단락이 선택 되는지 따라 답이 달라지는 질문을 의미합니다.
|
||||
예를 들어, '주어진 표에 따르면 가장 높은 점수는 무엇인가요?'라는 질문은 단락 의존 질문입니다. 왜냐하면 표가 어떤 것인지에 따라 그 답이 달라지기 때문입니다.
|
||||
|
||||
반면에, 다음 문장들은 단락 의존적이지 않습니다.
|
||||
'KBO 야구 역사상 한 경기에서 가장 높은 점수는 무엇인가요?' 또는 '프랑스의 수도는 무엇인가요?'
|
||||
이러한 문장은 단락에 관계 없이 동일한 답을 가집니다.
|
||||
|
||||
입력된 질문이 단락 의존적이라면 True를 반환하고, 그렇지 않으면 False를 반환하세요.""",
|
||||
)
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""あなたは「段落依存」の質問を認識する分類器です。
|
||||
「段落依存」とは、どの段落が選択されるかによって答えが変わる質問を意味します。
|
||||
たとえば、「与えられた表によると、最も高い点数は何ですか?」という質問は、段落依存の質問です。 なぜなら、表がどんなものかによってその答えが変わるからです。
|
||||
|
||||
一方、次の文章は段落依存的ではありません。
|
||||
KBO野球史上1試合で最も高い点数は何ですか?またはフランスの首都は何ですか?'
|
||||
このような文章は段落に関係なく同じ答えを持ちます。
|
||||
|
||||
入力された質問が段落依存的である場合はTrueを返し、そうでない場合はFalseを返します。""",
|
||||
)
|
||||
],
|
||||
},
|
||||
}
|
||||
0
autorag/data/qa/generation_gt/__init__.py
Normal file
0
autorag/data/qa/generation_gt/__init__.py
Normal file
16
autorag/data/qa/generation_gt/base.py
Normal file
16
autorag/data/qa/generation_gt/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def add_gen_gt(row: Dict, new_gen_gt: str) -> Dict:
|
||||
if "generation_gt" in list(row.keys()):
|
||||
if isinstance(row["generation_gt"], list):
|
||||
row["generation_gt"].append(new_gen_gt)
|
||||
elif isinstance(row["generation_gt"], str):
|
||||
row["generation_gt"] = [row["generation_gt"], new_gen_gt]
|
||||
else:
|
||||
raise ValueError(
|
||||
"generation_gt should be either a string or a list of strings."
|
||||
)
|
||||
return row
|
||||
row["generation_gt"] = [new_gen_gt]
|
||||
return row
|
||||
41
autorag/data/qa/generation_gt/llama_index_gen_gt.py
Normal file
41
autorag/data/qa/generation_gt/llama_index_gen_gt.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import itertools
|
||||
from typing import Dict
|
||||
|
||||
|
||||
from llama_index.core.base.llms.base import BaseLLM
|
||||
from llama_index.core.base.llms.types import MessageRole, ChatMessage
|
||||
|
||||
from autorag.data.qa.generation_gt.base import add_gen_gt
|
||||
from autorag.data.qa.generation_gt.prompt import GEN_GT_SYSTEM_PROMPT
|
||||
|
||||
|
||||
async def make_gen_gt_llama_index(row: Dict, llm: BaseLLM, system_prompt: str) -> Dict:
|
||||
retrieval_gt_contents = list(
|
||||
itertools.chain.from_iterable(row["retrieval_gt_contents"])
|
||||
)
|
||||
query = row["query"]
|
||||
passage_str = "\n".join(retrieval_gt_contents)
|
||||
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:"
|
||||
|
||||
response = await llm.achat(
|
||||
messages=[
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt),
|
||||
ChatMessage(role=MessageRole.USER, content=user_prompt),
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
return add_gen_gt(row, response.message.content)
|
||||
|
||||
|
||||
async def make_concise_gen_gt(row: Dict, llm: BaseLLM, lang: str = "en") -> Dict:
|
||||
return await make_gen_gt_llama_index(
|
||||
row, llm, GEN_GT_SYSTEM_PROMPT["concise"][lang]
|
||||
)
|
||||
|
||||
|
||||
async def make_basic_gen_gt(row: Dict, llm: BaseLLM, lang: str = "en") -> Dict:
|
||||
return await make_gen_gt_llama_index(row, llm, GEN_GT_SYSTEM_PROMPT["basic"][lang])
|
||||
|
||||
|
||||
async def make_custom_gen_gt(row: Dict, llm: BaseLLM, system_prompt: str) -> Dict:
|
||||
return await make_gen_gt_llama_index(row, llm, system_prompt)
|
||||
84
autorag/data/qa/generation_gt/openai_gen_gt.py
Normal file
84
autorag/data/qa/generation_gt/openai_gen_gt.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import itertools
|
||||
from typing import Dict
|
||||
|
||||
from openai import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autorag.data.qa.generation_gt.base import add_gen_gt
|
||||
from autorag.data.qa.generation_gt.prompt import GEN_GT_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
async def make_gen_gt_openai(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
system_prompt: str,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
):
|
||||
retrieval_gt_contents = list(
|
||||
itertools.chain.from_iterable(row["retrieval_gt_contents"])
|
||||
)
|
||||
query = row["query"]
|
||||
passage_str = "\n".join(retrieval_gt_contents)
|
||||
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:"
|
||||
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=0.0,
|
||||
response_format=Response,
|
||||
)
|
||||
response: Response = completion.choices[0].message.parsed
|
||||
return add_gen_gt(row, response.answer)
|
||||
|
||||
|
||||
async def make_concise_gen_gt(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
):
|
||||
"""
|
||||
Generate concise generation_gt using OpenAI Structured Output for preventing errors.
|
||||
It generates a concise answer, so it is generally a word or just a phrase.
|
||||
|
||||
:param row: The input row of the qa dataframe.
|
||||
:param client: The OpenAI async client.
|
||||
:param model_name: The model name that supports structured output.
|
||||
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18".
|
||||
:param lang: The language code of the prompt.
|
||||
Default is "en".
|
||||
:return: The output row of the qa dataframe with added "generation_gt" in it.
|
||||
"""
|
||||
return await make_gen_gt_openai(
|
||||
row, client, GEN_GT_SYSTEM_PROMPT["concise"][lang], model_name
|
||||
)
|
||||
|
||||
|
||||
async def make_basic_gen_gt(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
):
|
||||
"""
|
||||
Generate basic generation_gt using OpenAI Structured Output for preventing errors.
|
||||
It generates a "basic" answer, and its prompt is simple.
|
||||
|
||||
:param row: The input row of the qa dataframe.
|
||||
:param client: The OpenAI async client.
|
||||
:param model_name: The model name that supports structured output.
|
||||
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18".
|
||||
:param lang: The language code of the prompt.
|
||||
Default is "en".
|
||||
:return: The output row of the qa dataframe with added "generation_gt" in it.
|
||||
"""
|
||||
return await make_gen_gt_openai(
|
||||
row, client, GEN_GT_SYSTEM_PROMPT["basic"][lang], model_name
|
||||
)
|
||||
27
autorag/data/qa/generation_gt/prompt.py
Normal file
27
autorag/data/qa/generation_gt/prompt.py
Normal file
@@ -0,0 +1,27 @@
|
||||
GEN_GT_SYSTEM_PROMPT = {
|
||||
"concise": {
|
||||
"en": """You are an AI assistant to answer the given question in the provide evidence text.
|
||||
You can find the evidence from the given text about question, and you have to write a proper answer to the given question.
|
||||
Your answer have to be concise and relevant to the question.
|
||||
Do not make a verbose answer and make it super clear.
|
||||
It doesn't have to be an full sentence. It can be the answer is a word or a paraphrase.""",
|
||||
"ko": """당신은 주어진 질문에 대해 제공된 Text 내에서 답을 찾는 AI 비서입니다.
|
||||
질문에 대한 답을 Text에서 찾아 적절한 답변을 작성하세요.
|
||||
답변은 간결하고 질문에 관련된 내용만 포함해야 합니다.
|
||||
불필요하게 길게 답변하지 말고, 명확하게 작성하세요.
|
||||
완전한 문장이 아니어도 되며, 답은 단어나 요약일 수 있습니다.""",
|
||||
"ja": """
|
||||
あなたは与えられた質問に対して提供されたText内で答えを探すAI秘書です。
|
||||
質問に対する答えをTextで探して適切な答えを作成しましょう。
|
||||
回答は簡潔で、質問に関連する内容のみを含める必要があります。
|
||||
不必要に長く答えず、明確に作成しましょう。
|
||||
完全な文章でなくてもいいし、答えは単語や要約かもしれません。
|
||||
""",
|
||||
},
|
||||
"basic": {
|
||||
"en": """You are an AI assistant to answer the given question in the provide evidence text.
|
||||
You can find the evidence from the given text about question, and you have to write a proper answer to the given question.""",
|
||||
"ko": "당신은 주어진 질문에 대한 답을 제공된 Text 내에서 찾는 AI 비서입니다. 질문과 관련된 증거를 Text에서 찾아 적절한 답변을 작성하세요.",
|
||||
"ja": "あなたは与えられた質問に対する答えを提供されたText内で探すAI秘書です。 質問に関する証拠をTextで探して適切な回答を作成しましょう。",
|
||||
},
|
||||
}
|
||||
0
autorag/data/qa/query/__init__.py
Normal file
0
autorag/data/qa/query/__init__.py
Normal file
82
autorag/data/qa/query/llama_gen_query.py
Normal file
82
autorag/data/qa/query/llama_gen_query.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import itertools
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.base import BaseLLM
|
||||
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
|
||||
|
||||
from autorag.data.qa.query.prompt import QUERY_GEN_PROMPT, QUERY_GEN_PROMPT_EXTRA
|
||||
|
||||
|
||||
async def llama_index_generate_base(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
messages: List[ChatMessage],
|
||||
) -> Dict:
|
||||
context = list(itertools.chain.from_iterable(row["retrieval_gt_contents"]))
|
||||
context_str = "\n".join([f"{i + 1}. {c}" for i, c in enumerate(context)])
|
||||
user_prompt = f"Text:\n{context_str}\n\nGenerated Question from the Text:\n"
|
||||
user_message = ChatMessage(role=MessageRole.USER, content=user_prompt)
|
||||
new_messages = [*messages, user_message]
|
||||
chat_response: ChatResponse = await llm.achat(messages=new_messages)
|
||||
row["query"] = chat_response.message.content
|
||||
return row
|
||||
|
||||
|
||||
async def factoid_query_gen(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await llama_index_generate_base(
|
||||
row, llm, QUERY_GEN_PROMPT["factoid_single_hop"][lang]
|
||||
)
|
||||
|
||||
|
||||
async def concept_completion_query_gen(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await llama_index_generate_base(
|
||||
row, llm, QUERY_GEN_PROMPT["concept_completion"][lang]
|
||||
)
|
||||
|
||||
|
||||
async def two_hop_incremental(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
messages = QUERY_GEN_PROMPT["two_hop_incremental"][lang]
|
||||
passages = row["retrieval_gt_contents"]
|
||||
assert (
|
||||
len(passages) >= 2
|
||||
), "You have to sample more than two passages for making two-hop questions."
|
||||
context_str = f"Document 1: {passages[0][0]}\nDocument 2: {passages[1][0]}"
|
||||
user_prompt = f"{context_str}\n\nGenerated two-hop Question from two Documents:\n"
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
chat_response: ChatResponse = await llm.achat(messages=messages)
|
||||
response = chat_response.message.content
|
||||
row["query"] = response.split(":")[-1].strip()
|
||||
return row
|
||||
|
||||
|
||||
async def custom_query_gen(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
messages: List[ChatMessage],
|
||||
) -> Dict:
|
||||
return await llama_index_generate_base(row, llm, messages)
|
||||
|
||||
|
||||
# Experimental feature: can only use factoid_single_hop
|
||||
async def multiple_queries_gen(
|
||||
row: Dict,
|
||||
llm: BaseLLM,
|
||||
lang: str = "en",
|
||||
n: int = 3,
|
||||
) -> Dict:
|
||||
_messages = QUERY_GEN_PROMPT["factoid_single_hop"][lang]
|
||||
_messages[0].content += QUERY_GEN_PROMPT_EXTRA["multiple_queries"][lang].format(n=n)
|
||||
return await llama_index_generate_base(row, llm, _messages)
|
||||
95
autorag/data/qa/query/openai_gen_query.py
Normal file
95
autorag/data/qa/query/openai_gen_query.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import itertools
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.llms.openai.utils import to_openai_message_dicts
|
||||
from openai import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autorag.data.qa.query.prompt import QUERY_GEN_PROMPT
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
# Single hop QA generation OpenAI
|
||||
async def query_gen_openai_base(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
messages: List[ChatMessage],
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
):
|
||||
context = list(itertools.chain.from_iterable(row["retrieval_gt_contents"]))
|
||||
context_str = "Text:\n" + "\n".join(
|
||||
[f"{i + 1}. {c}" for i, c in enumerate(context)]
|
||||
)
|
||||
user_prompt = f"{context_str}\n\nGenerated Question from the Text:\n"
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(messages),
|
||||
response_format=Response,
|
||||
)
|
||||
row["query"] = completion.choices[0].message.parsed.query
|
||||
return row
|
||||
|
||||
|
||||
async def factoid_query_gen(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await query_gen_openai_base(
|
||||
row, client, QUERY_GEN_PROMPT["factoid_single_hop"][lang], model_name
|
||||
)
|
||||
|
||||
|
||||
async def concept_completion_query_gen(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
return await query_gen_openai_base(
|
||||
row, client, QUERY_GEN_PROMPT["factoid_single_hop"][lang], model_name
|
||||
)
|
||||
|
||||
|
||||
class TwoHopIncrementalResponse(BaseModel):
|
||||
answer: str
|
||||
one_hop_question: str
|
||||
two_hop_question: str
|
||||
|
||||
|
||||
async def two_hop_incremental(
|
||||
row: Dict,
|
||||
client: AsyncClient,
|
||||
model_name: str = "gpt-4o-2024-08-06",
|
||||
lang: str = "en",
|
||||
) -> Dict:
|
||||
"""
|
||||
Create a two-hop question using incremental prompt.
|
||||
Incremental prompt is more effective to create multi-hop question.
|
||||
The input retrieval_gt has to include more than one passage.
|
||||
|
||||
:return: The two-hop question using openai incremental prompt
|
||||
"""
|
||||
messages = QUERY_GEN_PROMPT["two_hop_incremental"][lang]
|
||||
passages = row["retrieval_gt_contents"]
|
||||
assert (
|
||||
len(passages) >= 2
|
||||
), "You have to sample more than two passages for making two-hop questions."
|
||||
context_str = f"Document 1: {passages[0][0]}\nDocument 2: {passages[1][0]}"
|
||||
user_prompt = f"{context_str}\n\nGenerated two-hop Question from two Documents:\n"
|
||||
messages.append(ChatMessage(role=MessageRole.USER, content=user_prompt))
|
||||
|
||||
completion = await client.beta.chat.completions.parse(
|
||||
model=model_name,
|
||||
messages=to_openai_message_dicts(messages),
|
||||
response_format=TwoHopIncrementalResponse,
|
||||
)
|
||||
row["query"] = completion.choices[0].message.parsed.two_hop_question
|
||||
return row
|
||||
202
autorag/data/qa/query/prompt.py
Normal file
202
autorag/data/qa/query/prompt.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
|
||||
QUERY_GEN_PROMPT = {
|
||||
"factoid_single_hop": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""You're an AI tasked to convert Text into a factoid question.
|
||||
Factoid questions are those seeking brief, factual information that can be easily verified. They typically require a yes or no answer or a brief explanation and often inquire about specific details such as dates, names, places, or events.
|
||||
|
||||
Examples of factoid questions include:
|
||||
|
||||
- What is the capital of France?
|
||||
- Who invented the light bulb?
|
||||
- When was Wikipedia founded?
|
||||
|
||||
Instructions:
|
||||
1. Questions MUST BE extracted from given Text
|
||||
2. Questions should be as detailed as possible from Text
|
||||
3. Create questions that ask about factual information from the Text
|
||||
4. Do not mention any of these in the questions: "in the given text", "in the provided information", etc.
|
||||
Users do not know the passage source of the question, so it should not be mentioned in the question.
|
||||
5. Do not ask about the file name or the file title. Ask about the content of the file.
|
||||
For example, avoid to write questions like `What is the file name of the document?`""",
|
||||
)
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""당신은 주어진 Text를 '사실 질문'으로 변환하는 AI입니다.
|
||||
|
||||
사실 질문(factoid questions)이란 사실적인 정보를 요구하는 질문으로, 쉽게 검증할 수 있는 답변을 필요로 합니다. 일반적으로 예/아니오 답변이나 간단한 설명을 요구하며, 날짜, 이름, 장소 또는 사건과 같은 구체적인 세부사항에 대해 묻는 질문입니다.
|
||||
|
||||
사실 질문의 예는 다음과 같습니다:
|
||||
|
||||
• 프랑스의 수도는 어디입니까?
|
||||
• 전구를 발명한 사람은 누구입니까?
|
||||
• 위키피디아는 언제 설립되었습니까?
|
||||
|
||||
지침:
|
||||
1. 질문은 반드시 주어진 Text를 기반으로 작성되어야 합니다.
|
||||
2. 질문은 Text를 기반으로 가능한 한 구체적으로 작성되어야 합니다.
|
||||
3. Text에서 사실적 정보를 요구하는 질문을 만들어야 합니다. 즉, Text를 기반으로 사실 질문을 만드세요.
|
||||
4. 질문에 “주어진 Text에서” 또는 “제공된 단락에서”와 같은 표현을 포함해서는 안 됩니다.
|
||||
사용자는 질문의 출처가 Text라는 것을 모르기 때문에 반드시 그 출처를 언급해서는 안 됩니다.
|
||||
5. 파일 이름이나 파일 제목에 대한 질문을 하지 마세요. 파일의 내용에 대해 물어보세요.
|
||||
예를 들어, '문서의 파일 이름은 무엇입니까?'와 같은 질문을 작성하지 마세요.
|
||||
6. 질문을 한국어로 작성하세요.""",
|
||||
)
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""あなたは与えられたTextを「実は質問」に変換するAIです。
|
||||
|
||||
事実質問(factoid questions)とは、事実的な情報を求める質問であり、容易に検証できる回答を必要とします。 一般的に、「はい/いいえ」の返答や簡単な説明を要求し、日付、名前、場所、または事件のような具体的な詳細事項について尋ねる質問です。
|
||||
|
||||
実は質問の例は次の通りです:
|
||||
|
||||
• フランスの首都はどこですか?
|
||||
• 電球を発明したのは誰ですか?
|
||||
• ウィキペディアはいつ設立されましたか?
|
||||
|
||||
指針:
|
||||
1. 質問は、必ず与えられたTextに基づいて作成されなければなりません。
|
||||
2. 質問は、Textに基づいて可能な限り具体的に作成されなければなりません。
|
||||
3. Textで事実的情報を要求する質問を作らなければなりません。 つまり、Textに基づいて質問を作成します。
|
||||
4. 質問に「与えられたTextで」または「提供された段落で」のような表現を含めてはいけません。
|
||||
ユーザーは質問の出所がTextだということを知らないので、必ずしもその出所を言及してはいけません。
|
||||
5. ファイル名やファイルタイトルを訊かないでください。ファイルの内容について聞いてください。
|
||||
例えば、「このドキュメントのファイル名は何ですか?
|
||||
6. 質問を日本語で作成しなさい。""",
|
||||
)
|
||||
],
|
||||
},
|
||||
"concept_completion": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""You're an AI tasked to convert Text into a "Concept Completion" Question.
|
||||
A “concept completion” question asks directly about the essence or identity of a concept.
|
||||
|
||||
Follow the following instructions.
|
||||
Instructions:
|
||||
1. Questions MUST BE extracted from given Text
|
||||
2. Questions should be as detailed as possible from Text
|
||||
3. Create questions that ask about information from the Text
|
||||
4. MUST include specific keywords from the Text.
|
||||
5. Do not mention any of these in the questions: "in the given text", "in the provided information", etc.
|
||||
Users do not know the passage source of the question, so it should not be mentioned in the question.
|
||||
6. Do not ask about the file name or the file title. Ask about the content of the file.
|
||||
For example, avoid to write questions like `What is the file name of the document?""",
|
||||
)
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""당신은 Text를 “개념 완성” 질문으로 변환하는 AI입니다.
|
||||
"개념 완성" 질문은 개념의 본질이나 정체성에 대해 직접적으로 묻는 질문입니다.
|
||||
|
||||
다음 지시사항을 따르세요.
|
||||
지시사항:
|
||||
1. 질문은 반드시 주어진 Text를 기반으로 작성되어야 합니다.
|
||||
2. 질문은 Text를 기반으로 가능한 한 자세하게 작성되어야 합니다.
|
||||
3. Text에서 제공된 정보를 묻는 질문을 생성하세요.
|
||||
4. Text의 특정 키워드를 반드시 질문에 포함하세요.
|
||||
5. 질문에 “주어진 Text에서” 또는 “제공된 단락에서”와 같은 표현을 포함해서는 안 됩니다.
|
||||
사용자는 질문의 출처가 Text라는 것을 모르기 때문에 반드시 그 출처를 언급해서는 안 됩니다.
|
||||
6. 파일 이름이나 파일 제목에 대한 질문을 하지 마세요. 파일의 내용에 대해 물어보세요.
|
||||
예를 들어, '문서의 파일 이름은 무엇입니까?'와 같은 질문을 작성하지 마세요.
|
||||
7. 질문을 한국어로 작성하세요.""",
|
||||
)
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="""あなたはTextを「概念完成」の質問に変換するAIです。
|
||||
「概念完成」の質問は概念の本質やアイデンティティについて直接的に尋ねる質問です。
|
||||
|
||||
次の指示に従います。
|
||||
指示事項:
|
||||
1. 質問は、必ず与えられたTextに基づいて作成されなければなりません。
|
||||
2. 質問は、Textに基づいてできるだけ詳しく作成されなければなりません。
|
||||
3. Textで提供された情報を尋ねる質問を作成します。
|
||||
4. Textの特定のキーワードを必ず質問に含みます。
|
||||
5. 質問に「与えられたTextで」または「提供された段落で」のような表現を含めてはいけません。
|
||||
ユーザーは質問の出所がTextだということを知らないので、必ずしもその出所を言及してはいけません。
|
||||
6. ファイル名やファイルタイトルを訊かないでください。ファイルの内容について聞いてください。
|
||||
例えば、「このドキュメントのファイル名は何ですか?
|
||||
7. 質問を日本語で書きましょう。""",
|
||||
)
|
||||
],
|
||||
},
|
||||
"two_hop_incremental": {
|
||||
"en": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="Generate a multi-hop question for the given answer which requires reference to all of the given documents.",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Document 1: The Municipality of Nuevo Laredo is located in the Mexican state of Tamaulipas.
|
||||
Document 2: The Ciudad Deportiva (Sports City ¨ ¨) is a sports
|
||||
complex in Nuevo Laredo, Mexico. It is home to the Tecolotes de
|
||||
Nuevo Laredo Mexican Baseball League team and ...""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="""Answer: Tamaulipas
|
||||
One-hop question (using Document 1): In which Mexican state is Nuevo Laredo located?
|
||||
Two-hop question (using Document 2): In which Mexican state can one find the Ciudad Deportiva, home to the Tecolotes de Nuevo Laredo?""",
|
||||
),
|
||||
],
|
||||
"ko": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="Generate a multi-hop question for the given answer which requires reference to all of the given documents.",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Document 1: The Municipality of Nuevo Laredo is located in the Mexican state of Tamaulipas.
|
||||
Document 2: The Ciudad Deportiva (Sports City ¨ ¨) is a sports
|
||||
complex in Nuevo Laredo, Mexico. It is home to the Tecolotes de
|
||||
Nuevo Laredo Mexican Baseball League team and ...""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="""Answer: Tamaulipas
|
||||
One-hop question (using Document 1): In which Mexican state is Nuevo Laredo located?
|
||||
Two-hop question (using Document 2): In which Mexican state can one find the Ciudad Deportiva, home to the Tecolotes de Nuevo Laredo?""",
|
||||
),
|
||||
],
|
||||
"ja": [
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content="与えられた答えに対するマルチホップ質問を生成し、与えられたすべての文書を参照する必要があります。",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content="""Document 1: ヌエヴォ·ラレド自治体はメキシコのタマウリパス州にあります。
|
||||
Ciudad Deportiva(スポーツシティ ¨ ¨)はスポーツです
|
||||
メキシコのヌエボ·ラレドにある複合施設です。 テコロテス·デ·テコロテスの故郷です
|
||||
Nuevo Larredo メキシコ野球リーグのチームです···""",
|
||||
),
|
||||
ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="""Answer: Tamaulipas
|
||||
One-hop question (using Document 1): ヌエヴォ·ラレド自治体はどのメキシコの州にありますか?
|
||||
Two-hop question (using Document 2): ヌエヴォ·ラレドのテコロテス·デ·テコロテスの故郷であるメキシコの州はどこですか?""",
|
||||
),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Experimental feature
|
||||
QUERY_GEN_PROMPT_EXTRA = {
|
||||
"multiple_queries": {
|
||||
"en": "\nAdditional instructions:\n - Please make {n} questions.",
|
||||
"ko": "\n추가 지침:\n - 질문은 {n}개를 만드세요.",
|
||||
"ja": "\n追加指示:\n - 質問を{n}個作成してください。",
|
||||
}
|
||||
}
|
||||
26
autorag/data/qa/sample.py
Normal file
26
autorag/data/qa/sample.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
from typing import Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def random_single_hop(
|
||||
corpus_df: pd.DataFrame, n: int, random_state: int = 42
|
||||
) -> pd.DataFrame:
|
||||
sample_df = corpus_df.sample(n, random_state=random_state)
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"qid": [str(uuid.uuid4()) for _ in range(len(sample_df))],
|
||||
"retrieval_gt": [[[id_]] for id_ in sample_df["doc_id"].tolist()],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def range_single_hop(corpus_df: pd.DataFrame, idx_range: Iterable):
|
||||
sample_df = corpus_df.iloc[idx_range]
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"qid": [str(uuid.uuid4()) for _ in range(len(sample_df))],
|
||||
"retrieval_gt": [[[id_]] for id_ in sample_df["doc_id"].tolist()],
|
||||
}
|
||||
)
|
||||
322
autorag/data/qa/schema.py
Normal file
322
autorag/data/qa/schema.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Dict, Awaitable, Any, Tuple, List
|
||||
import uuid
|
||||
import pandas as pd
|
||||
from autorag.utils.util import process_batch, get_event_loop, fetch_contents
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Raw:
|
||||
"""
|
||||
The Raw class that stored document parsing results.
|
||||
It can do chunking.
|
||||
It has two column names, 'raw_id' and 'contents'.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_df: Optional[pd.DataFrame] = None):
|
||||
self.data = raw_df
|
||||
|
||||
def batch_apply(
|
||||
self, fn: Callable[[Dict, Any], Awaitable[Dict]], batch_size: int = 32, **kwargs
|
||||
) -> "Raw":
|
||||
raw_dicts = self.data.to_dict(orient="records")
|
||||
loop = get_event_loop()
|
||||
tasks = [fn(raw_dict, **kwargs) for raw_dict in raw_dicts]
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size))
|
||||
return Raw(pd.DataFrame(results))
|
||||
|
||||
def map(self, fn: Callable[[pd.DataFrame, Any], pd.DataFrame], **kwargs) -> "Raw":
|
||||
return Raw(fn(self.data, **kwargs))
|
||||
|
||||
def flatmap(self, fn: Callable, **kwargs) -> "Raw":
|
||||
return fn(self.data, **kwargs)
|
||||
|
||||
def chunk(self, module_name: str, **module_params) -> "Corpus":
|
||||
chunk_module = get_support_modules(module_name)
|
||||
chunked_result = chunk_module(parsed_result=self.data, **module_params)
|
||||
return Corpus(chunked_result, self)
|
||||
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, Raw), "You can only add Raw instances."
|
||||
self.data = pd.concat([self.data, other.data], ignore_index=True).reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class Corpus:
|
||||
"""
|
||||
The Corpus class that stored chunked passages.
|
||||
It can generate qa set, linked with Raw instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
corpus_df: Optional[pd.DataFrame] = None,
|
||||
linked_raw: Optional[Raw] = None,
|
||||
):
|
||||
self.data = corpus_df
|
||||
self._linked_raw = linked_raw
|
||||
|
||||
@property
|
||||
def linked_raw(self) -> Raw:
|
||||
return self._linked_raw
|
||||
|
||||
@linked_raw.setter
|
||||
def linked_raw(self, raw: Raw):
|
||||
raise NotImplementedError("linked_raw is read-only.")
|
||||
|
||||
def to_parquet(self, save_path: str):
|
||||
"""
|
||||
Save the corpus to the AutoRAG compatible parquet file.
|
||||
It is not for the data creation, for running AutoRAG.
|
||||
If you want to save it directly, use the below code.
|
||||
`corpus.data.to_parquet(save_path)`
|
||||
|
||||
:param save_path: The path to save the corpus.
|
||||
"""
|
||||
if not save_path.endswith(".parquet"):
|
||||
raise ValueError("save_path must be ended with .parquet")
|
||||
save_df = self.data.reset_index(drop=True)
|
||||
save_df.to_parquet(save_path)
|
||||
|
||||
def batch_apply(
|
||||
self, fn: Callable[[Dict, Any], Awaitable[Dict]], batch_size: int = 32, **kwargs
|
||||
) -> "Corpus":
|
||||
corpus_dicts = self.data.to_dict(orient="records")
|
||||
loop = get_event_loop()
|
||||
tasks = [fn(corpus_dict, **kwargs) for corpus_dict in corpus_dicts]
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size))
|
||||
return Corpus(pd.DataFrame(results), self.linked_raw)
|
||||
|
||||
def map(
|
||||
self, fn: Callable[[pd.DataFrame, Any], pd.DataFrame], **kwargs
|
||||
) -> "Corpus":
|
||||
return Corpus(fn(self.data, **kwargs), self.linked_raw)
|
||||
|
||||
def sample(self, fn: Callable[[pd.DataFrame, Any], pd.DataFrame], **kwargs) -> "QA":
|
||||
"""
|
||||
Sample the corpus for making QA.
|
||||
It selects the subset of the corpus and makes QA set from it.
|
||||
You can generate questions from the created question.
|
||||
It is the first step to make QA set from the corpus.
|
||||
If you select just one passage from each passage, it will be a single-hop QA set.
|
||||
If you select multiple passages from each passage, it will be a multi-hop QA set.
|
||||
|
||||
:param fn: The select function to perform.
|
||||
It returns QA dataframe.
|
||||
:return: QA instance that is selected.
|
||||
It contains qid and retrieval_gt columns.
|
||||
"""
|
||||
return QA(fn(self.data, **kwargs), self)
|
||||
|
||||
|
||||
class QA:
|
||||
def __init__(
|
||||
self,
|
||||
qa_df: Optional[pd.DataFrame] = None,
|
||||
linked_corpus: Optional[Corpus] = None,
|
||||
):
|
||||
self.data = qa_df
|
||||
self._linked_corpus = linked_corpus
|
||||
|
||||
@property
|
||||
def linked_corpus(self) -> Corpus:
|
||||
return self._linked_corpus
|
||||
|
||||
@linked_corpus.setter
|
||||
def linked_corpus(self, corpus: Corpus):
|
||||
raise NotImplementedError("linked_corpus is read-only.")
|
||||
|
||||
def batch_apply(
|
||||
self, fn: Callable[[Dict, Any], Awaitable[Dict]], batch_size: int = 32, **kwargs
|
||||
) -> "QA":
|
||||
qa_dicts = self.data.to_dict(orient="records")
|
||||
loop = get_event_loop()
|
||||
tasks = [fn(qa_dict, **kwargs) for qa_dict in qa_dicts]
|
||||
results = loop.run_until_complete(process_batch(tasks, batch_size))
|
||||
|
||||
# Experimental feature
|
||||
if fn.__name__ == "multiple_queries_gen":
|
||||
return self._process_multiple_queries_gen(results)
|
||||
|
||||
return QA(pd.DataFrame(results), self.linked_corpus)
|
||||
|
||||
def batch_filter(
|
||||
self, fn: Callable[[Dict, Any], Awaitable[bool]], batch_size: int = 32, **kwargs
|
||||
) -> "QA":
|
||||
qa_dicts = self.data.to_dict(orient="records")
|
||||
loop = get_event_loop()
|
||||
tasks = [fn(qa_dict, **kwargs) for qa_dict in qa_dicts]
|
||||
masks = loop.run_until_complete(process_batch(tasks, batch_size))
|
||||
return QA(self.data[masks], self.linked_corpus)
|
||||
|
||||
def filter(self, fn: Callable[[Dict, Any], bool], **kwargs) -> "QA":
|
||||
qa_dicts = self.data.to_dict(orient="records")
|
||||
masks = [fn(qa_dict, **kwargs) for qa_dict in qa_dicts]
|
||||
return QA(self.data[masks], self.linked_corpus)
|
||||
|
||||
def map(self, fn: Callable[[pd.DataFrame, Any], pd.DataFrame], **kwargs) -> "QA":
|
||||
return QA(fn(self.data, **kwargs), self.linked_corpus)
|
||||
|
||||
def make_retrieval_gt_contents(self) -> "QA":
|
||||
"""
|
||||
Make retrieval_gt_contents column from retrieval_gt column.
|
||||
:return: The QA instance that has a retrieval_gt_contents column.
|
||||
"""
|
||||
self.data["retrieval_gt_contents"] = self.data["retrieval_gt"].apply(
|
||||
lambda x: fetch_contents(self.linked_corpus.data, x)
|
||||
)
|
||||
return self
|
||||
|
||||
def to_parquet(self, qa_save_path: str, corpus_save_path: str):
|
||||
"""
|
||||
Save the qa and corpus to the AutoRAG compatible parquet file.
|
||||
It is not for the data creation, for running AutoRAG.
|
||||
If you want to save it directly, use the below code.
|
||||
`qa.data.to_parquet(save_path)`
|
||||
|
||||
:param qa_save_path: The path to save the qa dataset.
|
||||
:param corpus_save_path: The path to save the corpus.
|
||||
"""
|
||||
if not qa_save_path.endswith(".parquet"):
|
||||
raise ValueError("save_path must be ended with .parquet")
|
||||
if not corpus_save_path.endswith(".parquet"):
|
||||
raise ValueError("save_path must be ended with .parquet")
|
||||
save_df = self.data[
|
||||
["qid", "query", "retrieval_gt", "generation_gt"]
|
||||
].reset_index(drop=True)
|
||||
save_df.to_parquet(qa_save_path)
|
||||
self.linked_corpus.to_parquet(corpus_save_path)
|
||||
|
||||
def update_corpus(self, new_corpus: Corpus) -> "QA":
|
||||
"""
|
||||
Update linked corpus.
|
||||
Not just replace linked_corpus to the new Corpus,
|
||||
it replaces the whole `retrieval_gt` to the new corpus using `linked_raw`.
|
||||
The QA data must have a `retrieval_gt` column.
|
||||
|
||||
:param new_corpus: Corpus that you want to replace.
|
||||
Must have valid `linked_raw` and `raw_id`, `raw_start_idx`, `raw_end_idx` columns.
|
||||
:return: The QA instance that updated linked corpus.
|
||||
"""
|
||||
self.data["evidence_path"] = (
|
||||
self.data["retrieval_gt"]
|
||||
.apply(
|
||||
lambda x: fetch_contents(
|
||||
self.linked_corpus.data,
|
||||
x,
|
||||
column_name="path",
|
||||
)
|
||||
)
|
||||
.tolist()
|
||||
)
|
||||
self.data["evidence_page"] = self.data["retrieval_gt"].apply(
|
||||
lambda x: list(
|
||||
map(
|
||||
lambda lst: list(map(lambda x: x.get("page", -1), lst)),
|
||||
fetch_contents(self.linked_corpus.data, x, column_name="metadata"),
|
||||
)
|
||||
)
|
||||
)
|
||||
if "evidence_start_end_idx" not in self.data.columns:
|
||||
# make evidence start_end_idx
|
||||
self.data["evidence_start_end_idx"] = (
|
||||
self.data["retrieval_gt"]
|
||||
.apply(
|
||||
lambda x: fetch_contents(
|
||||
self.linked_corpus.data,
|
||||
x,
|
||||
column_name="start_end_idx",
|
||||
)
|
||||
)
|
||||
.tolist()
|
||||
)
|
||||
|
||||
# matching the new corpus with the old corpus
|
||||
path_corpus_dict = QA.__make_path_corpus_dict(new_corpus.data)
|
||||
new_retrieval_gt = self.data.apply(
|
||||
lambda row: QA.__match_index_row(
|
||||
row["evidence_start_end_idx"],
|
||||
row["evidence_path"],
|
||||
row["evidence_page"],
|
||||
path_corpus_dict,
|
||||
),
|
||||
axis=1,
|
||||
).tolist()
|
||||
new_qa = self.data.copy(deep=True)[["qid", "query", "generation_gt"]]
|
||||
new_qa["retrieval_gt"] = new_retrieval_gt
|
||||
return QA(new_qa, new_corpus)
|
||||
|
||||
@staticmethod
|
||||
def __match_index(target_idx: Tuple[int, int], dst_idx: Tuple[int, int]) -> bool:
|
||||
"""
|
||||
Check if the target_idx is overlap by the dst_idx.
|
||||
"""
|
||||
target_start, target_end = target_idx
|
||||
dst_start, dst_end = dst_idx
|
||||
return (
|
||||
dst_start <= target_start <= dst_end or dst_start <= target_end <= dst_end
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __match_index_row(
|
||||
evidence_indices: List[List[Tuple[int, int]]],
|
||||
evidence_paths: List[List[str]],
|
||||
evidence_pages: List[List[int]],
|
||||
path_corpus_dict: Dict,
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Find the matched passage from new_corpus.
|
||||
|
||||
:param evidence_indices: The evidence indices at the corresponding Raw.
|
||||
Its shape is the same as the retrieval_gt.
|
||||
:param evidence_paths: The evidence paths at the corresponding Raw.
|
||||
Its shape is the same as the retrieval_gt.
|
||||
:param path_corpus_dict: The key is the path name, and the value is the corpus dataframe that only contains the path in the key.
|
||||
You can make it using `QA.__make_path_corpus_dict`.
|
||||
:return:
|
||||
"""
|
||||
result = []
|
||||
for i, idx_list in enumerate(evidence_indices):
|
||||
sub_result = []
|
||||
for j, idx in enumerate(idx_list):
|
||||
path_corpus_df = path_corpus_dict[evidence_paths[i][j]]
|
||||
if evidence_pages[i][j] >= 0:
|
||||
path_corpus_df = path_corpus_df.loc[
|
||||
path_corpus_df["metadata"].apply(lambda x: x.get("page", -1))
|
||||
== evidence_pages[i][j]
|
||||
]
|
||||
matched_corpus = path_corpus_df.loc[
|
||||
path_corpus_df["start_end_idx"].apply(
|
||||
lambda x: QA.__match_index(idx, x)
|
||||
)
|
||||
]
|
||||
sub_result.extend(matched_corpus["doc_id"].tolist())
|
||||
result.append(sub_result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __make_path_corpus_dict(corpus_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
|
||||
return {
|
||||
path: corpus_df[corpus_df["path"] == path]
|
||||
for path in corpus_df["path"].unique()
|
||||
}
|
||||
|
||||
# Experimental feature
|
||||
def _process_multiple_queries_gen(self, results: List[Dict]) -> "QA":
|
||||
data = []
|
||||
for result in results:
|
||||
queries = result["query"].split("\n")
|
||||
for query in queries:
|
||||
new_result = {
|
||||
key: (str(uuid.uuid4()) if key == "qid" else result[key])
|
||||
for key in result.keys()
|
||||
}
|
||||
new_result["query"] = query
|
||||
data.append(new_result)
|
||||
df = pd.DataFrame(data)
|
||||
return QA(df, self.linked_corpus)
|
||||
0
autorag/data/utils/__init__.py
Normal file
0
autorag/data/utils/__init__.py
Normal file
103
autorag/data/utils/util.py
Normal file
103
autorag/data/utils/util.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Callable
|
||||
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from langchain_core.documents import Document
|
||||
from llama_index.core.schema import NodeRelationship
|
||||
|
||||
from autorag.schema import Module
|
||||
from autorag.utils.util import make_combinations, explode
|
||||
|
||||
|
||||
def get_file_metadata(file_path: str) -> Dict:
|
||||
"""Get some handy metadate from filesystem.
|
||||
|
||||
Args:
|
||||
file_path: str: file path in str
|
||||
"""
|
||||
return {
|
||||
"file_path": file_path,
|
||||
"file_name": os.path.basename(file_path),
|
||||
"file_type": mimetypes.guess_type(file_path)[0],
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"creation_datetime": datetime.fromtimestamp(
|
||||
Path(file_path).stat().st_ctime
|
||||
).strftime("%Y-%m-%d"),
|
||||
"last_modified_datetime": datetime.fromtimestamp(
|
||||
Path(file_path).stat().st_mtime
|
||||
).strftime("%Y-%m-%d"),
|
||||
"last_accessed_datetime": datetime.fromtimestamp(
|
||||
Path(file_path).stat().st_atime
|
||||
).strftime("%Y-%m-%d"),
|
||||
}
|
||||
|
||||
|
||||
def add_essential_metadata(metadata: Dict) -> Dict:
|
||||
if "last_modified_datetime" not in metadata:
|
||||
metadata["last_modified_datetime"] = datetime.now()
|
||||
return metadata
|
||||
|
||||
|
||||
def corpus_df_to_langchain_documents(corpus_df: pd.DataFrame) -> List[Document]:
|
||||
page_contents = corpus_df["contents"].tolist()
|
||||
ids = corpus_df["doc_id"].tolist()
|
||||
metadatas = corpus_df["metadata"].tolist()
|
||||
return list(
|
||||
map(
|
||||
lambda x: Document(page_content=x[0], metadata={"filename": x[1], **x[2]}),
|
||||
zip(page_contents, ids, metadatas),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def add_essential_metadata_llama_text_node(metadata: Dict, relationships: Dict) -> Dict:
|
||||
if "last_modified_datetime" not in metadata:
|
||||
metadata["last_modified_datetime"] = datetime.now()
|
||||
|
||||
if "prev_id" not in metadata:
|
||||
if NodeRelationship.PREVIOUS in relationships:
|
||||
prev_node = relationships.get(NodeRelationship.PREVIOUS, None)
|
||||
if prev_node:
|
||||
metadata["prev_id"] = prev_node.node_id
|
||||
|
||||
if "next_id" not in metadata:
|
||||
if NodeRelationship.NEXT in relationships:
|
||||
next_node = relationships.get(NodeRelationship.NEXT, None)
|
||||
if next_node:
|
||||
metadata["next_id"] = next_node.node_id
|
||||
return metadata
|
||||
|
||||
|
||||
def load_yaml(yaml_path: str):
|
||||
if not os.path.exists(yaml_path):
|
||||
raise ValueError(f"YAML file {yaml_path} does not exist.")
|
||||
with open(yaml_path, "r", encoding="utf-8") as stream:
|
||||
try:
|
||||
yaml_dict = yaml.safe_load(stream)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"YAML file {yaml_path} could not be loaded.") from exc
|
||||
return yaml_dict["modules"]
|
||||
|
||||
|
||||
def get_param_combinations(modules: List[Dict]) -> Tuple[List[Callable], List[Dict]]:
|
||||
module_callable_list, module_params_list = [], []
|
||||
for module in modules:
|
||||
module_instance = Module.from_dict(module)
|
||||
module_params_list.append(module_instance.module_param)
|
||||
module_callable_list.append(module_instance.module)
|
||||
|
||||
combinations = list(map(make_combinations, module_params_list))
|
||||
module_list, combination_list = explode(module_callable_list, combinations)
|
||||
return module_list, combination_list
|
||||
|
||||
|
||||
def get_start_end_idx(original_text: str, search_str: str) -> Tuple[int, int]:
|
||||
start_idx = original_text.find(search_str)
|
||||
if start_idx == -1:
|
||||
return 0, 0
|
||||
end_idx = start_idx + len(search_str)
|
||||
return start_idx, end_idx - 1
|
||||
Reference in New Issue
Block a user