diff --git a/.gitignore b/.gitignore index 910afc8..1ffe179 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ -# Cache directories .cache/ -__pycache__/ \ No newline at end of file +.tmp/ +__pycache__/ +*.pyc +*.pyo +*.pyd +*.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 745d230..b5bfcaf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,8 +11,6 @@ ENV DEBIAN_FRONTEND=noninteractive \ TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \ TORCH_CUDA_ARCH_LIST="8.0" -WORKDIR /workspace - # 필수 빌드 도구 설치 RUN apt-get update && apt-get install -y --no-install-recommends \ git build-essential ninja-build \ @@ -36,4 +34,11 @@ RUN pip install vllm==0.8.5 RUN pip cache purge && \ pip install --no-cache-dir --no-build-isolation --no-binary=flash-attn flash-attn==2.7.3 -WORKDIR /workspace \ No newline at end of file +# API 서버 실행 포트 노출 +EXPOSE 11635 + +WORKDIR /workspace +COPY . . + +# Uvicorn으로 FastAPI 서버 실행 +CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "11635"] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index b536973..c268c76 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,13 +3,15 @@ services: build: context: . dockerfile: Dockerfile - image: deepseek-ocr-vllm:cu126 + image: deepseek-ocr-api:torch2.6.0-cuda12.6-cudnn9-vllm0.8.5 container_name: deepseek_ocr_vllm working_dir: /workspace volumes: - ./:/workspace + ports: + - "11635:11635" gpus: all - shm_size: "8g" + shm_size: "16g" ipc: "host" environment: - HF_HOME=/workspace/.cache/huggingface @@ -18,4 +20,10 @@ services: - PIP_DISABLE_PIP_VERSION_CHECK=1 - PYTHONUNBUFFERED=1 tty: true - entrypoint: ["/bin/bash"] + restart: always + networks: + - llm_gateway_local_net + +networks: + llm_gateway_local_net: + external: true diff --git a/services/ocr_engine.py b/services/ocr_engine.py new file mode 100644 index 0000000..70107cd --- /dev/null +++ b/services/ocr_engine.py @@ -0,0 +1,117 @@ +import asyncio +import io +from typing import Union + +import fitz +from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT +from fastapi import UploadFile +from PIL import Image +from process.image_process import DeepseekOCRProcessor +from vllm import AsyncLLMEngine, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.registry import ModelRegistry + +from services.deepseek_ocr import DeepseekOCRForCausalLM + +# -------------------------------------------------------------------------- +# 1. 모델 및 프로세서 초기화 +# -------------------------------------------------------------------------- + +# VLLM이 커스텀 모델을 인식하도록 등록 +ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) + +# VLLM 비동기 엔진 설정 +engine_args = AsyncEngineArgs( + model=MODEL_PATH, + hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, + block_size=256, + max_model_len=8192, + enforce_eager=False, + trust_remote_code=True, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, +) +engine = AsyncLLMEngine.from_engine_args(engine_args) + +# 샘플링 파라미터 설정 +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=8192, + skip_special_tokens=False, +) + +# 이미지 전처리기 +processor = DeepseekOCRProcessor() + +# -------------------------------------------------------------------------- +# 2. 핵심 처리 함수 +# -------------------------------------------------------------------------- + +async def _process_single_image(image: Image.Image) -> str: + """단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다.""" + if "" not in PROMPT: + raise ValueError("프롬프트에 '' 토큰이 없어 OCR을 수행할 수 없습니다.") + + image_features = processor.tokenize_with_images( + images=[image], bos=True, eos=True, cropping=CROP_MODE + ) + + request = {"prompt": PROMPT, "multi_modal_data": {"image": image_features}} + request_id = f"request-{asyncio.get_running_loop().time()}" + + final_output = "" + async for request_output in engine.generate(request, sampling_params, request_id): + if request_output.outputs: + final_output = request_output.outputs[0].text + + return final_output + +def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]: + """PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다.""" + images = [] + pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf") + zoom = dpi / 72.0 + matrix = fitz.Matrix(zoom, zoom) + + for page_num in range(pdf_document.page_count): + page = pdf_document[page_num] + pixmap = page.get_pixmap(matrix=matrix, alpha=False) + img_data = pixmap.tobytes("png") + img = Image.open(io.BytesIO(img_data)) + images.append(img) + + pdf_document.close() + return images + +async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict: + """ + 업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다. + """ + if content_type.startswith("image/"): + try: + image = Image.open(io.BytesIO(file_bytes)).convert("RGB") + except Exception as e: + raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}") + + result_text = await _process_single_image(image) + return {"filename": filename, "text": result_text} + + elif content_type == "application/pdf": + try: + images = _pdf_to_images(file_bytes) + except Exception as e: + raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}") + + # 각 페이지를 비동기적으로 처리 + tasks = [_process_single_image(img) for img in images] + page_results = await asyncio.gather(*tasks) + + # 페이지 구분자를 넣어 전체 텍스트 합치기 + full_text = "\n<--- Page Split --->\n".join(page_results) + return {"filename": filename, "text": full_text, "page_count": len(images)} + + else: + raise ValueError( + f"지원하지 않는 파일 형식입니다: {content_type}. " + "이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요." + ) \ No newline at end of file diff --git a/services/ocr_gateway_post.py b/services/ocr_gateway_post.py new file mode 100644 index 0000000..def8b7d --- /dev/null +++ b/services/ocr_gateway_post.py @@ -0,0 +1,71 @@ +import logging +import os +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + + +# ✅ DeepSeek OCR API +async def extract_deepseek_ocr(file_path: str): + """ + deepseek_ocr_vllm 컨테이너를 호출하여 이미지에서 텍스트 및 좌표를 추출합니다. + """ + # deepseek_ocr_vllm 컨테이너의 FastAPI 엔드포인트 URL + # TODO: 실제 엔드포인트명('/ocr')을 확정한 후 필요시 수정해야 합니다. + DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL", "http://deepseek_ocr_vllm:8000/ocr") + + if not file_path or not os.path.exists(file_path): + raise FileNotFoundError(f"파일이 존재하지 않습니다: {file_path}") + + filename = Path(file_path).name + full_text_parts = [] + coord_response = [] + + with open(file_path, "rb") as f: + # TODO: FastAPI 엔드포인트에서 사용하는 파일 파라미터 이름('document')을 확인해야 합니다. + files = {"document": (filename, f, "application/octet-stream")} + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(DEEPSEEK_API_URL, files=files) + response.raise_for_status() + result = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"DeepSeek API 오류: {e.response.text}") + raise RuntimeError(f"DeepSeek API 오류: {e.response.status_code}") + except httpx.ConnectError as e: + logger.error(f"DeepSeek 컨테이너 연결 실패: {e}") + raise RuntimeError(f"DeepSeek 컨테이너({DEEPSEEK_API_URL})에 연결할 수 없습니다.") + + # TODO: 실제 API 응답 형식에 맞게 JSON 파싱 로직을 수정해야 합니다. + try: + # 아래는 응답이 Upstage와 유사한 형식일 경우를 가정한 예시입니다. + pages = result.get("pages", []) + for page_idx, p in enumerate(pages, start=1): + txt = p.get("text") + if txt: + full_text_parts.append(txt) + + for w in p.get("words", []): + verts = (w.get("boundingBox", {}) or {}).get("vertices") + if not verts or len(verts) != 4: + continue + xs = [v.get("x", 0) for v in verts] + ys = [v.get("y", 0) for v in verts] + coord_response.append( + { + "text": w.get("text"), + "coords": [min(xs), min(ys), max(xs), max(ys)], + "page": page_idx, + } + ) + logger.warning("DeepSeek OCR의 실제 응답 형식에 맞게 파싱 로직을 구현해야 합니다.") + + except Exception as e: + logger.error(f"[DEEPSEEK] JSON 파싱 실패: {e} / 원본 result: {result}") + return "", [] + + logger.info("[DEEPSEEK] 텍스트 추출 완료") + full_response = "\n".join(full_text_parts) + return full_response, coord_response \ No newline at end of file diff --git a/services/run_dpsk_ocr_image.py b/services/run_dpsk_ocr_image.py deleted file mode 100644 index 5dfd884..0000000 --- a/services/run_dpsk_ocr_image.py +++ /dev/null @@ -1,320 +0,0 @@ -import asyncio -import os -import re - -from config import env_setup - -env_setup.setup_environment() - -import time - -import numpy as np -from config.model_settings import CROP_MODE, INPUT_PATH, MODEL_PATH, OUTPUT_PATH, PROMPT -from PIL import Image, ImageDraw, ImageFont, ImageOps -from process.image_process import DeepseekOCRProcessor -from process.ngram_norepeat import NoRepeatNGramLogitsProcessor -from tqdm import tqdm -from vllm import AsyncLLMEngine, SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.model_executor.models.registry import ModelRegistry - -from deepseek_ocr import DeepseekOCRForCausalLM - -ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) - - -def load_image(image_path): - try: - image = Image.open(image_path) - - corrected_image = ImageOps.exif_transpose(image) - - return corrected_image - - except Exception as e: - print(f"error: {e}") - try: - return Image.open(image_path) - except: - return None - - -def re_match(text): - pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)" - matches = re.findall(pattern, text, re.DOTALL) - - mathes_image = [] - mathes_other = [] - for a_match in matches: - if "<|ref|>image<|/ref|>" in a_match[0]: - mathes_image.append(a_match[0]) - else: - mathes_other.append(a_match[0]) - return matches, mathes_image, mathes_other - - -def extract_coordinates_and_label(ref_text, image_width, image_height): - try: - label_type = ref_text[1] - cor_list = eval(ref_text[2]) - except Exception as e: - print(e) - return None - - return (label_type, cor_list) - - -def draw_bounding_boxes(image, refs): - image_width, image_height = image.size - img_draw = image.copy() - draw = ImageDraw.Draw(img_draw) - - overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0)) - draw2 = ImageDraw.Draw(overlay) - - # except IOError: - font = ImageFont.load_default() - - img_idx = 0 - - for i, ref in enumerate(refs): - try: - result = extract_coordinates_and_label(ref, image_width, image_height) - if result: - label_type, points_list = result - - color = ( - np.random.randint(0, 200), - np.random.randint(0, 200), - np.random.randint(0, 255), - ) - - color_a = color + (20,) - for points in points_list: - x1, y1, x2, y2 = points - - x1 = int(x1 / 999 * image_width) - y1 = int(y1 / 999 * image_height) - - x2 = int(x2 / 999 * image_width) - y2 = int(y2 / 999 * image_height) - - if label_type == "image": - try: - cropped = image.crop((x1, y1, x2, y2)) - cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg") - except Exception as e: - print(e) - pass - img_idx += 1 - - try: - if label_type == "title": - draw.rectangle([x1, y1, x2, y2], outline=color, width=4) - draw2.rectangle( - [x1, y1, x2, y2], - fill=color_a, - outline=(0, 0, 0, 0), - width=1, - ) - else: - draw.rectangle([x1, y1, x2, y2], outline=color, width=2) - draw2.rectangle( - [x1, y1, x2, y2], - fill=color_a, - outline=(0, 0, 0, 0), - width=1, - ) - - text_x = x1 - text_y = max(0, y1 - 15) - - text_bbox = draw.textbbox((0, 0), label_type, font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - draw.rectangle( - [text_x, text_y, text_x + text_width, text_y + text_height], - fill=(255, 255, 255, 30), - ) - - draw.text((text_x, text_y), label_type, font=font, fill=color) - except: - pass - except: - continue - img_draw.paste(overlay, (0, 0), overlay) - return img_draw - - -def process_image_with_refs(image, ref_texts): - result_image = draw_bounding_boxes(image, ref_texts) - return result_image - - -async def stream_generate(image=None, prompt=""): - engine_args = AsyncEngineArgs( - model=MODEL_PATH, - hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, - block_size=256, - max_model_len=8192, - enforce_eager=False, - trust_remote_code=True, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - ) - engine = AsyncLLMEngine.from_engine_args(engine_args) - - logits_processors = [ - NoRepeatNGramLogitsProcessor( - ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822} - ) - ] # whitelist: , - - sampling_params = SamplingParams( - temperature=0.0, - max_tokens=8192, - logits_processors=logits_processors, - skip_special_tokens=False, - # ignore_eos=False, - ) - - request_id = f"request-{int(time.time())}" - - printed_length = 0 - - if image and "" in prompt: - request = {"prompt": prompt, "multi_modal_data": {"image": image}} - elif prompt: - request = {"prompt": prompt} - else: - assert False, "prompt is none!!!" - async for request_output in engine.generate(request, sampling_params, request_id): - if request_output.outputs: - full_text = request_output.outputs[0].text - new_text = full_text[printed_length:] - print(new_text, end="", flush=True) - printed_length = len(full_text) - final_output = full_text - print("\n") - - return final_output - - -if __name__ == "__main__": - os.makedirs(OUTPUT_PATH, exist_ok=True) - os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True) - - image = load_image(INPUT_PATH).convert("RGB") - - if "" in PROMPT: - image_features = DeepseekOCRProcessor().tokenize_with_images( - images=[image], bos=True, eos=True, cropping=CROP_MODE - ) - else: - image_features = "" - - prompt = PROMPT - - result_out = asyncio.run(stream_generate(image_features, prompt)) - - save_results = 1 - - if save_results and "" in prompt: - print("=" * 15 + "save results:" + "=" * 15) - - image_draw = image.copy() - - outputs = result_out - - with open(f"{OUTPUT_PATH}/result_ori.mmd", "w", encoding="utf-8") as afile: - afile.write(outputs) - - matches_ref, matches_images, mathes_other = re_match(outputs) - # print(matches_ref) - result = process_image_with_refs(image_draw, matches_ref) - - for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): - outputs = outputs.replace( - a_match_image, "![](images/" + str(idx) + ".jpg)\n" - ) - - for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): - outputs = ( - outputs.replace(a_match_other, "") - .replace("\\coloneqq", ":=") - .replace("\\eqqcolon", "=:") - ) - - # if 'structural formula' in conversation[0]['content']: - # outputs = '' + outputs + '' - with open(f"{OUTPUT_PATH}/result.mmd", "w", encoding="utf-8") as afile: - afile.write(outputs) - - if "line_type" in outputs: - import matplotlib.pyplot as plt - from matplotlib.patches import Circle - - lines = eval(outputs)["Line"]["line"] - - line_type = eval(outputs)["Line"]["line_type"] - # print(lines) - - endpoints = eval(outputs)["Line"]["line_endpoint"] - - fig, ax = plt.subplots(figsize=(3, 3), dpi=200) - ax.set_xlim(-15, 15) - ax.set_ylim(-15, 15) - - for idx, line in enumerate(lines): - try: - p0 = eval(line.split(" -- ")[0]) - p1 = eval(line.split(" -- ")[-1]) - - if line_type[idx] == "--": - ax.plot( - [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k" - ) - else: - ax.plot( - [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k" - ) - - ax.scatter(p0[0], p0[1], s=5, color="k") - ax.scatter(p1[0], p1[1], s=5, color="k") - except: - pass - - for endpoint in endpoints: - label = endpoint.split(": ")[0] - (x, y) = eval(endpoint.split(": ")[1]) - ax.annotate( - label, - (x, y), - xytext=(1, 1), - textcoords="offset points", - fontsize=5, - fontweight="light", - ) - - try: - if "Circle" in eval(outputs).keys(): - circle_centers = eval(outputs)["Circle"]["circle_center"] - radius = eval(outputs)["Circle"]["radius"] - - for center, r in zip(circle_centers, radius): - center = eval(center.split(": ")[1]) - circle = Circle( - center, - radius=r, - fill=False, - edgecolor="black", - linewidth=0.8, - ) - ax.add_patch(circle) - except: - pass - - plt.savefig(f"{OUTPUT_PATH}/geo.jpg") - plt.close() - - result.save(f"{OUTPUT_PATH}/result_with_boxes.jpg") diff --git a/services/run_dpsk_ocr_pdf.py b/services/run_dpsk_ocr_pdf.py deleted file mode 100644 index e2872b2..0000000 --- a/services/run_dpsk_ocr_pdf.py +++ /dev/null @@ -1,352 +0,0 @@ -import io -import os -import re -from concurrent.futures import ThreadPoolExecutor - -from config import env_setup - -env_setup.setup_environment() - -import fitz -import img2pdf -import numpy as np -from config.model_settings import ( - CROP_MODE, - INPUT_PATH, - MAX_CONCURRENCY, - MODEL_PATH, - NUM_WORKERS, - OUTPUT_PATH, - PROMPT, - SKIP_REPEAT, -) -from PIL import Image, ImageDraw, ImageFont -from process.image_process import DeepseekOCRProcessor -from process.ngram_norepeat import NoRepeatNGramLogitsProcessor -from tqdm import tqdm -from vllm import LLM, SamplingParams -from vllm.model_executor.models.registry import ModelRegistry - -from deepseek_ocr import DeepseekOCRForCausalLM - -ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) - - -llm = LLM( - model=MODEL_PATH, - hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, - block_size=256, - enforce_eager=False, - trust_remote_code=True, - max_model_len=8192, - swap_space=0, - max_num_seqs=MAX_CONCURRENCY, - tensor_parallel_size=1, - gpu_memory_utilization=0.9, - disable_mm_preprocessor_cache=True, -) - -logits_processors = [ - NoRepeatNGramLogitsProcessor( - ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822} - ) -] # window for fast;whitelist_token_ids: , - -sampling_params = SamplingParams( - temperature=0.0, - max_tokens=8192, - logits_processors=logits_processors, - skip_special_tokens=False, - include_stop_str_in_output=True, -) - - -class Colors: - RED = "\033[31m" - GREEN = "\033[32m" - YELLOW = "\033[33m" - BLUE = "\033[34m" - RESET = "\033[0m" - - -def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"): - """ - pdf2images - """ - images = [] - - pdf_document = fitz.open(pdf_path) - - zoom = dpi / 72.0 - matrix = fitz.Matrix(zoom, zoom) - - for page_num in range(pdf_document.page_count): - page = pdf_document[page_num] - - pixmap = page.get_pixmap(matrix=matrix, alpha=False) - Image.MAX_IMAGE_PIXELS = None - - if image_format.upper() == "PNG": - img_data = pixmap.tobytes("png") - img = Image.open(io.BytesIO(img_data)) - else: - img_data = pixmap.tobytes("png") - img = Image.open(io.BytesIO(img_data)) - if img.mode in ("RGBA", "LA"): - background = Image.new("RGB", img.size, (255, 255, 255)) - background.paste( - img, mask=img.split()[-1] if img.mode == "RGBA" else None - ) - img = background - - images.append(img) - - pdf_document.close() - return images - - -def pil_to_pdf_img2pdf(pil_images, output_path): - if not pil_images: - return - - image_bytes_list = [] - - for img in pil_images: - if img.mode != "RGB": - img = img.convert("RGB") - - img_buffer = io.BytesIO() - img.save(img_buffer, format="JPEG", quality=95) - img_bytes = img_buffer.getvalue() - image_bytes_list.append(img_bytes) - - try: - pdf_bytes = img2pdf.convert(image_bytes_list) - with open(output_path, "wb") as f: - f.write(pdf_bytes) - - except Exception as e: - print(f"error: {e}") - - -def re_match(text): - pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)" - matches = re.findall(pattern, text, re.DOTALL) - - mathes_image = [] - mathes_other = [] - for a_match in matches: - if "<|ref|>image<|/ref|>" in a_match[0]: - mathes_image.append(a_match[0]) - else: - mathes_other.append(a_match[0]) - return matches, mathes_image, mathes_other - - -def extract_coordinates_and_label(ref_text, image_width, image_height): - try: - label_type = ref_text[1] - cor_list = eval(ref_text[2]) - except Exception as e: - print(e) - return None - - return (label_type, cor_list) - - -def draw_bounding_boxes(image, refs, jdx): - image_width, image_height = image.size - img_draw = image.copy() - draw = ImageDraw.Draw(img_draw) - - overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0)) - draw2 = ImageDraw.Draw(overlay) - - # except IOError: - font = ImageFont.load_default() - - img_idx = 0 - - for i, ref in enumerate(refs): - try: - result = extract_coordinates_and_label(ref, image_width, image_height) - if result: - label_type, points_list = result - - color = ( - np.random.randint(0, 200), - np.random.randint(0, 200), - np.random.randint(0, 255), - ) - - color_a = color + (20,) - for points in points_list: - x1, y1, x2, y2 = points - - x1 = int(x1 / 999 * image_width) - y1 = int(y1 / 999 * image_height) - - x2 = int(x2 / 999 * image_width) - y2 = int(y2 / 999 * image_height) - - if label_type == "image": - try: - cropped = image.crop((x1, y1, x2, y2)) - cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg") - except Exception as e: - print(e) - pass - img_idx += 1 - - try: - if label_type == "title": - draw.rectangle([x1, y1, x2, y2], outline=color, width=4) - draw2.rectangle( - [x1, y1, x2, y2], - fill=color_a, - outline=(0, 0, 0, 0), - width=1, - ) - else: - draw.rectangle([x1, y1, x2, y2], outline=color, width=2) - draw2.rectangle( - [x1, y1, x2, y2], - fill=color_a, - outline=(0, 0, 0, 0), - width=1, - ) - - text_x = x1 - text_y = max(0, y1 - 15) - - text_bbox = draw.textbbox((0, 0), label_type, font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - draw.rectangle( - [text_x, text_y, text_x + text_width, text_y + text_height], - fill=(255, 255, 255, 30), - ) - - draw.text((text_x, text_y), label_type, font=font, fill=color) - except: - pass - except: - continue - img_draw.paste(overlay, (0, 0), overlay) - return img_draw - - -def process_image_with_refs(image, ref_texts, jdx): - result_image = draw_bounding_boxes(image, ref_texts, jdx) - return result_image - - -def process_single_image(image): - """single image""" - prompt_in = prompt - cache_item = { - "prompt": prompt_in, - "multi_modal_data": { - "image": DeepseekOCRProcessor().tokenize_with_images( - images=[image], bos=True, eos=True, cropping=CROP_MODE - ) - }, - } - return cache_item - - -if __name__ == "__main__": - os.makedirs(OUTPUT_PATH, exist_ok=True) - os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True) - - print(f"{Colors.RED}PDF loading .....{Colors.RESET}") - - images = pdf_to_images_high_quality(INPUT_PATH) - - prompt = PROMPT - - # batch_inputs = [] - - with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: - batch_inputs = list( - tqdm( - executor.map(process_single_image, images), - total=len(images), - desc="Pre-processed images", - ) - ) - - # for image in tqdm(images): - - # prompt_in = prompt - # cache_list = [ - # { - # "prompt": prompt_in, - # "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)}, - # } - # ] - # batch_inputs.extend(cache_list) - - outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params) - - output_path = OUTPUT_PATH - - os.makedirs(output_path, exist_ok=True) - - mmd_det_path = ( - output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_det.mmd") - ) - mmd_path = output_path + "/" + INPUT_PATH.split("/")[-1].replace("pdf", "mmd") - pdf_out_path = ( - output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_layouts.pdf") - ) - contents_det = "" - contents = "" - draw_images = [] - jdx = 0 - for output, img in zip(outputs_list, images): - content = output.outputs[0].text - - if "<|end▁of▁sentence|>" in content: # repeat no eos - content = content.replace("<|end▁of▁sentence|>", "") - else: - if SKIP_REPEAT: - continue - - page_num = "\n<--- Page Split --->" - - contents_det += content + f"\n{page_num}\n" - - image_draw = img.copy() - - matches_ref, matches_images, mathes_other = re_match(content) - # print(matches_ref) - result_image = process_image_with_refs(image_draw, matches_ref, jdx) - - draw_images.append(result_image) - - for idx, a_match_image in enumerate(matches_images): - content = content.replace( - a_match_image, "![](images/" + str(jdx) + "_" + str(idx) + ".jpg)\n" - ) - - for idx, a_match_other in enumerate(mathes_other): - content = ( - content.replace(a_match_other, "") - .replace("\\coloneqq", ":=") - .replace("\\eqqcolon", "=:") - .replace("\n\n\n\n", "\n\n") - .replace("\n\n\n", "\n\n") - ) - - contents += content + f"\n{page_num}\n" - - jdx += 1 - - with open(mmd_det_path, "w", encoding="utf-8") as afile: - afile.write(contents_det) - - with open(mmd_path, "w", encoding="utf-8") as afile: - afile.write(contents) - - pil_to_pdf_img2pdf(draw_images, pdf_out_path)