Compare commits

2 Commits

Author SHA1 Message Date
kyy
64550b1fd5 입출력 로직 변경 2025-10-27 15:36:17 +09:00
kyy
758b9afe9a 이미지 파일 제외 2025-10-27 15:35:00 +09:00
5 changed files with 314 additions and 213 deletions

3
.gitignore vendored
View File

@@ -6,3 +6,6 @@ model_services/dotc.ocr/dots
__pycache__/ __pycache__/
output/ output/
*.pdf *.pdf
*.jpg
*.png
*.jpeg

View File

@@ -22,7 +22,8 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py # Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = "/workspace/2018-0802140959-217049.pdf" PDF_INPUT_PATH = "/workspace/2018-0802140959-217049.pdf"
IMAGE_INPUT_PATH = "/workspace/20250730180509-798-917-821.jpg"
OUTPUT_PATH = "/workspace/output/" OUTPUT_PATH = "/workspace/output/"
PROMPT = "<image>\n<|grounding|>Convert the document to markdown." PROMPT = "<image>\n<|grounding|>Convert the document to markdown."

View File

@@ -0,0 +1,24 @@
import os
import argparse
import subprocess
def main():
parser = argparse.ArgumentParser(description="Run OCR based on file type.")
parser.add_argument("input_path", type=str, help="Path to the input file (PDF or image).")
args = parser.parse_args()
input_path = args.input_path
file_extension = os.path.splitext(input_path)[1].lower()
if file_extension == '.pdf':
print(f"Detected PDF file. Running PDF OCR script for: {input_path}")
subprocess.run(["python", "run_dpsk_ocr_pdf.py", "--input", input_path])
elif file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
print(f"Detected image file. Running image OCR script for: {input_path}")
subprocess.run(["python", "run_dpsk_ocr_image.py", "--input", input_path])
else:
print(f"Unsupported file type: {file_extension}")
print("Please provide a PDF or an image file (.jpg, .jpeg, .png, .bmp, .gif).")
if __name__ == "__main__":
main()

View File

@@ -1,32 +1,35 @@
import argparse
import asyncio import asyncio
import re
import os import os
import json
import re
import torch import torch
if torch.version.cuda == '11.8':
if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas" os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0' os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = '0' os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
import numpy as np
from config import CROP_MODE, IMAGE_INPUT_PATH, MODEL_PATH, OUTPUT_PATH, PROMPT
from PIL import Image, ImageDraw, ImageFont, ImageOps
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import AsyncLLMEngine, SamplingParams from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
import time
from deepseek_ocr import DeepseekOCRForCausalLM from deepseek_ocr import DeepseekOCRForCausalLM
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from tqdm import tqdm
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path):
def load_image(image_path):
try: try:
image = Image.open(image_path) image = Image.open(image_path)
@@ -43,14 +46,13 @@ def load_image(image_path):
def re_match(text): def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [] mathes_image = []
mathes_other = [] mathes_other = []
for a_match in matches: for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]: if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0]) mathes_image.append(a_match[0])
else: else:
mathes_other.append(a_match[0]) mathes_other.append(a_match[0])
@@ -58,8 +60,6 @@ def re_match(text):
def extract_coordinates_and_label(ref_text, image_width, image_height): def extract_coordinates_and_label(ref_text, image_width, image_height):
try: try:
label_type = ref_text[1] label_type = ref_text[1]
cor_list = eval(ref_text[2]) cor_list = eval(ref_text[2])
@@ -71,12 +71,11 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
def draw_bounding_boxes(image, refs): def draw_bounding_boxes(image, refs):
image_width, image_height = image.size image_width, image_height = image.size
img_draw = image.copy() img_draw = image.copy()
draw = ImageDraw.Draw(img_draw) draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay) draw2 = ImageDraw.Draw(overlay)
# except IOError: # except IOError:
@@ -90,9 +89,13 @@ def draw_bounding_boxes(image, refs):
if result: if result:
label_type, points_list = result label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20, ) color_a = color + (20,)
for points in points_list: for points in points_list:
x1, y1, x2, y2 = points x1, y1, x2, y2 = points
@@ -102,7 +105,7 @@ def draw_bounding_boxes(image, refs):
x2 = int(x2 / 999 * image_width) x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height) y2 = int(y2 / 999 * image_height)
if label_type == 'image': if label_type == "image":
try: try:
cropped = image.crop((x1, y1, x2, y2)) cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg") cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
@@ -112,12 +115,22 @@ def draw_bounding_boxes(image, refs):
img_idx += 1 img_idx += 1
try: try:
if label_type == 'title': if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4) draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else: else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1 text_x = x1
text_y = max(0, y1 - 15) text_y = max(0, y1 - 15)
@@ -125,8 +138,10 @@ def draw_bounding_boxes(image, refs):
text_bbox = draw.textbbox((0, 0), label_type, font=font) text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], draw.rectangle(
fill=(255, 255, 255, 30)) [text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color) draw.text((text_x, text_y), label_type, font=font, fill=color)
except: except:
@@ -142,11 +157,7 @@ def process_image_with_refs(image, ref_texts):
return result_image return result_image
async def stream_generate(image=None, prompt=""):
async def stream_generate(image=None, prompt=''):
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=MODEL_PATH, model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
@@ -159,7 +170,11 @@ async def stream_generate(image=None, prompt=''):
) )
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td> logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822}
)
] # whitelist: <td>, </td>
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.0, temperature=0.0,
@@ -167,137 +182,157 @@ async def stream_generate(image=None, prompt=''):
logits_processors=logits_processors, logits_processors=logits_processors,
skip_special_tokens=False, skip_special_tokens=False,
# ignore_eos=False, # ignore_eos=False,
) )
request_id = f"request-{int(time.time())}" request_id = f"request-{int(time.time())}"
printed_length = 0 printed_length = 0
if image and '<image>' in prompt: if image and "<image>" in prompt:
request = { request = {"prompt": prompt, "multi_modal_data": {"image": image}}
"prompt": prompt,
"multi_modal_data": {"image": image}
}
elif prompt: elif prompt:
request = { request = {"prompt": prompt}
"prompt": prompt
}
else: else:
assert False, f'prompt is none!!!' assert False, "prompt is none!!!"
async for request_output in engine.generate( async for request_output in engine.generate(request, sampling_params, request_id):
request, sampling_params, request_id
):
if request_output.outputs: if request_output.outputs:
full_text = request_output.outputs[0].text full_text = request_output.outputs[0].text
new_text = full_text[printed_length:] new_text = full_text[printed_length:]
print(new_text, end='', flush=True) print(new_text, end="", flush=True)
printed_length = len(full_text) printed_length = len(full_text)
final_output = full_text final_output = full_text
print('\n') print("\n")
return final_output return final_output
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to the input image file.")
args = parser.parse_args()
input_path = args.input if args.input else IMAGE_INPUT_PATH
os.makedirs(OUTPUT_PATH, exist_ok=True) os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True) os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
image = load_image(INPUT_PATH).convert('RGB') image = load_image(input_path).convert("RGB")
if "<image>" in PROMPT:
if '<image>' in PROMPT: image_features = DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
image_features = DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE) )
else: else:
image_features = '' image_features = ""
prompt = PROMPT prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt)) result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1 save_results = 1
if save_results and '<image>' in prompt: if save_results and "<image>" in prompt:
print('='*15 + 'save results:' + '='*15) print("=" * 15 + "save results:" + "=" * 15)
base_name = os.path.basename(input_path)
file_name_without_ext = os.path.splitext(base_name)[0]
output_json_det_path = f'{OUTPUT_PATH}/{file_name_without_ext}_det.json'
output_json_path = f'{OUTPUT_PATH}/{file_name_without_ext}.json'
image_draw = image.copy() image_draw = image.copy()
outputs = result_out outputs = result_out
with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile: with open(output_json_det_path, "w", encoding="utf-8") as afile:
afile.write(outputs) json.dump({"parsed": outputs}, afile, ensure_ascii=False, indent=4)
matches_ref, matches_images, mathes_other = re_match(outputs) matches_ref, matches_images, mathes_other = re_match(outputs)
# print(matches_ref) # print(matches_ref)
result = process_image_with_refs(image_draw, matches_ref) result = process_image_with_refs(image_draw, matches_ref)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(a_match_image, f'![](images/' + str(idx) + '.jpg)\n') outputs = outputs.replace(
a_match_image, "![](images/" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') outputs = (
outputs.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
)
# if 'structural formula' in conversation[0]['content']: # if 'structural formula' in conversation[0]['content']:
# outputs = '<smiles>' + outputs + '</smiles>' # outputs = '<smiles>' + outputs + '</smiles>'
with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile: with open(output_json_path, "w", encoding="utf-8") as afile:
afile.write(outputs) json.dump({"parsed": outputs}, afile, ensure_ascii=False, indent=4)
if 'line_type' in outputs: if "line_type" in outputs:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Circle from matplotlib.patches import Circle
lines = eval(outputs)['Line']['line']
line_type = eval(outputs)['Line']['line_type'] lines = eval(outputs)["Line"]["line"]
line_type = eval(outputs)["Line"]["line_type"]
# print(lines) # print(lines)
endpoints = eval(outputs)['Line']['line_endpoint'] endpoints = eval(outputs)["Line"]["line_endpoint"]
fig, ax = plt.subplots(figsize=(3,3), dpi=200) fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
ax.set_xlim(-15, 15) ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15) ax.set_ylim(-15, 15)
for idx, line in enumerate(lines): for idx, line in enumerate(lines):
try: try:
p0 = eval(line.split(' -- ')[0]) p0 = eval(line.split(" -- ")[0])
p1 = eval(line.split(' -- ')[-1]) p1 = eval(line.split(" -- ")[-1])
if line_type[idx] == '--': if line_type[idx] == "--":
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
else: else:
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
ax.scatter(p0[0], p0[1], s=5, color = 'k') ax.scatter(p0[0], p0[1], s=5, color="k")
ax.scatter(p1[0], p1[1], s=5, color = 'k') ax.scatter(p1[0], p1[1], s=5, color="k")
except: except:
pass pass
for endpoint in endpoints: for endpoint in endpoints:
label = endpoint.split(": ")[0]
label = endpoint.split(': ')[0] (x, y) = eval(endpoint.split(": ")[1])
(x, y) = eval(endpoint.split(': ')[1]) ax.annotate(
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', label,
fontsize=5, fontweight='light') (x, y),
xytext=(1, 1),
textcoords="offset points",
fontsize=5,
fontweight="light",
)
try: try:
if 'Circle' in eval(outputs).keys(): if "Circle" in eval(outputs).keys():
circle_centers = eval(outputs)['Circle']['circle_center'] circle_centers = eval(outputs)["Circle"]["circle_center"]
radius = eval(outputs)['Circle']['radius'] radius = eval(outputs)["Circle"]["radius"]
for center, r in zip(circle_centers, radius): for center, r in zip(circle_centers, radius):
center = eval(center.split(': ')[1]) center = eval(center.split(": ")[1])
circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8) circle = Circle(
center,
radius=r,
fill=False,
edgecolor="black",
linewidth=0.8,
)
ax.add_patch(circle) ax.add_patch(circle)
except: except:
pass pass
plt.savefig(f"{OUTPUT_PATH}/geo.jpg")
plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
plt.close() plt.close()
result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg') result.save(f"{OUTPUT_PATH}/result_with_boxes.jpg")

View File

@@ -1,30 +1,39 @@
import os import argparse
import fitz
import img2pdf
import io import io
import json
import os
import re import re
from tqdm import tqdm
import torch
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import fitz
import img2pdf
import torch
from tqdm import tqdm
if torch.version.cuda == '11.8': if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas" os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0' os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = '0' os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM from config import (
CROP_MODE,
MAX_CONCURRENCY,
MODEL_PATH,
NUM_WORKERS,
OUTPUT_PATH,
PDF_INPUT_PATH,
PROMPT,
SKIP_REPEAT,
)
from PIL import Image, ImageDraw, ImageFont
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams from deepseek_ocr import DeepseekOCRForCausalLM
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
@@ -40,10 +49,14 @@ llm = LLM(
max_num_seqs=MAX_CONCURRENCY, max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1, tensor_parallel_size=1,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True disable_mm_preprocessor_cache=True,
) )
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fastwhitelist_token_ids: <td>,</td> logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
] # window for fastwhitelist_token_ids: <td>,</td>
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.0, temperature=0.0,
@@ -55,11 +68,12 @@ sampling_params = SamplingParams(
class Colors: class Colors:
RED = '\033[31m' RED = "\033[31m"
GREEN = '\033[32m' GREEN = "\033[32m"
YELLOW = '\033[33m' YELLOW = "\033[33m"
BLUE = '\033[34m' BLUE = "\033[34m"
RESET = '\033[0m' RESET = "\033[0m"
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"): def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
""" """
@@ -84,9 +98,11 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
else: else:
img_data = pixmap.tobytes("png") img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data)) img = Image.open(io.BytesIO(img_data))
if img.mode in ('RGBA', 'LA'): if img.mode in ("RGBA", "LA"):
background = Image.new('RGB', img.size, (255, 255, 255)) background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) background.paste(
img, mask=img.split()[-1] if img.mode == "RGBA" else None
)
img = background img = background
images.append(img) images.append(img)
@@ -94,19 +110,19 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
pdf_document.close() pdf_document.close()
return images return images
def pil_to_pdf_img2pdf(pil_images, output_path):
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images: if not pil_images:
return return
image_bytes_list = [] image_bytes_list = []
for img in pil_images: for img in pil_images:
if img.mode != 'RGB': if img.mode != "RGB":
img = img.convert('RGB') img = img.convert("RGB")
img_buffer = io.BytesIO() img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95) img.save(img_buffer, format="JPEG", quality=95)
img_bytes = img_buffer.getvalue() img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes) image_bytes_list.append(img_bytes)
@@ -119,16 +135,14 @@ def pil_to_pdf_img2pdf(pil_images, output_path):
print(f"error: {e}") print(f"error: {e}")
def re_match(text): def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [] mathes_image = []
mathes_other = [] mathes_other = []
for a_match in matches: for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]: if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0]) mathes_image.append(a_match[0])
else: else:
mathes_other.append(a_match[0]) mathes_other.append(a_match[0])
@@ -136,8 +150,6 @@ def re_match(text):
def extract_coordinates_and_label(ref_text, image_width, image_height): def extract_coordinates_and_label(ref_text, image_width, image_height):
try: try:
label_type = ref_text[1] label_type = ref_text[1]
cor_list = eval(ref_text[2]) cor_list = eval(ref_text[2])
@@ -149,12 +161,11 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
def draw_bounding_boxes(image, refs, jdx): def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size image_width, image_height = image.size
img_draw = image.copy() img_draw = image.copy()
draw = ImageDraw.Draw(img_draw) draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay) draw2 = ImageDraw.Draw(overlay)
# except IOError: # except IOError:
@@ -168,9 +179,13 @@ def draw_bounding_boxes(image, refs, jdx):
if result: if result:
label_type, points_list = result label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20, ) color_a = color + (20,)
for points in points_list: for points in points_list:
x1, y1, x2, y2 = points x1, y1, x2, y2 = points
@@ -180,7 +195,7 @@ def draw_bounding_boxes(image, refs, jdx):
x2 = int(x2 / 999 * image_width) x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height) y2 = int(y2 / 999 * image_height)
if label_type == 'image': if label_type == "image":
try: try:
cropped = image.crop((x1, y1, x2, y2)) cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg") cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
@@ -190,12 +205,22 @@ def draw_bounding_boxes(image, refs, jdx):
img_idx += 1 img_idx += 1
try: try:
if label_type == 'title': if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4) draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else: else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1 text_x = x1
text_y = max(0, y1 - 15) text_y = max(0, y1 - 15)
@@ -203,8 +228,10 @@ def draw_bounding_boxes(image, refs, jdx):
text_bbox = draw.textbbox((0, 0), label_type, font=font) text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], draw.rectangle(
fill=(255, 255, 255, 30)) [text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color) draw.text((text_x, text_y), label_type, font=font, fill=color)
except: except:
@@ -225,33 +252,41 @@ def process_single_image(image):
prompt_in = prompt prompt_in = prompt
cache_item = { cache_item = {
"prompt": prompt_in, "prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)}, "multi_modal_data": {
"image": DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
},
} }
return cache_item return cache_item
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to the input PDF file.")
args = parser.parse_args()
input_path = args.input if args.input else PDF_INPUT_PATH
os.makedirs(OUTPUT_PATH, exist_ok=True) os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True) os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
print(f'{Colors.RED}PDF loading .....{Colors.RESET}') print(f"{Colors.RED}PDF loading .....{Colors.RESET}")
images = pdf_to_images_high_quality(INPUT_PATH)
images = pdf_to_images_high_quality(input_path)
prompt = PROMPT prompt = PROMPT
# batch_inputs = [] # batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm( batch_inputs = list(
executor.map(process_single_image, images), tqdm(
total=len(images), executor.map(process_single_image, images),
desc="Pre-processed images" total=len(images),
)) desc="Pre-processed images",
)
)
# for image in tqdm(images): # for image in tqdm(images):
@@ -264,38 +299,39 @@ if __name__ == "__main__":
# ] # ]
# batch_inputs.extend(cache_list) # batch_inputs.extend(cache_list)
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
output_path = OUTPUT_PATH output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
json_det_path = (
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd') output_path + "/" + input_path.split("/")[-1].replace(".pdf", "_det.json")
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd') )
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf') json_path = (
contents_det = '' output_path + "/" + input_path.split("/")[-1].replace(".pdf", ".json")
contents = '' )
pdf_out_path = (
output_path
+ "/"
+ input_path.split("/")[-1].replace(".pdf", "_layouts.pdf")
)
contents_det = ""
contents = ""
draw_images = [] draw_images = []
jdx = 0 jdx = 0
for output, img in zip(outputs_list, images): for output, img in zip(outputs_list, images):
content = output.outputs[0].text content = output.outputs[0].text
if '<end▁of▁sentence>' in content: # repeat no eos if "<end▁of▁sentence>" in content: # repeat no eos
content = content.replace('<end▁of▁sentence>', '') content = content.replace("<end▁of▁sentence>", "")
else: else:
if SKIP_REPEAT: if SKIP_REPEAT:
continue continue
page_num = "\n<--- Page Split --->"
page_num = f'\n<--- Page Split --->' contents_det += content + f"\n{page_num}\n"
contents_det += content + f'\n{page_num}\n'
image_draw = img.copy() image_draw = img.copy()
@@ -303,28 +339,30 @@ if __name__ == "__main__":
# print(matches_ref) # print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx) result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image) draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images): for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n') content = content.replace(
a_match_image, "![](images/" + str(jdx) + "_" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(mathes_other): for idx, a_match_other in enumerate(mathes_other):
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n') content = (
content.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
contents += content + f'\n{page_num}\n' .replace("\\eqqcolon", "=:")
.replace("\n\n\n\n", "\n\n")
.replace("\n\n\n", "\n\n")
)
contents += content + f"\n{page_num}\n"
jdx += 1 jdx += 1
with open(mmd_det_path, 'w', encoding='utf-8') as afile: with open(json_det_path, "w", encoding="utf-8") as afile:
afile.write(contents_det) json.dump({"parsed": contents_det}, afile, ensure_ascii=False, indent=4)
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(contents)
with open(json_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": contents}, afile, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, pdf_out_path) pil_to_pdf_img2pdf(draw_images, pdf_out_path)