Fix Dockerfile build issue
This commit is contained in:
1
autorag/data/parse/__init__.py
Normal file
1
autorag/data/parse/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .langchain_parse import langchain_parse
|
||||
79
autorag/data/parse/base.py
Normal file
79
autorag/data/parse/base.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from glob import glob
|
||||
from typing import Tuple, List, Optional
|
||||
import os
|
||||
|
||||
from autorag.utils import result_to_dataframe
|
||||
from autorag.data.utils.util import get_file_metadata
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
def parser_node(func):
|
||||
@functools.wraps(func)
|
||||
@result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
|
||||
def wrapper(
|
||||
data_path_glob: str,
|
||||
file_type: str,
|
||||
parse_method: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[str], List[str], List[int], List[datetime]]:
|
||||
logger.info(f"Running parser - {func.__name__} module...")
|
||||
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
raise FileNotFoundError(f"data does not exits in {data_path_glob}")
|
||||
|
||||
assert file_type in [
|
||||
"pdf",
|
||||
"csv",
|
||||
"json",
|
||||
"md",
|
||||
"html",
|
||||
"xml",
|
||||
"all_files",
|
||||
], f"search type {file_type} is not supported"
|
||||
|
||||
# extract only files from data_path_list based on the file_type set in the YAML file
|
||||
data_paths = (
|
||||
[
|
||||
data_path
|
||||
for data_path in data_path_list
|
||||
if os.path.basename(data_path).split(".")[-1] == file_type
|
||||
]
|
||||
if file_type != "all_files"
|
||||
else data_path_list
|
||||
)
|
||||
|
||||
if func.__name__ == "langchain_parse":
|
||||
parse_method = parse_method.lower()
|
||||
if parse_method == "directory":
|
||||
path_split_list = data_path_glob.split("/")
|
||||
glob_path = path_split_list.pop()
|
||||
folder_path = "/".join(path_split_list)
|
||||
kwargs.update({"glob": glob_path, "path": folder_path})
|
||||
result = func(
|
||||
data_path_list=data_paths, parse_method=parse_method, **kwargs
|
||||
)
|
||||
else:
|
||||
result = func(
|
||||
data_path_list=data_paths, parse_method=parse_method, **kwargs
|
||||
)
|
||||
elif func.__name__ in ["clova_ocr", "llama_parse", "table_hybrid_parse"]:
|
||||
result = func(data_path_list=data_paths, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported module_type: {func.__name__}")
|
||||
result = _add_last_modified_datetime(result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _add_last_modified_datetime(result):
|
||||
last_modified_datetime_lst = list(
|
||||
map(lambda x: get_file_metadata(x)["last_modified_datetime"], result[1])
|
||||
)
|
||||
result_with_dates = result + (last_modified_datetime_lst,)
|
||||
return result_with_dates
|
||||
194
autorag/data/parse/clova.py
Normal file
194
autorag/data/parse/clova.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import base64
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import fitz # PyMuPDF
|
||||
|
||||
from autorag.data.parse.base import parser_node
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
|
||||
|
||||
@parser_node
|
||||
def clova_ocr(
|
||||
data_path_list: List[str],
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
batch: int = 5,
|
||||
table_detection: bool = False,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use Naver Clova OCR.
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param url: The URL for Clova OCR.
|
||||
You can get the URL with the guide at https://guide.ncloud-docs.com/docs/clovaocr-example01
|
||||
You can set the environment variable CLOVA_URL, or you can set it directly as a parameter.
|
||||
:param api_key: The API key for Clova OCR.
|
||||
You can get the API key with the guide at https://guide.ncloud-docs.com/docs/clovaocr-example01
|
||||
You can set the environment variable CLOVA_API_KEY, or you can set it directly as a parameter.
|
||||
:param batch: The batch size for parse documents. Default is 8.
|
||||
:param table_detection: Whether to enable table detection. Default is False.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
url = os.getenv("CLOVA_URL", None) if url is None else url
|
||||
if url is None:
|
||||
raise KeyError(
|
||||
"Please set the URL for Clova OCR in the environment variable CLOVA_URL "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
|
||||
api_key = os.getenv("CLOVA_API_KEY", None) if api_key is None else api_key
|
||||
if api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the API key for Clova OCR in the environment variable CLOVA_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
if batch > 5:
|
||||
raise ValueError("The batch size should be less than or equal to 5.")
|
||||
|
||||
image_data_lst = list(
|
||||
map(lambda data_path: pdf_to_images(data_path), data_path_list)
|
||||
)
|
||||
image_info_lst = [
|
||||
generate_image_info(pdf_path, len(image_data))
|
||||
for pdf_path, image_data in zip(data_path_list, image_data_lst)
|
||||
]
|
||||
|
||||
image_data_list = list(itertools.chain(*image_data_lst))
|
||||
image_info_list = list(itertools.chain(*image_info_lst))
|
||||
|
||||
tasks = [
|
||||
clova_ocr_pure(image_data, image_info, url, api_key, table_detection)
|
||||
for image_data, image_info in zip(image_data_list, image_info_list)
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
texts, path, pages = zip(*results)
|
||||
return list(texts), list(path), list(pages)
|
||||
|
||||
|
||||
async def clova_ocr_pure(
|
||||
image_data: bytes,
|
||||
image_info: dict,
|
||||
url: str,
|
||||
api_key: str,
|
||||
table_detection: bool = False,
|
||||
) -> Tuple[str, str, int]:
|
||||
session = aiohttp.ClientSession()
|
||||
table_html = ""
|
||||
headers = {"X-OCR-SECRET": api_key, "Content-Type": "application/json"}
|
||||
|
||||
# Convert image data to base64
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Set data
|
||||
data = {
|
||||
"version": "V2",
|
||||
"requestId": "sample_id",
|
||||
"timestamp": 0,
|
||||
"images": [{"format": "png", "name": "sample_image", "data": image_base64}],
|
||||
"enableTableDetection": table_detection,
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, data=json.dumps(data)) as response:
|
||||
resp_json = await response.json()
|
||||
if "images" not in resp_json:
|
||||
raise RuntimeError(
|
||||
f"Invalid response from Clova API: {resp_json['detail']}"
|
||||
)
|
||||
if "tables" in resp_json["images"][0].keys():
|
||||
table_html = json_to_html_table(
|
||||
resp_json["images"][0]["tables"][0]["cells"]
|
||||
)
|
||||
page_text = extract_text_from_fields(resp_json["images"][0]["fields"])
|
||||
|
||||
if table_html:
|
||||
page_text += f"\n\ntable html:\n{table_html}"
|
||||
|
||||
await session.close()
|
||||
return page_text, image_info["pdf_path"], image_info["pdf_page"]
|
||||
|
||||
|
||||
def pdf_to_images(pdf_path: str) -> List[bytes]:
|
||||
"""Convert each page of the PDF to an image and return the image data."""
|
||||
pdf_document = fitz.open(pdf_path)
|
||||
image_data_lst = []
|
||||
for page_num in range(len(pdf_document)):
|
||||
page = pdf_document.load_page(page_num)
|
||||
pix = page.get_pixmap()
|
||||
img_data = pix.tobytes("png")
|
||||
image_data_lst.append(img_data)
|
||||
return image_data_lst
|
||||
|
||||
|
||||
def generate_image_info(pdf_path: str, num_pages: int) -> List[dict]:
|
||||
"""Generate image names based on the PDF file name and the number of pages."""
|
||||
image_info_lst = [
|
||||
{"pdf_path": pdf_path, "pdf_page": page_num + 1}
|
||||
for page_num in range(num_pages)
|
||||
]
|
||||
return image_info_lst
|
||||
|
||||
|
||||
def extract_text_from_fields(fields):
|
||||
text = ""
|
||||
for field in fields:
|
||||
text += field["inferText"]
|
||||
if field["lineBreak"]:
|
||||
text += "\n"
|
||||
else:
|
||||
text += " "
|
||||
return text.strip()
|
||||
|
||||
|
||||
def json_to_html_table(json_data):
|
||||
# Initialize the HTML table
|
||||
html = '<table border="1">\n'
|
||||
# Determine the number of rows and columns
|
||||
max_row = max(cell["rowIndex"] + cell["rowSpan"] for cell in json_data)
|
||||
max_col = max(cell["columnIndex"] + cell["columnSpan"] for cell in json_data)
|
||||
# Create a 2D array to keep track of merged cells
|
||||
table = [["" for _ in range(max_col)] for _ in range(max_row)]
|
||||
# Fill the table with cell data
|
||||
for cell in json_data:
|
||||
row = cell["rowIndex"]
|
||||
col = cell["columnIndex"]
|
||||
row_span = cell["rowSpan"]
|
||||
col_span = cell["columnSpan"]
|
||||
cell_text = (
|
||||
" ".join(
|
||||
line["inferText"] for line in cell["cellTextLines"][0]["cellWords"]
|
||||
)
|
||||
if cell["cellTextLines"]
|
||||
else ""
|
||||
)
|
||||
# Place the cell in the table
|
||||
table[row][col] = {"text": cell_text, "rowSpan": row_span, "colSpan": col_span}
|
||||
# Mark merged cells as occupied
|
||||
for r in range(row, row + row_span):
|
||||
for c in range(col, col + col_span):
|
||||
if r != row or c != col:
|
||||
table[r][c] = None
|
||||
# Generate HTML from the table array
|
||||
for row in table:
|
||||
html += " <tr>\n"
|
||||
for cell in row:
|
||||
if cell is None:
|
||||
continue
|
||||
if cell == "":
|
||||
html += " <td></td>\n"
|
||||
else:
|
||||
row_span_attr = (
|
||||
f' rowspan="{cell["rowSpan"]}"' if cell["rowSpan"] > 1 else ""
|
||||
)
|
||||
col_span_attr = (
|
||||
f' colspan="{cell["colSpan"]}"' if cell["colSpan"] > 1 else ""
|
||||
)
|
||||
html += f' <td{row_span_attr}{col_span_attr}>{cell["text"]}</td>\n'
|
||||
html += " </tr>\n"
|
||||
html += "</table>"
|
||||
return html
|
||||
87
autorag/data/parse/langchain_parse.py
Normal file
87
autorag/data/parse/langchain_parse.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import multiprocessing as mp
|
||||
from itertools import chain
|
||||
from typing import List, Tuple
|
||||
|
||||
from autorag.data import parse_modules
|
||||
from autorag.data.parse.base import parser_node
|
||||
|
||||
|
||||
@parser_node
|
||||
def langchain_parse(
|
||||
data_path_list: List[str], parse_method: str, **kwargs
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use langchain document_loaders(parse) method
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param parse_method: A langchain document_loaders(parse) method to use.
|
||||
:param kwargs: The extra parameters for creating the langchain document_loaders(parse) instance.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
if parse_method in ["directory", "unstructured"]:
|
||||
results = parse_all_files(data_path_list, parse_method, **kwargs)
|
||||
texts, path = results[0], results[1]
|
||||
pages = [-1] * len(texts)
|
||||
|
||||
else:
|
||||
num_workers = mp.cpu_count()
|
||||
# Execute parallel processing
|
||||
with mp.Pool(num_workers) as pool:
|
||||
results = pool.starmap(
|
||||
langchain_parse_pure,
|
||||
[(data_path, parse_method, kwargs) for data_path in data_path_list],
|
||||
)
|
||||
|
||||
texts, path, pages = (list(chain.from_iterable(item)) for item in zip(*results))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def langchain_parse_pure(
|
||||
data_path: str, parse_method: str, kwargs
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parses a single file using the specified parse method.
|
||||
|
||||
Args:
|
||||
data_path (str): The file path to parse.
|
||||
parse_method (str): The parsing method to use.
|
||||
kwargs (Dict): Additional keyword arguments for the parsing method.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the parsed text and the file path.
|
||||
"""
|
||||
|
||||
parse_instance = parse_modules[parse_method](data_path, **kwargs)
|
||||
|
||||
# Load the text from the file
|
||||
documents = parse_instance.load()
|
||||
|
||||
texts = list(map(lambda x: x.page_content, documents))
|
||||
path = [data_path] * len(texts)
|
||||
if parse_method in ["pymupdf", "pdfplumber", "pypdf", "pypdfium2"]:
|
||||
pages = list(range(1, len(documents) + 1))
|
||||
else:
|
||||
pages = [-1] * len(texts)
|
||||
|
||||
# Clean up the parse instance
|
||||
del parse_instance
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def parse_all_files(
|
||||
data_path_list: List[str], parse_method: str, **kwargs
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
if parse_method == "unstructured":
|
||||
parse_instance = parse_modules[parse_method](data_path_list, **kwargs)
|
||||
elif parse_method == "directory":
|
||||
parse_instance = parse_modules[parse_method](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported parse method: {parse_method}")
|
||||
docs = parse_instance.load()
|
||||
texts = [doc.page_content for doc in docs]
|
||||
file_names = [doc.metadata["source"] for doc in docs]
|
||||
|
||||
del parse_instance
|
||||
return texts, file_names
|
||||
126
autorag/data/parse/llamaparse.py
Normal file
126
autorag/data/parse/llamaparse.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from itertools import chain
|
||||
|
||||
from llama_parse import LlamaParse
|
||||
|
||||
from autorag.data.parse.base import parser_node
|
||||
from autorag.utils.util import process_batch, get_event_loop
|
||||
|
||||
|
||||
@parser_node
|
||||
def llama_parse(
|
||||
data_path_list: List[str],
|
||||
batch: int = 8,
|
||||
use_vendor_multimodal_model: bool = False,
|
||||
vendor_multimodal_model_name: str = "openai-gpt4o",
|
||||
use_own_key: bool = False,
|
||||
vendor_multimodal_api_key: str = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use llama_parse.
|
||||
LLAMA_CLOUD_API_KEY environment variable should be set.
|
||||
You can get the key from https://cloud.llamaindex.ai/api-key
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param batch: The batch size for parse documents. Default is 8.
|
||||
:param use_vendor_multimodal_model: Whether to use the vendor multimodal model. Default is False.
|
||||
:param vendor_multimodal_model_name: The name of the vendor multimodal model. Default is "openai-gpt4o".
|
||||
:param use_own_key: Whether to use the own API key. Default is False.
|
||||
:param vendor_multimodal_api_key: The API key for the vendor multimodal model.
|
||||
:param kwargs: The extra parameters for creating the llama_parse instance.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
if use_vendor_multimodal_model:
|
||||
kwargs = _add_multimodal_params(
|
||||
kwargs,
|
||||
use_vendor_multimodal_model,
|
||||
vendor_multimodal_model_name,
|
||||
use_own_key,
|
||||
vendor_multimodal_api_key,
|
||||
)
|
||||
|
||||
parse_instance = LlamaParse(**kwargs)
|
||||
|
||||
tasks = [
|
||||
llama_parse_pure(data_path, parse_instance) for data_path in data_path_list
|
||||
]
|
||||
loop = get_event_loop()
|
||||
results = loop.run_until_complete(process_batch(tasks, batch))
|
||||
|
||||
del parse_instance
|
||||
|
||||
texts, path, pages = (list(chain.from_iterable(item)) for item in zip(*results))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
async def llama_parse_pure(
|
||||
data_path: str, parse_instance
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
documents = await parse_instance.aload_data(data_path)
|
||||
|
||||
texts = list(map(lambda x: x.text, documents))
|
||||
path = [data_path] * len(texts)
|
||||
pages = list(range(1, len(documents) + 1))
|
||||
|
||||
return texts, path, pages
|
||||
|
||||
|
||||
def _add_multimodal_params(
|
||||
kwargs,
|
||||
use_vendor_multimodal_model,
|
||||
vendor_multimodal_model_name,
|
||||
use_own_key,
|
||||
vendor_multimodal_api_key,
|
||||
) -> dict:
|
||||
kwargs["use_vendor_multimodal_model"] = use_vendor_multimodal_model
|
||||
kwargs["vendor_multimodal_model_name"] = vendor_multimodal_model_name
|
||||
|
||||
def set_multimodal_api_key(
|
||||
multimodal_model_name: str = "openai-gpt4o", _api_key: str = None
|
||||
) -> str:
|
||||
if multimodal_model_name in ["openai-gpt4o", "openai-gpt-4o-mini"]:
|
||||
_api_key = (
|
||||
os.getenv("OPENAI_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the OPENAI_API_KEY in the environment variable OPENAI_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["anthropic-sonnet-3.5"]:
|
||||
_api_key = (
|
||||
os.getenv("ANTHROPIC_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the ANTHROPIC_API_KEY in the environment variable ANTHROPIC_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["gemini-1.5-flash", "gemini-1.5-pro"]:
|
||||
_api_key = (
|
||||
os.getenv("GEMINI_API_KEY", None) if _api_key is None else _api_key
|
||||
)
|
||||
if _api_key is None:
|
||||
raise KeyError(
|
||||
"Please set the GEMINI_API_KEY in the environment variable GEMINI_API_KEY "
|
||||
"or directly set it on the config YAML file."
|
||||
)
|
||||
elif multimodal_model_name in ["custom-azure-model"]:
|
||||
raise NotImplementedError(
|
||||
"Custom Azure multimodal model is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid multimodal model name.")
|
||||
|
||||
return _api_key
|
||||
|
||||
if use_own_key:
|
||||
api_key = set_multimodal_api_key(
|
||||
vendor_multimodal_model_name, vendor_multimodal_api_key
|
||||
)
|
||||
kwargs["vendor_multimodal_api_key"] = api_key
|
||||
|
||||
return kwargs
|
||||
141
autorag/data/parse/run.py
Normal file
141
autorag/data/parse/run.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import os
|
||||
from typing import List, Callable, Dict
|
||||
import pandas as pd
|
||||
from glob import glob
|
||||
|
||||
from autorag.strategy import measure_speed
|
||||
from autorag.data.utils.util import get_param_combinations
|
||||
|
||||
default_map = {
|
||||
"pdf": {
|
||||
"file_type": "pdf",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "pdfminer",
|
||||
},
|
||||
"csv": {
|
||||
"file_type": "csv",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "csv",
|
||||
},
|
||||
"md": {
|
||||
"file_type": "md",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "unstructuredmarkdown",
|
||||
},
|
||||
"html": {
|
||||
"file_type": "html",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "bshtml",
|
||||
},
|
||||
"xml": {
|
||||
"file_type": "xml",
|
||||
"module_type": "langchain_parse",
|
||||
"parse_method": "unstructuredxml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_parser(
|
||||
modules: List[Callable],
|
||||
module_params: List[Dict],
|
||||
data_path_glob: str,
|
||||
project_dir: str,
|
||||
all_files: bool,
|
||||
):
|
||||
if not all_files:
|
||||
# Set the parsing module to default if it is a file type in paths but not set in YAML.
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
raise FileNotFoundError(f"data does not exits in {data_path_glob}")
|
||||
|
||||
file_types = set(
|
||||
[os.path.basename(data_path).split(".")[-1] for data_path in data_path_list]
|
||||
)
|
||||
set_file_types = set([module["file_type"] for module in module_params])
|
||||
|
||||
# Calculate the set difference once
|
||||
file_types_to_remove = set_file_types - file_types
|
||||
|
||||
# Use list comprehension to filter out unwanted elements
|
||||
module_params = [
|
||||
param
|
||||
for param in module_params
|
||||
if param["file_type"] not in file_types_to_remove
|
||||
]
|
||||
modules = [
|
||||
module
|
||||
for module, param in zip(modules, module_params)
|
||||
if param["file_type"] not in file_types_to_remove
|
||||
]
|
||||
|
||||
# create a list of only those file_types that are in file_types but not in set_file_types
|
||||
missing_file_types = list(file_types - set_file_types)
|
||||
|
||||
if missing_file_types:
|
||||
add_modules_list = []
|
||||
for missing_file_type in missing_file_types:
|
||||
if missing_file_type == "json":
|
||||
raise ValueError(
|
||||
"JSON file type must have a jq_schema so you must set it in the YAML file."
|
||||
)
|
||||
|
||||
add_modules_list.append(default_map[missing_file_type])
|
||||
|
||||
add_modules, add_params = get_param_combinations(add_modules_list)
|
||||
modules.extend(add_modules)
|
||||
module_params.extend(add_params)
|
||||
|
||||
results, execution_times = zip(
|
||||
*map(
|
||||
lambda x: measure_speed(x[0], data_path_glob=data_path_glob, **x[1]),
|
||||
zip(modules, module_params),
|
||||
)
|
||||
)
|
||||
average_times = list(map(lambda x: x / len(results[0]), execution_times))
|
||||
|
||||
# save results to parquet files
|
||||
if all_files:
|
||||
if len(module_params) > 1:
|
||||
raise ValueError(
|
||||
"All files is set to True, You can only use one parsing module."
|
||||
)
|
||||
filepaths = [os.path.join(project_dir, "parsed_result.parquet")]
|
||||
else:
|
||||
filepaths = list(
|
||||
map(
|
||||
lambda x: os.path.join(project_dir, f"{x['file_type']}.parquet"),
|
||||
module_params,
|
||||
)
|
||||
)
|
||||
|
||||
_files = {}
|
||||
for result, filepath in zip(results, filepaths):
|
||||
_files[filepath].append(result) if filepath in _files.keys() else _files.update(
|
||||
{filepath: [result]}
|
||||
)
|
||||
# Save files with a specific file type as Parquet files.
|
||||
for filepath, value in _files.items():
|
||||
pd.concat(value).to_parquet(filepath, index=False)
|
||||
|
||||
filenames = list(map(lambda x: os.path.basename(x), filepaths))
|
||||
|
||||
summary_df = pd.DataFrame(
|
||||
{
|
||||
"filename": filenames,
|
||||
"module_name": list(map(lambda module: module.__name__, modules)),
|
||||
"module_params": module_params,
|
||||
"execution_time": average_times,
|
||||
}
|
||||
)
|
||||
summary_df.to_csv(os.path.join(project_dir, "summary.csv"), index=False)
|
||||
|
||||
# concat all parquet files here if not all_files.
|
||||
_filepaths = list(_files.keys())
|
||||
if not all_files:
|
||||
dataframes = [pd.read_parquet(file) for file in _filepaths]
|
||||
combined_df = pd.concat(dataframes, ignore_index=True)
|
||||
combined_df.to_parquet(
|
||||
os.path.join(project_dir, "parsed_result.parquet"), index=False
|
||||
)
|
||||
|
||||
return summary_df
|
||||
134
autorag/data/parse/table_hybrid_parse.py
Normal file
134
autorag/data/parse/table_hybrid_parse.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
import tempfile
|
||||
from glob import glob
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
from PyPDF2 import PdfFileReader, PdfFileWriter
|
||||
import pdfplumber
|
||||
|
||||
from autorag.support import get_support_modules
|
||||
from autorag.data.parse.base import parser_node
|
||||
|
||||
|
||||
@parser_node
|
||||
def table_hybrid_parse(
|
||||
data_path_list: List[str],
|
||||
text_parse_module: str,
|
||||
text_params: Dict,
|
||||
table_parse_module: str,
|
||||
table_params: Dict,
|
||||
) -> Tuple[List[str], List[str], List[int]]:
|
||||
"""
|
||||
Parse documents to use table_hybrid_parse method.
|
||||
The table_hybrid_parse method is a hybrid method that combines the parsing results of PDFs with and without tables.
|
||||
It splits the PDF file into pages, separates pages with and without tables, and then parses and merges the results.
|
||||
|
||||
:param data_path_list: The list of data paths to parse.
|
||||
:param text_parse_module: The text parsing module to use. The type should be a string.
|
||||
:param text_params: The extra parameters for the text parsing module. The type should be a dictionary.
|
||||
:param table_parse_module: The table parsing module to use. The type should be a string.
|
||||
:param table_params: The extra parameters for the table parsing module. The type should be a dictionary.
|
||||
:return: tuple of lists containing the parsed texts, path and pages.
|
||||
"""
|
||||
# make save folder directory
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as save_dir:
|
||||
text_dir = os.path.join(save_dir, "text")
|
||||
table_dir = os.path.join(save_dir, "table")
|
||||
|
||||
os.makedirs(text_dir, exist_ok=True)
|
||||
os.makedirs(table_dir, exist_ok=True)
|
||||
|
||||
# Split PDF file into pages and Save PDFs with and without tables
|
||||
path_map_dict_lst = [
|
||||
save_page_by_table(data_path, text_dir, table_dir)
|
||||
for data_path in data_path_list
|
||||
]
|
||||
path_map_dict = {k: v for d in path_map_dict_lst for k, v in d.items()}
|
||||
|
||||
# Extract text pages
|
||||
table_results, table_file_path = get_each_module_result(
|
||||
table_parse_module, table_params, os.path.join(table_dir, "*")
|
||||
)
|
||||
|
||||
# Extract table pages
|
||||
text_results, text_file_path = get_each_module_result(
|
||||
text_parse_module, text_params, os.path.join(text_dir, "*")
|
||||
)
|
||||
|
||||
# Merge parsing results of PDFs with and without tables
|
||||
texts = table_results + text_results
|
||||
temp_path_lst = table_file_path + text_file_path
|
||||
|
||||
# Sort by file names
|
||||
temp_path_lst, texts = zip(*sorted(zip(temp_path_lst, texts)))
|
||||
|
||||
# get original file path
|
||||
path = list(map(lambda temp_path: path_map_dict[temp_path], temp_path_lst))
|
||||
|
||||
# get pages
|
||||
pages = list(map(lambda x: get_page_from_path(x), temp_path_lst))
|
||||
|
||||
return list(texts), path, pages
|
||||
|
||||
|
||||
# Save PDFs with and without tables
|
||||
def save_page_by_table(data_path: str, text_dir: str, table_dir: str) -> Dict[str, str]:
|
||||
file_name = os.path.basename(data_path).split(".pdf")[0]
|
||||
|
||||
with open(data_path, "rb") as input_data:
|
||||
pdf_reader = PdfFileReader(input_data)
|
||||
num_pages = pdf_reader.getNumPages()
|
||||
|
||||
path_map_dict = {}
|
||||
for page_num in range(num_pages):
|
||||
output_pdf_path = _get_output_path(
|
||||
data_path, page_num, file_name, text_dir, table_dir
|
||||
)
|
||||
_save_single_page(pdf_reader, page_num, output_pdf_path)
|
||||
path_map_dict.update({output_pdf_path: data_path})
|
||||
|
||||
return path_map_dict
|
||||
|
||||
|
||||
def _get_output_path(
|
||||
data_path: str, page_num: int, file_name: str, text_dir: str, table_dir: str
|
||||
) -> str:
|
||||
with pdfplumber.open(data_path) as pdf:
|
||||
page = pdf.pages[page_num]
|
||||
tables = page.extract_tables()
|
||||
directory = table_dir if tables else text_dir
|
||||
return os.path.join(directory, f"{file_name}_page_{page_num + 1}.pdf")
|
||||
|
||||
|
||||
def _save_single_page(pdf_reader: PdfFileReader, page_num: int, output_pdf_path: str):
|
||||
pdf_writer = PdfFileWriter()
|
||||
pdf_writer.addPage(pdf_reader.getPage(page_num))
|
||||
|
||||
with open(output_pdf_path, "wb") as output_file:
|
||||
pdf_writer.write(output_file)
|
||||
|
||||
|
||||
def get_each_module_result(
|
||||
module: str, module_params: Dict, data_path_glob: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
module_params["module_type"] = module
|
||||
|
||||
data_path_list = glob(data_path_glob)
|
||||
if not data_path_list:
|
||||
return [], []
|
||||
|
||||
module_name = module_params.pop("module_type")
|
||||
module_callable = get_support_modules(module_name)
|
||||
module_original = module_callable.__wrapped__
|
||||
texts, path, _ = module_original(data_path_list, **module_params)
|
||||
|
||||
return texts, path
|
||||
|
||||
|
||||
def get_page_from_path(file_path: str) -> int:
|
||||
file_name = os.path.basename(file_path)
|
||||
split_result = file_name.rsplit("_page_", -1)
|
||||
page_number_with_extension = split_result[1]
|
||||
page_number, _ = page_number_with_extension.split(".")
|
||||
|
||||
return int(page_number)
|
||||
Reference in New Issue
Block a user