FastAPI 기반 OCR 서비스 구현 및 Docker 환경 구성
This commit is contained in:
8
.gitignore
vendored
8
.gitignore
vendored
@@ -1,3 +1,7 @@
|
|||||||
# Cache directories
|
|
||||||
.cache/
|
.cache/
|
||||||
__pycache__/
|
.tmp/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.log
|
||||||
11
Dockerfile
11
Dockerfile
@@ -11,8 +11,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|||||||
TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
|
TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
|
||||||
TORCH_CUDA_ARCH_LIST="8.0"
|
TORCH_CUDA_ARCH_LIST="8.0"
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
# 필수 빌드 도구 설치
|
# 필수 빌드 도구 설치
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
git build-essential ninja-build \
|
git build-essential ninja-build \
|
||||||
@@ -36,4 +34,11 @@ RUN pip install vllm==0.8.5
|
|||||||
RUN pip cache purge && \
|
RUN pip cache purge && \
|
||||||
pip install --no-cache-dir --no-build-isolation --no-binary=flash-attn flash-attn==2.7.3
|
pip install --no-cache-dir --no-build-isolation --no-binary=flash-attn flash-attn==2.7.3
|
||||||
|
|
||||||
WORKDIR /workspace
|
# API 서버 실행 포트 노출
|
||||||
|
EXPOSE 11635
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Uvicorn으로 FastAPI 서버 실행
|
||||||
|
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "11635"]
|
||||||
@@ -3,13 +3,15 @@ services:
|
|||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
image: deepseek-ocr-vllm:cu126
|
image: deepseek-ocr-api:torch2.6.0-cuda12.6-cudnn9-vllm0.8.5
|
||||||
container_name: deepseek_ocr_vllm
|
container_name: deepseek_ocr_vllm
|
||||||
working_dir: /workspace
|
working_dir: /workspace
|
||||||
volumes:
|
volumes:
|
||||||
- ./:/workspace
|
- ./:/workspace
|
||||||
|
ports:
|
||||||
|
- "11635:11635"
|
||||||
gpus: all
|
gpus: all
|
||||||
shm_size: "8g"
|
shm_size: "16g"
|
||||||
ipc: "host"
|
ipc: "host"
|
||||||
environment:
|
environment:
|
||||||
- HF_HOME=/workspace/.cache/huggingface
|
- HF_HOME=/workspace/.cache/huggingface
|
||||||
@@ -18,4 +20,10 @@ services:
|
|||||||
- PIP_DISABLE_PIP_VERSION_CHECK=1
|
- PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
tty: true
|
tty: true
|
||||||
entrypoint: ["/bin/bash"]
|
restart: always
|
||||||
|
networks:
|
||||||
|
- llm_gateway_local_net
|
||||||
|
|
||||||
|
networks:
|
||||||
|
llm_gateway_local_net:
|
||||||
|
external: true
|
||||||
|
|||||||
117
services/ocr_engine.py
Normal file
117
services/ocr_engine.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from PIL import Image
|
||||||
|
from process.image_process import DeepseekOCRProcessor
|
||||||
|
from vllm import AsyncLLMEngine, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
|
|
||||||
|
from services.deepseek_ocr import DeepseekOCRForCausalLM
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# 1. 모델 및 프로세서 초기화
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# VLLM이 커스텀 모델을 인식하도록 등록
|
||||||
|
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||||
|
|
||||||
|
# VLLM 비동기 엔진 설정
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||||
|
block_size=256,
|
||||||
|
max_model_len=8192,
|
||||||
|
enforce_eager=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.75,
|
||||||
|
)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
# 샘플링 파라미터 설정
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=8192,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 이미지 전처리기
|
||||||
|
processor = DeepseekOCRProcessor()
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# 2. 핵심 처리 함수
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _process_single_image(image: Image.Image) -> str:
|
||||||
|
"""단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다."""
|
||||||
|
if "<image>" not in PROMPT:
|
||||||
|
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
|
||||||
|
|
||||||
|
image_features = processor.tokenize_with_images(
|
||||||
|
images=[image], bos=True, eos=True, cropping=CROP_MODE
|
||||||
|
)
|
||||||
|
|
||||||
|
request = {"prompt": PROMPT, "multi_modal_data": {"image": image_features}}
|
||||||
|
request_id = f"request-{asyncio.get_running_loop().time()}"
|
||||||
|
|
||||||
|
final_output = ""
|
||||||
|
async for request_output in engine.generate(request, sampling_params, request_id):
|
||||||
|
if request_output.outputs:
|
||||||
|
final_output = request_output.outputs[0].text
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
||||||
|
"""PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다."""
|
||||||
|
images = []
|
||||||
|
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||||
|
zoom = dpi / 72.0
|
||||||
|
matrix = fitz.Matrix(zoom, zoom)
|
||||||
|
|
||||||
|
for page_num in range(pdf_document.page_count):
|
||||||
|
page = pdf_document[page_num]
|
||||||
|
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
||||||
|
img_data = pixmap.tobytes("png")
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
pdf_document.close()
|
||||||
|
return images
|
||||||
|
|
||||||
|
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
|
||||||
|
"""
|
||||||
|
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
|
||||||
|
"""
|
||||||
|
if content_type.startswith("image/"):
|
||||||
|
try:
|
||||||
|
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
|
||||||
|
|
||||||
|
result_text = await _process_single_image(image)
|
||||||
|
return {"filename": filename, "text": result_text}
|
||||||
|
|
||||||
|
elif content_type == "application/pdf":
|
||||||
|
try:
|
||||||
|
images = _pdf_to_images(file_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
|
||||||
|
|
||||||
|
# 각 페이지를 비동기적으로 처리
|
||||||
|
tasks = [_process_single_image(img) for img in images]
|
||||||
|
page_results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# 페이지 구분자를 넣어 전체 텍스트 합치기
|
||||||
|
full_text = "\n<--- Page Split --->\n".join(page_results)
|
||||||
|
return {"filename": filename, "text": full_text, "page_count": len(images)}
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"지원하지 않는 파일 형식입니다: {content_type}. "
|
||||||
|
"이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요."
|
||||||
|
)
|
||||||
71
services/ocr_gateway_post.py
Normal file
71
services/ocr_gateway_post.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ✅ DeepSeek OCR API
|
||||||
|
async def extract_deepseek_ocr(file_path: str):
|
||||||
|
"""
|
||||||
|
deepseek_ocr_vllm 컨테이너를 호출하여 이미지에서 텍스트 및 좌표를 추출합니다.
|
||||||
|
"""
|
||||||
|
# deepseek_ocr_vllm 컨테이너의 FastAPI 엔드포인트 URL
|
||||||
|
# TODO: 실제 엔드포인트명('/ocr')을 확정한 후 필요시 수정해야 합니다.
|
||||||
|
DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL", "http://deepseek_ocr_vllm:8000/ocr")
|
||||||
|
|
||||||
|
if not file_path or not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"파일이 존재하지 않습니다: {file_path}")
|
||||||
|
|
||||||
|
filename = Path(file_path).name
|
||||||
|
full_text_parts = []
|
||||||
|
coord_response = []
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
# TODO: FastAPI 엔드포인트에서 사용하는 파일 파라미터 이름('document')을 확인해야 합니다.
|
||||||
|
files = {"document": (filename, f, "application/octet-stream")}
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.post(DEEPSEEK_API_URL, files=files)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"DeepSeek API 오류: {e.response.text}")
|
||||||
|
raise RuntimeError(f"DeepSeek API 오류: {e.response.status_code}")
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
logger.error(f"DeepSeek 컨테이너 연결 실패: {e}")
|
||||||
|
raise RuntimeError(f"DeepSeek 컨테이너({DEEPSEEK_API_URL})에 연결할 수 없습니다.")
|
||||||
|
|
||||||
|
# TODO: 실제 API 응답 형식에 맞게 JSON 파싱 로직을 수정해야 합니다.
|
||||||
|
try:
|
||||||
|
# 아래는 응답이 Upstage와 유사한 형식일 경우를 가정한 예시입니다.
|
||||||
|
pages = result.get("pages", [])
|
||||||
|
for page_idx, p in enumerate(pages, start=1):
|
||||||
|
txt = p.get("text")
|
||||||
|
if txt:
|
||||||
|
full_text_parts.append(txt)
|
||||||
|
|
||||||
|
for w in p.get("words", []):
|
||||||
|
verts = (w.get("boundingBox", {}) or {}).get("vertices")
|
||||||
|
if not verts or len(verts) != 4:
|
||||||
|
continue
|
||||||
|
xs = [v.get("x", 0) for v in verts]
|
||||||
|
ys = [v.get("y", 0) for v in verts]
|
||||||
|
coord_response.append(
|
||||||
|
{
|
||||||
|
"text": w.get("text"),
|
||||||
|
"coords": [min(xs), min(ys), max(xs), max(ys)],
|
||||||
|
"page": page_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.warning("DeepSeek OCR의 실제 응답 형식에 맞게 파싱 로직을 구현해야 합니다.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[DEEPSEEK] JSON 파싱 실패: {e} / 원본 result: {result}")
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
logger.info("[DEEPSEEK] 텍스트 추출 완료")
|
||||||
|
full_response = "\n".join(full_text_parts)
|
||||||
|
return full_response, coord_response
|
||||||
@@ -1,320 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
from config import env_setup
|
|
||||||
|
|
||||||
env_setup.setup_environment()
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from config.model_settings import CROP_MODE, INPUT_PATH, MODEL_PATH, OUTPUT_PATH, PROMPT
|
|
||||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
|
||||||
from process.image_process import DeepseekOCRProcessor
|
|
||||||
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
|
||||||
from tqdm import tqdm
|
|
||||||
from vllm import AsyncLLMEngine, SamplingParams
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
from vllm.model_executor.models.registry import ModelRegistry
|
|
||||||
|
|
||||||
from deepseek_ocr import DeepseekOCRForCausalLM
|
|
||||||
|
|
||||||
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_path):
|
|
||||||
try:
|
|
||||||
image = Image.open(image_path)
|
|
||||||
|
|
||||||
corrected_image = ImageOps.exif_transpose(image)
|
|
||||||
|
|
||||||
return corrected_image
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"error: {e}")
|
|
||||||
try:
|
|
||||||
return Image.open(image_path)
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def re_match(text):
|
|
||||||
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
mathes_image = []
|
|
||||||
mathes_other = []
|
|
||||||
for a_match in matches:
|
|
||||||
if "<|ref|>image<|/ref|>" in a_match[0]:
|
|
||||||
mathes_image.append(a_match[0])
|
|
||||||
else:
|
|
||||||
mathes_other.append(a_match[0])
|
|
||||||
return matches, mathes_image, mathes_other
|
|
||||||
|
|
||||||
|
|
||||||
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
|
||||||
try:
|
|
||||||
label_type = ref_text[1]
|
|
||||||
cor_list = eval(ref_text[2])
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return (label_type, cor_list)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_bounding_boxes(image, refs):
|
|
||||||
image_width, image_height = image.size
|
|
||||||
img_draw = image.copy()
|
|
||||||
draw = ImageDraw.Draw(img_draw)
|
|
||||||
|
|
||||||
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
|
|
||||||
draw2 = ImageDraw.Draw(overlay)
|
|
||||||
|
|
||||||
# except IOError:
|
|
||||||
font = ImageFont.load_default()
|
|
||||||
|
|
||||||
img_idx = 0
|
|
||||||
|
|
||||||
for i, ref in enumerate(refs):
|
|
||||||
try:
|
|
||||||
result = extract_coordinates_and_label(ref, image_width, image_height)
|
|
||||||
if result:
|
|
||||||
label_type, points_list = result
|
|
||||||
|
|
||||||
color = (
|
|
||||||
np.random.randint(0, 200),
|
|
||||||
np.random.randint(0, 200),
|
|
||||||
np.random.randint(0, 255),
|
|
||||||
)
|
|
||||||
|
|
||||||
color_a = color + (20,)
|
|
||||||
for points in points_list:
|
|
||||||
x1, y1, x2, y2 = points
|
|
||||||
|
|
||||||
x1 = int(x1 / 999 * image_width)
|
|
||||||
y1 = int(y1 / 999 * image_height)
|
|
||||||
|
|
||||||
x2 = int(x2 / 999 * image_width)
|
|
||||||
y2 = int(y2 / 999 * image_height)
|
|
||||||
|
|
||||||
if label_type == "image":
|
|
||||||
try:
|
|
||||||
cropped = image.crop((x1, y1, x2, y2))
|
|
||||||
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
img_idx += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
if label_type == "title":
|
|
||||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
|
||||||
draw2.rectangle(
|
|
||||||
[x1, y1, x2, y2],
|
|
||||||
fill=color_a,
|
|
||||||
outline=(0, 0, 0, 0),
|
|
||||||
width=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
|
||||||
draw2.rectangle(
|
|
||||||
[x1, y1, x2, y2],
|
|
||||||
fill=color_a,
|
|
||||||
outline=(0, 0, 0, 0),
|
|
||||||
width=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_x = x1
|
|
||||||
text_y = max(0, y1 - 15)
|
|
||||||
|
|
||||||
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
|
||||||
text_width = text_bbox[2] - text_bbox[0]
|
|
||||||
text_height = text_bbox[3] - text_bbox[1]
|
|
||||||
draw.rectangle(
|
|
||||||
[text_x, text_y, text_x + text_width, text_y + text_height],
|
|
||||||
fill=(255, 255, 255, 30),
|
|
||||||
)
|
|
||||||
|
|
||||||
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
img_draw.paste(overlay, (0, 0), overlay)
|
|
||||||
return img_draw
|
|
||||||
|
|
||||||
|
|
||||||
def process_image_with_refs(image, ref_texts):
|
|
||||||
result_image = draw_bounding_boxes(image, ref_texts)
|
|
||||||
return result_image
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_generate(image=None, prompt=""):
|
|
||||||
engine_args = AsyncEngineArgs(
|
|
||||||
model=MODEL_PATH,
|
|
||||||
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
|
||||||
block_size=256,
|
|
||||||
max_model_len=8192,
|
|
||||||
enforce_eager=False,
|
|
||||||
trust_remote_code=True,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
gpu_memory_utilization=0.75,
|
|
||||||
)
|
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
logits_processors = [
|
|
||||||
NoRepeatNGramLogitsProcessor(
|
|
||||||
ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822}
|
|
||||||
)
|
|
||||||
] # whitelist: <td>, </td>
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=8192,
|
|
||||||
logits_processors=logits_processors,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
# ignore_eos=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
request_id = f"request-{int(time.time())}"
|
|
||||||
|
|
||||||
printed_length = 0
|
|
||||||
|
|
||||||
if image and "<image>" in prompt:
|
|
||||||
request = {"prompt": prompt, "multi_modal_data": {"image": image}}
|
|
||||||
elif prompt:
|
|
||||||
request = {"prompt": prompt}
|
|
||||||
else:
|
|
||||||
assert False, "prompt is none!!!"
|
|
||||||
async for request_output in engine.generate(request, sampling_params, request_id):
|
|
||||||
if request_output.outputs:
|
|
||||||
full_text = request_output.outputs[0].text
|
|
||||||
new_text = full_text[printed_length:]
|
|
||||||
print(new_text, end="", flush=True)
|
|
||||||
printed_length = len(full_text)
|
|
||||||
final_output = full_text
|
|
||||||
print("\n")
|
|
||||||
|
|
||||||
return final_output
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
|
||||||
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
|
|
||||||
|
|
||||||
image = load_image(INPUT_PATH).convert("RGB")
|
|
||||||
|
|
||||||
if "<image>" in PROMPT:
|
|
||||||
image_features = DeepseekOCRProcessor().tokenize_with_images(
|
|
||||||
images=[image], bos=True, eos=True, cropping=CROP_MODE
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_features = ""
|
|
||||||
|
|
||||||
prompt = PROMPT
|
|
||||||
|
|
||||||
result_out = asyncio.run(stream_generate(image_features, prompt))
|
|
||||||
|
|
||||||
save_results = 1
|
|
||||||
|
|
||||||
if save_results and "<image>" in prompt:
|
|
||||||
print("=" * 15 + "save results:" + "=" * 15)
|
|
||||||
|
|
||||||
image_draw = image.copy()
|
|
||||||
|
|
||||||
outputs = result_out
|
|
||||||
|
|
||||||
with open(f"{OUTPUT_PATH}/result_ori.mmd", "w", encoding="utf-8") as afile:
|
|
||||||
afile.write(outputs)
|
|
||||||
|
|
||||||
matches_ref, matches_images, mathes_other = re_match(outputs)
|
|
||||||
# print(matches_ref)
|
|
||||||
result = process_image_with_refs(image_draw, matches_ref)
|
|
||||||
|
|
||||||
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
|
|
||||||
outputs = outputs.replace(
|
|
||||||
a_match_image, " + ".jpg)\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
|
||||||
outputs = (
|
|
||||||
outputs.replace(a_match_other, "")
|
|
||||||
.replace("\\coloneqq", ":=")
|
|
||||||
.replace("\\eqqcolon", "=:")
|
|
||||||
)
|
|
||||||
|
|
||||||
# if 'structural formula' in conversation[0]['content']:
|
|
||||||
# outputs = '<smiles>' + outputs + '</smiles>'
|
|
||||||
with open(f"{OUTPUT_PATH}/result.mmd", "w", encoding="utf-8") as afile:
|
|
||||||
afile.write(outputs)
|
|
||||||
|
|
||||||
if "line_type" in outputs:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.patches import Circle
|
|
||||||
|
|
||||||
lines = eval(outputs)["Line"]["line"]
|
|
||||||
|
|
||||||
line_type = eval(outputs)["Line"]["line_type"]
|
|
||||||
# print(lines)
|
|
||||||
|
|
||||||
endpoints = eval(outputs)["Line"]["line_endpoint"]
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
|
|
||||||
ax.set_xlim(-15, 15)
|
|
||||||
ax.set_ylim(-15, 15)
|
|
||||||
|
|
||||||
for idx, line in enumerate(lines):
|
|
||||||
try:
|
|
||||||
p0 = eval(line.split(" -- ")[0])
|
|
||||||
p1 = eval(line.split(" -- ")[-1])
|
|
||||||
|
|
||||||
if line_type[idx] == "--":
|
|
||||||
ax.plot(
|
|
||||||
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax.plot(
|
|
||||||
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.scatter(p0[0], p0[1], s=5, color="k")
|
|
||||||
ax.scatter(p1[0], p1[1], s=5, color="k")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for endpoint in endpoints:
|
|
||||||
label = endpoint.split(": ")[0]
|
|
||||||
(x, y) = eval(endpoint.split(": ")[1])
|
|
||||||
ax.annotate(
|
|
||||||
label,
|
|
||||||
(x, y),
|
|
||||||
xytext=(1, 1),
|
|
||||||
textcoords="offset points",
|
|
||||||
fontsize=5,
|
|
||||||
fontweight="light",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if "Circle" in eval(outputs).keys():
|
|
||||||
circle_centers = eval(outputs)["Circle"]["circle_center"]
|
|
||||||
radius = eval(outputs)["Circle"]["radius"]
|
|
||||||
|
|
||||||
for center, r in zip(circle_centers, radius):
|
|
||||||
center = eval(center.split(": ")[1])
|
|
||||||
circle = Circle(
|
|
||||||
center,
|
|
||||||
radius=r,
|
|
||||||
fill=False,
|
|
||||||
edgecolor="black",
|
|
||||||
linewidth=0.8,
|
|
||||||
)
|
|
||||||
ax.add_patch(circle)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
plt.savefig(f"{OUTPUT_PATH}/geo.jpg")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
result.save(f"{OUTPUT_PATH}/result_with_boxes.jpg")
|
|
||||||
@@ -1,352 +0,0 @@
|
|||||||
import io
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from config import env_setup
|
|
||||||
|
|
||||||
env_setup.setup_environment()
|
|
||||||
|
|
||||||
import fitz
|
|
||||||
import img2pdf
|
|
||||||
import numpy as np
|
|
||||||
from config.model_settings import (
|
|
||||||
CROP_MODE,
|
|
||||||
INPUT_PATH,
|
|
||||||
MAX_CONCURRENCY,
|
|
||||||
MODEL_PATH,
|
|
||||||
NUM_WORKERS,
|
|
||||||
OUTPUT_PATH,
|
|
||||||
PROMPT,
|
|
||||||
SKIP_REPEAT,
|
|
||||||
)
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
from process.image_process import DeepseekOCRProcessor
|
|
||||||
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
|
||||||
from tqdm import tqdm
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.model_executor.models.registry import ModelRegistry
|
|
||||||
|
|
||||||
from deepseek_ocr import DeepseekOCRForCausalLM
|
|
||||||
|
|
||||||
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
|
||||||
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model=MODEL_PATH,
|
|
||||||
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
|
||||||
block_size=256,
|
|
||||||
enforce_eager=False,
|
|
||||||
trust_remote_code=True,
|
|
||||||
max_model_len=8192,
|
|
||||||
swap_space=0,
|
|
||||||
max_num_seqs=MAX_CONCURRENCY,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
gpu_memory_utilization=0.9,
|
|
||||||
disable_mm_preprocessor_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_processors = [
|
|
||||||
NoRepeatNGramLogitsProcessor(
|
|
||||||
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
|
|
||||||
)
|
|
||||||
] # window for fast;whitelist_token_ids: <td>,</td>
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=8192,
|
|
||||||
logits_processors=logits_processors,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
include_stop_str_in_output=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Colors:
|
|
||||||
RED = "\033[31m"
|
|
||||||
GREEN = "\033[32m"
|
|
||||||
YELLOW = "\033[33m"
|
|
||||||
BLUE = "\033[34m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
|
|
||||||
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
|
|
||||||
"""
|
|
||||||
pdf2images
|
|
||||||
"""
|
|
||||||
images = []
|
|
||||||
|
|
||||||
pdf_document = fitz.open(pdf_path)
|
|
||||||
|
|
||||||
zoom = dpi / 72.0
|
|
||||||
matrix = fitz.Matrix(zoom, zoom)
|
|
||||||
|
|
||||||
for page_num in range(pdf_document.page_count):
|
|
||||||
page = pdf_document[page_num]
|
|
||||||
|
|
||||||
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
|
||||||
Image.MAX_IMAGE_PIXELS = None
|
|
||||||
|
|
||||||
if image_format.upper() == "PNG":
|
|
||||||
img_data = pixmap.tobytes("png")
|
|
||||||
img = Image.open(io.BytesIO(img_data))
|
|
||||||
else:
|
|
||||||
img_data = pixmap.tobytes("png")
|
|
||||||
img = Image.open(io.BytesIO(img_data))
|
|
||||||
if img.mode in ("RGBA", "LA"):
|
|
||||||
background = Image.new("RGB", img.size, (255, 255, 255))
|
|
||||||
background.paste(
|
|
||||||
img, mask=img.split()[-1] if img.mode == "RGBA" else None
|
|
||||||
)
|
|
||||||
img = background
|
|
||||||
|
|
||||||
images.append(img)
|
|
||||||
|
|
||||||
pdf_document.close()
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def pil_to_pdf_img2pdf(pil_images, output_path):
|
|
||||||
if not pil_images:
|
|
||||||
return
|
|
||||||
|
|
||||||
image_bytes_list = []
|
|
||||||
|
|
||||||
for img in pil_images:
|
|
||||||
if img.mode != "RGB":
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
img_buffer = io.BytesIO()
|
|
||||||
img.save(img_buffer, format="JPEG", quality=95)
|
|
||||||
img_bytes = img_buffer.getvalue()
|
|
||||||
image_bytes_list.append(img_bytes)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pdf_bytes = img2pdf.convert(image_bytes_list)
|
|
||||||
with open(output_path, "wb") as f:
|
|
||||||
f.write(pdf_bytes)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def re_match(text):
|
|
||||||
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
mathes_image = []
|
|
||||||
mathes_other = []
|
|
||||||
for a_match in matches:
|
|
||||||
if "<|ref|>image<|/ref|>" in a_match[0]:
|
|
||||||
mathes_image.append(a_match[0])
|
|
||||||
else:
|
|
||||||
mathes_other.append(a_match[0])
|
|
||||||
return matches, mathes_image, mathes_other
|
|
||||||
|
|
||||||
|
|
||||||
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
|
||||||
try:
|
|
||||||
label_type = ref_text[1]
|
|
||||||
cor_list = eval(ref_text[2])
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return (label_type, cor_list)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_bounding_boxes(image, refs, jdx):
|
|
||||||
image_width, image_height = image.size
|
|
||||||
img_draw = image.copy()
|
|
||||||
draw = ImageDraw.Draw(img_draw)
|
|
||||||
|
|
||||||
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
|
|
||||||
draw2 = ImageDraw.Draw(overlay)
|
|
||||||
|
|
||||||
# except IOError:
|
|
||||||
font = ImageFont.load_default()
|
|
||||||
|
|
||||||
img_idx = 0
|
|
||||||
|
|
||||||
for i, ref in enumerate(refs):
|
|
||||||
try:
|
|
||||||
result = extract_coordinates_and_label(ref, image_width, image_height)
|
|
||||||
if result:
|
|
||||||
label_type, points_list = result
|
|
||||||
|
|
||||||
color = (
|
|
||||||
np.random.randint(0, 200),
|
|
||||||
np.random.randint(0, 200),
|
|
||||||
np.random.randint(0, 255),
|
|
||||||
)
|
|
||||||
|
|
||||||
color_a = color + (20,)
|
|
||||||
for points in points_list:
|
|
||||||
x1, y1, x2, y2 = points
|
|
||||||
|
|
||||||
x1 = int(x1 / 999 * image_width)
|
|
||||||
y1 = int(y1 / 999 * image_height)
|
|
||||||
|
|
||||||
x2 = int(x2 / 999 * image_width)
|
|
||||||
y2 = int(y2 / 999 * image_height)
|
|
||||||
|
|
||||||
if label_type == "image":
|
|
||||||
try:
|
|
||||||
cropped = image.crop((x1, y1, x2, y2))
|
|
||||||
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
img_idx += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
if label_type == "title":
|
|
||||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
|
||||||
draw2.rectangle(
|
|
||||||
[x1, y1, x2, y2],
|
|
||||||
fill=color_a,
|
|
||||||
outline=(0, 0, 0, 0),
|
|
||||||
width=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
|
||||||
draw2.rectangle(
|
|
||||||
[x1, y1, x2, y2],
|
|
||||||
fill=color_a,
|
|
||||||
outline=(0, 0, 0, 0),
|
|
||||||
width=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_x = x1
|
|
||||||
text_y = max(0, y1 - 15)
|
|
||||||
|
|
||||||
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
|
||||||
text_width = text_bbox[2] - text_bbox[0]
|
|
||||||
text_height = text_bbox[3] - text_bbox[1]
|
|
||||||
draw.rectangle(
|
|
||||||
[text_x, text_y, text_x + text_width, text_y + text_height],
|
|
||||||
fill=(255, 255, 255, 30),
|
|
||||||
)
|
|
||||||
|
|
||||||
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
img_draw.paste(overlay, (0, 0), overlay)
|
|
||||||
return img_draw
|
|
||||||
|
|
||||||
|
|
||||||
def process_image_with_refs(image, ref_texts, jdx):
|
|
||||||
result_image = draw_bounding_boxes(image, ref_texts, jdx)
|
|
||||||
return result_image
|
|
||||||
|
|
||||||
|
|
||||||
def process_single_image(image):
|
|
||||||
"""single image"""
|
|
||||||
prompt_in = prompt
|
|
||||||
cache_item = {
|
|
||||||
"prompt": prompt_in,
|
|
||||||
"multi_modal_data": {
|
|
||||||
"image": DeepseekOCRProcessor().tokenize_with_images(
|
|
||||||
images=[image], bos=True, eos=True, cropping=CROP_MODE
|
|
||||||
)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
|
||||||
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
|
|
||||||
|
|
||||||
print(f"{Colors.RED}PDF loading .....{Colors.RESET}")
|
|
||||||
|
|
||||||
images = pdf_to_images_high_quality(INPUT_PATH)
|
|
||||||
|
|
||||||
prompt = PROMPT
|
|
||||||
|
|
||||||
# batch_inputs = []
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
|
|
||||||
batch_inputs = list(
|
|
||||||
tqdm(
|
|
||||||
executor.map(process_single_image, images),
|
|
||||||
total=len(images),
|
|
||||||
desc="Pre-processed images",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# for image in tqdm(images):
|
|
||||||
|
|
||||||
# prompt_in = prompt
|
|
||||||
# cache_list = [
|
|
||||||
# {
|
|
||||||
# "prompt": prompt_in,
|
|
||||||
# "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# batch_inputs.extend(cache_list)
|
|
||||||
|
|
||||||
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
|
|
||||||
|
|
||||||
output_path = OUTPUT_PATH
|
|
||||||
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
|
|
||||||
mmd_det_path = (
|
|
||||||
output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_det.mmd")
|
|
||||||
)
|
|
||||||
mmd_path = output_path + "/" + INPUT_PATH.split("/")[-1].replace("pdf", "mmd")
|
|
||||||
pdf_out_path = (
|
|
||||||
output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_layouts.pdf")
|
|
||||||
)
|
|
||||||
contents_det = ""
|
|
||||||
contents = ""
|
|
||||||
draw_images = []
|
|
||||||
jdx = 0
|
|
||||||
for output, img in zip(outputs_list, images):
|
|
||||||
content = output.outputs[0].text
|
|
||||||
|
|
||||||
if "<|end▁of▁sentence|>" in content: # repeat no eos
|
|
||||||
content = content.replace("<|end▁of▁sentence|>", "")
|
|
||||||
else:
|
|
||||||
if SKIP_REPEAT:
|
|
||||||
continue
|
|
||||||
|
|
||||||
page_num = "\n<--- Page Split --->"
|
|
||||||
|
|
||||||
contents_det += content + f"\n{page_num}\n"
|
|
||||||
|
|
||||||
image_draw = img.copy()
|
|
||||||
|
|
||||||
matches_ref, matches_images, mathes_other = re_match(content)
|
|
||||||
# print(matches_ref)
|
|
||||||
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
|
|
||||||
|
|
||||||
draw_images.append(result_image)
|
|
||||||
|
|
||||||
for idx, a_match_image in enumerate(matches_images):
|
|
||||||
content = content.replace(
|
|
||||||
a_match_image, " + "_" + str(idx) + ".jpg)\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, a_match_other in enumerate(mathes_other):
|
|
||||||
content = (
|
|
||||||
content.replace(a_match_other, "")
|
|
||||||
.replace("\\coloneqq", ":=")
|
|
||||||
.replace("\\eqqcolon", "=:")
|
|
||||||
.replace("\n\n\n\n", "\n\n")
|
|
||||||
.replace("\n\n\n", "\n\n")
|
|
||||||
)
|
|
||||||
|
|
||||||
contents += content + f"\n{page_num}\n"
|
|
||||||
|
|
||||||
jdx += 1
|
|
||||||
|
|
||||||
with open(mmd_det_path, "w", encoding="utf-8") as afile:
|
|
||||||
afile.write(contents_det)
|
|
||||||
|
|
||||||
with open(mmd_path, "w", encoding="utf-8") as afile:
|
|
||||||
afile.write(contents)
|
|
||||||
|
|
||||||
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
|
|
||||||
Reference in New Issue
Block a user