입출력 로직 변경

This commit is contained in:
kyy
2025-10-27 15:36:17 +09:00
parent 758b9afe9a
commit 64550b1fd5
4 changed files with 310 additions and 212 deletions

View File

@@ -22,7 +22,8 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = "/workspace/2018-0802140959-217049.pdf"
PDF_INPUT_PATH = "/workspace/2018-0802140959-217049.pdf"
IMAGE_INPUT_PATH = "/workspace/20250730180509-798-917-821.jpg"
OUTPUT_PATH = "/workspace/output/"
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."

View File

@@ -0,0 +1,24 @@
import os
import argparse
import subprocess
def main():
parser = argparse.ArgumentParser(description="Run OCR based on file type.")
parser.add_argument("input_path", type=str, help="Path to the input file (PDF or image).")
args = parser.parse_args()
input_path = args.input_path
file_extension = os.path.splitext(input_path)[1].lower()
if file_extension == '.pdf':
print(f"Detected PDF file. Running PDF OCR script for: {input_path}")
subprocess.run(["python", "run_dpsk_ocr_pdf.py", "--input", input_path])
elif file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
print(f"Detected image file. Running image OCR script for: {input_path}")
subprocess.run(["python", "run_dpsk_ocr_image.py", "--input", input_path])
else:
print(f"Unsupported file type: {file_extension}")
print("Please provide a PDF or an image file (.jpg, .jpeg, .png, .bmp, .gif).")
if __name__ == "__main__":
main()

View File

@@ -1,39 +1,42 @@
import argparse
import asyncio
import re
import os
import json
import re
import torch
if torch.version.cuda == '11.8':
if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
import numpy as np
from config import CROP_MODE, IMAGE_INPUT_PATH, MODEL_PATH, OUTPUT_PATH, PROMPT
from PIL import Image, ImageDraw, ImageFont, ImageOps
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry
import time
from deepseek_ocr import DeepseekOCRForCausalLM
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from tqdm import tqdm
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path):
def load_image(image_path):
try:
image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image)
return corrected_image
except Exception as e:
print(f"error: {e}")
try:
@@ -43,14 +46,13 @@ def load_image(image_path):
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
@@ -58,8 +60,6 @@ def re_match(text):
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
@@ -71,28 +71,31 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
def draw_bounding_boxes(image, refs):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
@@ -102,7 +105,7 @@ def draw_bounding_boxes(image, refs):
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
@@ -110,24 +113,36 @@ def draw_bounding_boxes(image, refs):
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
@@ -142,24 +157,24 @@ def process_image_with_refs(image, ref_texts):
return result_image
async def stream_generate(image=None, prompt=''):
async def stream_generate(image=None, prompt=""):
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td>
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822}
)
] # whitelist: <td>, </td>
sampling_params = SamplingParams(
temperature=0.0,
@@ -167,137 +182,157 @@ async def stream_generate(image=None, prompt=''):
logits_processors=logits_processors,
skip_special_tokens=False,
# ignore_eos=False,
)
request_id = f"request-{int(time.time())}"
printed_length = 0
printed_length = 0
if image and '<image>' in prompt:
request = {
"prompt": prompt,
"multi_modal_data": {"image": image}
}
if image and "<image>" in prompt:
request = {"prompt": prompt, "multi_modal_data": {"image": image}}
elif prompt:
request = {
"prompt": prompt
}
request = {"prompt": prompt}
else:
assert False, f'prompt is none!!!'
async for request_output in engine.generate(
request, sampling_params, request_id
):
assert False, "prompt is none!!!"
async for request_output in engine.generate(request, sampling_params, request_id):
if request_output.outputs:
full_text = request_output.outputs[0].text
new_text = full_text[printed_length:]
print(new_text, end='', flush=True)
print(new_text, end="", flush=True)
printed_length = len(full_text)
final_output = full_text
print('\n')
print("\n")
return final_output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to the input image file.")
args = parser.parse_args()
input_path = args.input if args.input else IMAGE_INPUT_PATH
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
image = load_image(INPUT_PATH).convert('RGB')
image = load_image(input_path).convert("RGB")
if '<image>' in PROMPT:
image_features = DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)
if "<image>" in PROMPT:
image_features = DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
else:
image_features = ''
image_features = ""
prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1
if save_results and '<image>' in prompt:
print('='*15 + 'save results:' + '='*15)
if save_results and "<image>" in prompt:
print("=" * 15 + "save results:" + "=" * 15)
base_name = os.path.basename(input_path)
file_name_without_ext = os.path.splitext(base_name)[0]
output_json_det_path = f'{OUTPUT_PATH}/{file_name_without_ext}_det.json'
output_json_path = f'{OUTPUT_PATH}/{file_name_without_ext}.json'
image_draw = image.copy()
outputs = result_out
with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
with open(output_json_det_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": outputs}, afile, ensure_ascii=False, indent=4)
matches_ref, matches_images, mathes_other = re_match(outputs)
# print(matches_ref)
result = process_image_with_refs(image_draw, matches_ref)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(a_match_image, f'![](images/' + str(idx) + '.jpg)\n')
outputs = outputs.replace(
a_match_image, "![](images/" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
outputs = (
outputs.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
)
# if 'structural formula' in conversation[0]['content']:
# outputs = '<smiles>' + outputs + '</smiles>'
with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
with open(output_json_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": outputs}, afile, ensure_ascii=False, indent=4)
if 'line_type' in outputs:
if "line_type" in outputs:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
lines = eval(outputs)['Line']['line']
line_type = eval(outputs)['Line']['line_type']
lines = eval(outputs)["Line"]["line"]
line_type = eval(outputs)["Line"]["line_type"]
# print(lines)
endpoints = eval(outputs)['Line']['line_endpoint']
endpoints = eval(outputs)["Line"]["line_endpoint"]
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
for idx, line in enumerate(lines):
try:
p0 = eval(line.split(' -- ')[0])
p1 = eval(line.split(' -- ')[-1])
p0 = eval(line.split(" -- ")[0])
p1 = eval(line.split(" -- ")[-1])
if line_type[idx] == '--':
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
if line_type[idx] == "--":
ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
else:
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
ax.scatter(p0[0], p0[1], s=5, color = 'k')
ax.scatter(p1[0], p1[1], s=5, color = 'k')
ax.scatter(p0[0], p0[1], s=5, color="k")
ax.scatter(p1[0], p1[1], s=5, color="k")
except:
pass
for endpoint in endpoints:
label = endpoint.split(": ")[0]
(x, y) = eval(endpoint.split(": ")[1])
ax.annotate(
label,
(x, y),
xytext=(1, 1),
textcoords="offset points",
fontsize=5,
fontweight="light",
)
label = endpoint.split(': ')[0]
(x, y) = eval(endpoint.split(': ')[1])
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
fontsize=5, fontweight='light')
try:
if 'Circle' in eval(outputs).keys():
circle_centers = eval(outputs)['Circle']['circle_center']
radius = eval(outputs)['Circle']['radius']
if "Circle" in eval(outputs).keys():
circle_centers = eval(outputs)["Circle"]["circle_center"]
radius = eval(outputs)["Circle"]["radius"]
for center, r in zip(circle_centers, radius):
center = eval(center.split(': ')[1])
circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8)
center = eval(center.split(": ")[1])
circle = Circle(
center,
radius=r,
fill=False,
edgecolor="black",
linewidth=0.8,
)
ax.add_patch(circle)
except:
pass
plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
plt.savefig(f"{OUTPUT_PATH}/geo.jpg")
plt.close()
result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg')
result.save(f"{OUTPUT_PATH}/result_with_boxes.jpg")

View File

@@ -1,30 +1,39 @@
import argparse
import io
import json
import os
import re
from concurrent.futures import ThreadPoolExecutor
import fitz
import img2pdf
import io
import re
from tqdm import tqdm
import torch
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
if torch.version.cuda == '11.8':
if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM
from config import (
CROP_MODE,
MAX_CONCURRENCY,
MODEL_PATH,
NUM_WORKERS,
OUTPUT_PATH,
PDF_INPUT_PATH,
PROMPT,
SKIP_REPEAT,
)
from PIL import Image, ImageDraw, ImageFont
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from deepseek_ocr import DeepseekOCRForCausalLM
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
@@ -34,16 +43,20 @@ llm = LLM(
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True
disable_mm_preprocessor_cache=True,
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fastwhitelist_token_ids: <td>,</td>
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
] # window for fastwhitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
@@ -55,23 +68,24 @@ sampling_params = SamplingParams(
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
"""
pdf2images
"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
@@ -84,32 +98,34 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
else:
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ('RGBA', 'LA'):
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(
img, mask=img.split()[-1] if img.mode == "RGBA" else None
)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != 'RGB':
img = img.convert('RGB')
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img.save(img_buffer, format="JPEG", quality=95)
img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes)
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
@@ -119,16 +135,14 @@ def pil_to_pdf_img2pdf(pil_images, output_path):
print(f"error: {e}")
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
@@ -136,8 +150,6 @@ def re_match(text):
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
@@ -149,28 +161,31 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
@@ -180,7 +195,7 @@ def draw_bounding_boxes(image, refs, jdx):
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
@@ -188,24 +203,36 @@ def draw_bounding_boxes(image, refs, jdx):
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
@@ -225,33 +252,41 @@ def process_single_image(image):
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
"multi_modal_data": {
"image": DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
},
}
return cache_item
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to the input PDF file.")
args = parser.parse_args()
input_path = args.input if args.input else PDF_INPUT_PATH
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
print(f"{Colors.RED}PDF loading .....{Colors.RESET}")
images = pdf_to_images_high_quality(INPUT_PATH)
images = pdf_to_images_high_quality(input_path)
prompt = PROMPT
# batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(
tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images",
)
)
# for image in tqdm(images):
@@ -264,38 +299,39 @@ if __name__ == "__main__":
# ]
# batch_inputs.extend(cache_list)
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
contents_det = ''
contents = ''
json_det_path = (
output_path + "/" + input_path.split("/")[-1].replace(".pdf", "_det.json")
)
json_path = (
output_path + "/" + input_path.split("/")[-1].replace(".pdf", ".json")
)
pdf_out_path = (
output_path
+ "/"
+ input_path.split("/")[-1].replace(".pdf", "_layouts.pdf")
)
contents_det = ""
contents = ""
draw_images = []
jdx = 0
for output, img in zip(outputs_list, images):
content = output.outputs[0].text
if '<end▁of▁sentence>' in content: # repeat no eos
content = content.replace('<end▁of▁sentence>', '')
if "<end▁of▁sentence>" in content: # repeat no eos
content = content.replace("<end▁of▁sentence>", "")
else:
if SKIP_REPEAT:
continue
page_num = f'\n<--- Page Split --->'
page_num = "\n<--- Page Split --->"
contents_det += content + f'\n{page_num}\n'
contents_det += content + f"\n{page_num}\n"
image_draw = img.copy()
@@ -303,28 +339,30 @@ if __name__ == "__main__":
# print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n')
content = content.replace(
a_match_image, "![](images/" + str(jdx) + "_" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(mathes_other):
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
contents += content + f'\n{page_num}\n'
content = (
content.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n\n", "\n\n")
.replace("\n\n\n", "\n\n")
)
contents += content + f"\n{page_num}\n"
jdx += 1
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
afile.write(contents_det)
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(contents)
with open(json_det_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": contents_det}, afile, ensure_ascii=False, indent=4)
with open(json_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": contents}, afile, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)