입출력 로직 변경

This commit is contained in:
kyy
2025-10-27 15:36:17 +09:00
parent 758b9afe9a
commit 64550b1fd5
4 changed files with 310 additions and 212 deletions

View File

@@ -1,30 +1,39 @@
import argparse
import io
import json
import os
import re
from concurrent.futures import ThreadPoolExecutor
import fitz
import img2pdf
import io
import re
from tqdm import tqdm
import torch
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
if torch.version.cuda == '11.8':
if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM
from config import (
CROP_MODE,
MAX_CONCURRENCY,
MODEL_PATH,
NUM_WORKERS,
OUTPUT_PATH,
PDF_INPUT_PATH,
PROMPT,
SKIP_REPEAT,
)
from PIL import Image, ImageDraw, ImageFont
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from deepseek_ocr import DeepseekOCRForCausalLM
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
@@ -34,16 +43,20 @@ llm = LLM(
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True
disable_mm_preprocessor_cache=True,
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fastwhitelist_token_ids: <td>,</td>
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
] # window for fastwhitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
@@ -55,23 +68,24 @@ sampling_params = SamplingParams(
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
"""
pdf2images
"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
@@ -84,32 +98,34 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
else:
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ('RGBA', 'LA'):
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(
img, mask=img.split()[-1] if img.mode == "RGBA" else None
)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != 'RGB':
img = img.convert('RGB')
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img.save(img_buffer, format="JPEG", quality=95)
img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes)
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
@@ -119,16 +135,14 @@ def pil_to_pdf_img2pdf(pil_images, output_path):
print(f"error: {e}")
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
@@ -136,8 +150,6 @@ def re_match(text):
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
@@ -149,28 +161,31 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
@@ -180,7 +195,7 @@ def draw_bounding_boxes(image, refs, jdx):
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
@@ -188,24 +203,36 @@ def draw_bounding_boxes(image, refs, jdx):
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
@@ -225,33 +252,41 @@ def process_single_image(image):
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
"multi_modal_data": {
"image": DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
},
}
return cache_item
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to the input PDF file.")
args = parser.parse_args()
input_path = args.input if args.input else PDF_INPUT_PATH
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
print(f"{Colors.RED}PDF loading .....{Colors.RESET}")
images = pdf_to_images_high_quality(INPUT_PATH)
images = pdf_to_images_high_quality(input_path)
prompt = PROMPT
# batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(
tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images",
)
)
# for image in tqdm(images):
@@ -264,38 +299,39 @@ if __name__ == "__main__":
# ]
# batch_inputs.extend(cache_list)
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
contents_det = ''
contents = ''
json_det_path = (
output_path + "/" + input_path.split("/")[-1].replace(".pdf", "_det.json")
)
json_path = (
output_path + "/" + input_path.split("/")[-1].replace(".pdf", ".json")
)
pdf_out_path = (
output_path
+ "/"
+ input_path.split("/")[-1].replace(".pdf", "_layouts.pdf")
)
contents_det = ""
contents = ""
draw_images = []
jdx = 0
for output, img in zip(outputs_list, images):
content = output.outputs[0].text
if '<end▁of▁sentence>' in content: # repeat no eos
content = content.replace('<end▁of▁sentence>', '')
if "<end▁of▁sentence>" in content: # repeat no eos
content = content.replace("<end▁of▁sentence>", "")
else:
if SKIP_REPEAT:
continue
page_num = f'\n<--- Page Split --->'
page_num = "\n<--- Page Split --->"
contents_det += content + f'\n{page_num}\n'
contents_det += content + f"\n{page_num}\n"
image_draw = img.copy()
@@ -303,28 +339,30 @@ if __name__ == "__main__":
# print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n')
content = content.replace(
a_match_image, "![](images/" + str(jdx) + "_" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(mathes_other):
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
contents += content + f'\n{page_num}\n'
content = (
content.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n\n", "\n\n")
.replace("\n\n\n", "\n\n")
)
contents += content + f"\n{page_num}\n"
jdx += 1
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
afile.write(contents_det)
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(contents)
with open(json_det_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": contents_det}, afile, ensure_ascii=False, indent=4)
with open(json_path, "w", encoding="utf-8") as afile:
json.dump({"parsed": contents}, afile, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)