Compare commits

..

11 Commits

Author SHA1 Message Date
kyy
732e7c8cc0 서빙 설정 변수 및 파일 추가 2025-11-06 15:12:26 +09:00
kyy
7749492ae7 후처리 로직 추가 2025-11-06 14:38:54 +09:00
kyy
8a6f2ae2d8 OCR 엔진 비동기 처리 전환 2025-11-06 14:37:15 +09:00
kyy
f5ab36737a 변수명 재정의 2025-11-06 14:37:03 +09:00
kyy
d14e3ea64a gitignore 반영: 테스트 산출물(pdf/jpg/json) 추적 제외 2025-11-06 12:03:06 +09:00
kyy
f757a541f8 레거시 모드 2025-11-06 12:02:22 +09:00
kyy
6a3b52fe7c 호출 경로 수정 2025-11-06 12:01:35 +09:00
kyy
715eaf8c8c 테스트 파일 제외 2025-11-06 12:01:01 +09:00
kyy
723fd4333e deepseek-ocr 구동 테스트 2025-11-06 11:58:17 +09:00
kyy
2c3b417f3b Lint 적용 2025-11-06 11:57:29 +09:00
kyy
f9975620cb 호출 경로 수정 2025-11-06 11:55:44 +09:00
14 changed files with 795 additions and 130 deletions

29
.env Normal file
View File

@@ -0,0 +1,29 @@
# --------------------------------------------------------------------------
# vLLM Engine Configuration
# --------------------------------------------------------------------------
# 이 파일의 주석(#)을 제거하고 값을 수정하여 기본 엔진 설정을 재정의할 수 있습니다.
# 설정 가능한 변수 목록은 config/engine_settings.py 파일을 참고하세요.
# Hugging Face 모델 경로 또는 로컬 경로
# MODEL_PATH="deepseek-ai/DeepSeek-OCR"
# 텐서 병렬 처리 크기 (Multi-GPU 환경에서 사용)
# TENSOR_PARALLEL_SIZE=1
# 최대 GPU 메모리 사용률 (0.0 ~ 1.0)
# GPU_MEMORY_UTILIZATION=0.15
# KV 캐시 블록 크기
# BLOCK_SIZE=256
# 최대 모델 길이
# MAX_MODEL_LEN=8192
# Eager 모드 강제 실행 여부 (True / False)
# ENFORCE_EAGER=False
# 원격 코드 신뢰 여부 (True / False)
# TRUST_REMOTE_CODE=True
# 사용자 정의 모델 아키텍처 (쉼표로 구분)
# ARCHITECTURES="DeepseekOCRForCausalLM"

16
.gitignore vendored
View File

@@ -4,4 +4,18 @@ __pycache__/
*.pyc *.pyc
*.pyo *.pyo
*.pyd *.pyd
*.log *.log
gemini.md
test/input/
test/output/
*.pdf
*.jpg
*.png
*.jpeg
*.tiff
*.bmp
*.gif
*.svg
*.json

23
api.py
View File

@@ -1,7 +1,19 @@
import logging import logging
from config.env_setup import setup_environment
# 환경 변수 설정을 최우선으로 호출
setup_environment()
# 로깅 기본 설정
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger("startup")
from fastapi import FastAPI from fastapi import FastAPI
from router import deepseek_router from router import deepseek_router
from services.ocr_engine import init_engine
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
@@ -14,6 +26,17 @@ app = FastAPI(
) )
@app.on_event("startup")
async def startup_event():
"""FastAPI startup event handler."""
logging.info("Application startup...")
try:
await init_engine()
logging.info("vLLM engine initialized successfully.")
except Exception as e:
logging.error(f"vLLM engine init failed: {e}", exc_info=True)
@app.get("/health/API", include_in_schema=False) @app.get("/health/API", include_in_schema=False)
async def health_check(): async def health_check():
return {"status": "API ok"} return {"status": "API ok"}

35
config/engine_settings.py Normal file
View File

@@ -0,0 +1,35 @@
import os
def _str_to_bool(value: str) -> bool:
"""환경 변수(문자열)를 boolean 값으로 변환합니다."""
return value.lower() in ("true", "1", "t")
# --------------------------------------------------------------------------
# vLLM Engine Configuration
# .env 파일에 동일한 이름의 환경 변수를 설정하여 아래 기본값을 재정의할 수 있습니다.
# --------------------------------------------------------------------------
# 사용자 정의 모델 아키텍처
# 여러 개일 경우 쉼표로 구분: "Arch1,Arch2"
_architectures_str = os.getenv("ARCHITECTURES", "DeepseekOCRForCausalLM")
ARCHITECTURES = [arch.strip() for arch in _architectures_str.split(",")]
# KV 캐시 블록 크기
BLOCK_SIZE = int(os.getenv("BLOCK_SIZE", "256"))
# 최대 모델 길이
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
# Eager 모드 강제 실행 여부
ENFORCE_EAGER = _str_to_bool(os.getenv("ENFORCE_EAGER", "False"))
# 원격 코드 신뢰 여부
TRUST_REMOTE_CODE = _str_to_bool(os.getenv("TRUST_REMOTE_CODE", "True"))
# 텐서 병렬 처리 크기
TENSOR_PARALLEL_SIZE = int(os.getenv("TENSOR_PARALLEL_SIZE", "1"))
# GPU 메모리 사용률
GPU_MEMORY_UTILIZATION = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.15"))

View File

@@ -22,8 +22,8 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py # Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = "/workspace/input" INPUT_PATH = "/workspace/test/input"
OUTPUT_PATH = "/workspace/output" OUTPUT_PATH = "/workspace/test/output"
# PROMPT = f"{PROMPT_TEXT.strip()}" # PROMPT = f"{PROMPT_TEXT.strip()}"
PROMPT = "<image>\n<|grounding|>Convert the document to markdown." PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
# PROMPT = '<image>\nFree OCR.' # PROMPT = '<image>\nFree OCR.'

View File

@@ -11,3 +11,4 @@ matplotlib
fastapi fastapi
uvicorn[standard] uvicorn[standard]
python-multipart python-multipart
python-dotenv

View File

@@ -1,20 +1,25 @@
import logging import logging
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from services.ocr_engine import process_document from services.ocr_engine import process_document
router = APIRouter(prefix="/ocr", tags=["OCR"]) router = APIRouter(prefix="/ocr", tags=["OCR"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@router.post("", description="요청된 파일에서 Deepseek OCR을 수행하고 텍스트를 추출합니다.") @router.post(
async def perform_ocr(document: UploadFile = File(..., description="OCR을 수행할 PDF 또는 이미지 파일")): "", description="요청된 파일에서 Deepseek OCR을 수행하고 텍스트를 추출합니다."
)
async def perform_ocr(
document: UploadFile = File(..., description="OCR을 수행할 PDF 또는 이미지 파일"),
):
""" """
클라이언트로부터 받은 파일을 OCR 엔진에 전달하고, 추출된 텍스트를 반환합니다. 클라이언트로부터 받은 파일을 OCR 엔진에 전달하고, 추출된 텍스트를 반환합니다.
- **document**: `multipart/form-data` 형식으로 전송된 파일. - **document**: `multipart/form-data` 형식으로 전송된 파일.
""" """
logger.info(f"'{document.filename}' 파일에 대한 OCR 요청 수신 (Content-Type: {document.content_type})") logger.info(
f"'{document.filename}' 파일에 대한 OCR 요청 수신 (Content-Type: {document.content_type})"
)
try: try:
file_content = await document.read() file_content = await document.read()
@@ -22,7 +27,7 @@ async def perform_ocr(document: UploadFile = File(..., description="OCR을 수
raise HTTPException(status_code=400, detail="파일 내용이 비어있습니다.") raise HTTPException(status_code=400, detail="파일 내용이 비어있습니다.")
# 모든 처리 로직을 OCR 엔진에 위임 # 모든 처리 로직을 OCR 엔진에 위임
result = process_document( result = await process_document(
file_bytes=file_content, file_bytes=file_content,
content_type=document.content_type, content_type=document.content_type,
filename=document.filename, filename=document.filename,
@@ -36,6 +41,8 @@ async def perform_ocr(document: UploadFile = File(..., description="OCR을 수
except Exception as e: except Exception as e:
# 예상치 못한 서버 내부 오류 # 예상치 못한 서버 내부 오류
logger.exception(f"OCR 처리 중 예상치 못한 오류 발생: {e}") logger.exception(f"OCR 처리 중 예상치 못한 오류 발생: {e}")
raise HTTPException(status_code=500, detail=f"서버 내부 오류가 발생했습니다: {e}") raise HTTPException(
status_code=500, detail=f"서버 내부 오류가 발생했습니다: {e}"
)
finally: finally:
await document.close() await document.close()

0
services/__init__.py Normal file
View File

View File

@@ -9,11 +9,17 @@ import torch.nn as nn
from addict import Dict from addict import Dict
# import time # import time
from config import BASE_SIZE, CROP_MODE, IMAGE_SIZE, PRINT_NUM_VIS_TOKENS, PROMPT from config.model_settings import (
from deepencoder.build_linear import MlpProjector BASE_SIZE,
from deepencoder.clip_sdpa import build_clip_l CROP_MODE,
from deepencoder.sam_vary_sdpa import build_sam_vit_b IMAGE_SIZE,
from process.image_process import DeepseekOCRProcessor, count_tiles PRINT_NUM_VIS_TOKENS,
PROMPT,
)
from services.deepencoder.build_linear import MlpProjector
from services.deepencoder.clip_sdpa import build_clip_l
from services.deepencoder.sam_vary_sdpa import build_sam_vit_b
from services.process.image_process import DeepseekOCRProcessor, count_tiles
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata

View File

@@ -1,37 +1,46 @@
import asyncio import asyncio
import io import io
from typing import Union import logging
import re
import fitz import fitz
from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT from config import engine_settings, model_settings
from fastapi import UploadFile from PIL import Image, ImageOps
from PIL import Image
from process.image_process import DeepseekOCRProcessor
from vllm import AsyncLLMEngine, SamplingParams from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from services.deepseek_ocr import DeepseekOCRForCausalLM from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
logger = logging.getLogger(__name__)
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# 1. 모델 및 프로세서 초기화 # 1. 모델 및 프로세서 초기화
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# VLLM이 커스텀 모델을 인식하도록 등록
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
_engine = None
async def init_engine():
"""vLLM 엔진을 비동기적으로 초기화합니다."""
global _engine
if _engine is not None:
return
engine_args = AsyncEngineArgs(
model=model_settings.MODEL_PATH,
hf_overrides={"architectures": engine_settings.ARCHITECTURES},
block_size=engine_settings.BLOCK_SIZE,
max_model_len=engine_settings.MAX_MODEL_LEN,
enforce_eager=engine_settings.ENFORCE_EAGER,
trust_remote_code=engine_settings.TRUST_REMOTE_CODE,
tensor_parallel_size=engine_settings.TENSOR_PARALLEL_SIZE,
gpu_memory_utilization=engine_settings.GPU_MEMORY_UTILIZATION,
)
_engine = AsyncLLMEngine.from_engine_args(engine_args)
# VLLM 비동기 엔진 설정
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# 샘플링 파라미터 설정 # 샘플링 파라미터 설정
sampling_params = SamplingParams( sampling_params = SamplingParams(
@@ -44,30 +53,89 @@ sampling_params = SamplingParams(
processor = DeepseekOCRProcessor() processor = DeepseekOCRProcessor()
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# 2. 핵심 처리 함수 # 2. 처리 함수 (공식 코드 기반)
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
async def _process_single_image(image: Image.Image) -> str:
"""단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다.""" def _postprocess_text(text: str, page_num: int = 0) -> str:
if "<image>" not in PROMPT: """
모델의 원본 출력에서 태그를 제거/변경하고 텍스트를 정리합니다.
(test/test.py의 후처리 로직 기반으로 수정)
"""
if not text:
return ""
# 1. 정규식으로 모든 ref/det 태그 블록을 찾음
# 패턴은 (전체 매치, ref 내용, det 내용)을 캡처
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
# 2. 전체 매치된 문자열을 이미지 태그와 기타 태그로 분리
matches_images = []
matches_other = []
for match_tuple in matches:
full_match_str = match_tuple[0] # 전체 매치된 부분
ref_content = match_tuple[1] # <|ref|> 안의 내용
if "image" in ref_content:
matches_images.append(full_match_str)
else:
matches_other.append(full_match_str)
processed_text = text
# 3. 이미지 태그는 마크다운 링크로 대체
for idx, img_tag in enumerate(matches_images):
img_link = f"![](images/{page_num}_{idx}.jpg)\n"
processed_text = processed_text.replace(img_tag, img_link)
# 4. 이미지가 아닌 다른 모든 태그는 제거
for other_tag in matches_other:
processed_text = processed_text.replace(other_tag, "")
# 5. 특수 문자, 불필요한 토큰 및 추가 공백 정리
processed_text = (
processed_text.replace("<end of sentence>", "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
.strip()
)
return processed_text
# --------------------------------------------------------------------------
# 3. 핵심 처리 함수
# --------------------------------------------------------------------------
async def _process_single_image(image: Image.Image, page_num: int = 0) -> str:
"""단일 PIL 이미지를 받아 OCR을 수행하고 후처리된 텍스트를 반환합니다."""
if _engine is None:
raise RuntimeError("vLLM engine not initialized yet")
if "<image>" not in model_settings.PROMPT:
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.") raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
image_features = processor.tokenize_with_images( image_features = processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE images=[image], bos=True, eos=True, cropping=model_settings.CROP_MODE
) )
request = {"prompt": PROMPT, "multi_modal_data": {"image": image_features}} request = {"prompt": model_settings.PROMPT, "multi_modal_data": {"image": image_features}}
request_id = f"request-{asyncio.get_running_loop().time()}" request_id = f"request-{asyncio.get_running_loop().time()}"
final_output = "" raw_output = ""
async for request_output in engine.generate(request, sampling_params, request_id): async for request_output in _engine.generate(request, sampling_params, request_id):
if request_output.outputs: if request_output.outputs:
final_output = request_output.outputs[0].text raw_output = request_output.outputs[0].text
return final_output # 후처리 적용 (페이지 번호 전달)
clean_text = _postprocess_text(raw_output, page_num)
return clean_text
def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
"""PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다.""" def _pdf_to_images_high_quality(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
"""PDF 바이트를 받아 페이지별 고품질 PIL 이미지 리스트를 반환합니다."""
images = [] images = []
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf") pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
zoom = dpi / 72.0 zoom = dpi / 72.0
@@ -76,37 +144,56 @@ def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
for page_num in range(pdf_document.page_count): for page_num in range(pdf_document.page_count):
page = pdf_document[page_num] page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False) pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png") img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data)) img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
images.append(img) images.append(img)
pdf_document.close() pdf_document.close()
return images return images
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict: async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
""" """
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다. 업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
""" """
if content_type.startswith("image/"): # Content-Type이 generic할 경우, 파일 확장자로 타입을 유추
inferred_content_type = content_type
if content_type == "application/octet-stream":
if filename.lower().endswith(".pdf"):
inferred_content_type = "application/pdf"
elif any(
filename.lower().endswith(ext)
for ext in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
):
inferred_content_type = "image/jpeg" # 구체적인 타입은 중요하지 않음
if inferred_content_type.startswith("image/"):
try: try:
image = Image.open(io.BytesIO(file_bytes)).convert("RGB") image = Image.open(io.BytesIO(file_bytes))
image = ImageOps.exif_transpose(image).convert("RGB")
except Exception as e: except Exception as e:
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}") raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
result_text = await _process_single_image(image) # 단일 이미지는 페이지 번호를 0으로 간주
result_text = await _process_single_image(image, page_num=0)
return {"filename": filename, "text": result_text} return {"filename": filename, "text": result_text}
elif content_type == "application/pdf": elif inferred_content_type == "application/pdf":
try: try:
images = _pdf_to_images(file_bytes) images = _pdf_to_images_high_quality(file_bytes)
except Exception as e: except Exception as e:
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}") raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
# 각 페이지를 비동기적으로 처리 # 각 페이지를 비동기적으로 처리 (페이지 번호 전달)
tasks = [_process_single_image(img) for img in images] tasks = [_process_single_image(img, page_num=i) for i, img in enumerate(images)]
page_results = await asyncio.gather(*tasks) page_results = await asyncio.gather(*tasks)
# 페이지 구분자를 넣어 전체 텍스트 합치기
full_text = "\n<--- Page Split --->\n".join(page_results) full_text = "\n<--- Page Split --->\n".join(page_results)
return {"filename": filename, "text": full_text, "page_count": len(images)} return {"filename": filename, "text": full_text, "page_count": len(images)}
@@ -114,4 +201,4 @@ async def process_document(file_bytes: bytes, content_type: str, filename: str)
raise ValueError( raise ValueError(
f"지원하지 않는 파일 형식입니다: {content_type}. " f"지원하지 않는 파일 형식입니다: {content_type}. "
"이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요." "이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요."
) )

View File

@@ -3,13 +3,21 @@ from typing import List, Tuple
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from config.model_settings import (
BASE_SIZE,
IMAGE_SIZE,
MAX_CROPS,
MIN_CROPS,
PROMPT,
TOKENIZER,
)
from PIL import Image, ImageOps from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast from transformers import AutoProcessor, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf') best_ratio_diff = float("inf")
best_ratio = (1, 1) best_ratio = (1, 1)
area = width * height area = width * height
for ratio in target_ratios: for ratio in target_ratios:
@@ -25,37 +33,56 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
return best_ratio return best_ratio
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False): def count_tiles(
orig_width,
orig_height,
min_num=MIN_CROPS,
max_num=MAX_CROPS,
image_size=640,
use_thumbnail=False,
):
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
target_ratios = set( target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if (i, j)
i * j <= max_num and i * j >= min_num) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
# print(target_ratios) # print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target # find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio( target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size) aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
return target_aspect_ratio return target_aspect_ratio
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False): def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
target_ratios = set( target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if (i, j)
i * j <= max_num and i * j >= min_num) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
# print(target_ratios) # print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target # find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio( target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size) aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# print(target_aspect_ratio) # print(target_aspect_ratio)
# calculate the target width and height # calculate the target width and height
@@ -71,7 +98,7 @@ def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=6
(i % (target_width // image_size)) * image_size, (i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size, ((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size ((i // (target_width // image_size)) + 1) * image_size,
) )
# split the image # split the image
split_img = resized_img.crop(box) split_img = resized_img.crop(box)
@@ -83,15 +110,13 @@ def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=6
return processed_images, target_aspect_ratio return processed_images, target_aspect_ratio
class ImageTransform: class ImageTransform:
def __init__(
def __init__(self, self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True): normalize: bool = True,
):
self.mean = mean self.mean = mean
self.std = std self.std = std
self.normalize = normalize self.normalize = normalize
@@ -129,28 +154,28 @@ class DeepseekOCRProcessor(ProcessorMixin):
ignore_id: int = -100, ignore_id: int = -100,
**kwargs, **kwargs,
): ):
# self.candidate_resolutions = candidate_resolutions # placeholder no use # self.candidate_resolutions = candidate_resolutions # placeholder no use
self.image_size = IMAGE_SIZE self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE self.base_size = BASE_SIZE
# self.patch_size = patch_size # self.patch_size = patch_size
self.patch_size = 16 self.patch_size = 16
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
self.normalize = normalize self.normalize = normalize
# self.downsample_ratio = downsample_ratio # self.downsample_ratio = downsample_ratio
self.downsample_ratio = 4 self.downsample_ratio = 4
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
# self.tokenizer = add_special_token(tokenizer) # self.tokenizer = add_special_token(tokenizer)
self.tokenizer.padding_side = 'left' # must set thispadding side with make a difference in batch inference self.tokenizer.padding_side = "left" # must set thispadding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None: if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token}) self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token # add image token
# image_token_id = self.tokenizer.vocab.get(image_token) # image_token_id = self.tokenizer.vocab.get(image_token)
@@ -186,9 +211,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
**kwargs, **kwargs,
) )
# def select_best_resolution(self, image_size): # def select_best_resolution(self, image_size):
# # used for cropping # # used for cropping
# original_width, original_height = image_size # original_width, original_height = image_size
@@ -264,13 +286,21 @@ class DeepseekOCRProcessor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens - num_image_tokens (List[int]): the number of image tokens
""" """
assert (prompt is not None and images is not None assert (
), "prompt and images must be used at the same time." prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt sft_format = prompt
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0] (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
_,
) = images[0]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
@@ -281,7 +311,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
"num_image_tokens": num_image_tokens, "num_image_tokens": num_image_tokens,
} }
# prepare = BatchFeature( # prepare = BatchFeature(
# data=dict( # data=dict(
# input_ids=input_ids, # input_ids=input_ids,
@@ -341,7 +370,12 @@ class DeepseekOCRProcessor(ProcessorMixin):
conversation = PROMPT conversation = PROMPT
assert conversation.count(self.image_token) == len(images) assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token) text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], [] images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = [] image_shapes = []
num_image_tokens = [] num_image_tokens = []
tokenized_str = [] tokenized_str = []
@@ -368,7 +402,9 @@ class DeepseekOCRProcessor(ProcessorMixin):
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
# print('image ', image.size) # print('image ', image.size)
# print('open_size:', image.size) # print('open_size:', image.size)
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE) images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
# print('crop_ratio: ', crop_ratio) # print('crop_ratio: ', crop_ratio)
else: else:
# best_width, best_height = self.image_size, self.image_size # best_width, best_height = self.image_size, self.image_size
@@ -383,8 +419,11 @@ class DeepseekOCRProcessor(ProcessorMixin):
# print('directly resize') # print('directly resize')
image = image.resize((self.image_size, self.image_size)) image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(image, (self.base_size, self.base_size), global_view = ImageOps.pad(
color=tuple(int(x * 255) for x in self.image_transform.mean)) image,
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view)) images_list.append(self.image_transform(global_view))
"""record height / width crop num""" """record height / width crop num"""
@@ -392,9 +431,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
num_width_tiles, num_height_tiles = crop_ratio num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles]) images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1: if num_width_tiles > 1 or num_height_tiles > 1:
"""process the local views""" """process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height), # local_view = ImageOps.pad(image, (best_width, best_height),
@@ -421,15 +457,22 @@ class DeepseekOCRProcessor(ProcessorMixin):
# """add image tokens""" # """add image tokens"""
"""add image tokens""" """add image tokens"""
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) num_queries = math.ceil(
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) (self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base [self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id] tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1: if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * ( tokenized_image += (
num_queries * num_height_tiles) [self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image) images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image)) num_image_tokens.append(len(tokenized_image))
@@ -447,10 +490,9 @@ class DeepseekOCRProcessor(ProcessorMixin):
tokenized_str = tokenized_str + [self.eos_id] tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False] images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len( assert (
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" len(tokenized_str) == len(images_seq_mask)
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = [] masked_tokenized_str = []
for token_index in tokenized_str: for token_index in tokenized_str:
@@ -459,17 +501,21 @@ class DeepseekOCRProcessor(ProcessorMixin):
else: else:
masked_tokenized_str.append(self.ignore_id) masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ assert (
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") ), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str) input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
(input_ids == self.image_token_id)] = self.ignore_id self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id input_ids[input_ids < 0] = self.pad_id
inference_mode = True inference_mode = True
@@ -484,19 +530,32 @@ class DeepseekOCRProcessor(ProcessorMixin):
if len(images_list) == 0: if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0) images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else: else:
pixel_values = torch.stack(images_list, dim=0) pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list: if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else: else:
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0) images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0) input_ids = input_ids.unsqueeze(0)
return [
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]] [
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
]
]
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor) AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)

View File

@@ -1,40 +1,47 @@
from typing import List
import torch import torch
from transformers import LogitsProcessor from transformers import LogitsProcessor
from transformers.generation.logits_process import _calc_banned_ngram_tokens
from typing import List, Set
class NoRepeatNGramLogitsProcessor(LogitsProcessor): class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None): self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None
):
if not isinstance(ngram_size, int) or ngram_size <= 0: if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") raise ValueError(
f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
)
if not isinstance(window_size, int) or window_size <= 0: if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}") raise ValueError(
f"`window_size` has to be a strictly positive integer, but is {window_size}"
)
self.ngram_size = ngram_size self.ngram_size = ngram_size
self.window_size = window_size self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set() self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(
self, input_ids: List[int], scores: torch.FloatTensor
) -> torch.FloatTensor:
if len(input_ids) < self.ngram_size: if len(input_ids) < self.ngram_size:
return scores return scores
current_prefix = tuple(input_ids[-(self.ngram_size - 1):]) current_prefix = tuple(input_ids[-(self.ngram_size - 1) :])
search_start = max(0, len(input_ids) - self.window_size) search_start = max(0, len(input_ids) - self.window_size)
search_end = len(input_ids) - self.ngram_size + 1 search_end = len(input_ids) - self.ngram_size + 1
banned_tokens = set() banned_tokens = set()
for i in range(search_start, search_end): for i in range(search_start, search_end):
ngram = tuple(input_ids[i:i + self.ngram_size]) ngram = tuple(input_ids[i : i + self.ngram_size])
if ngram[:-1] == current_prefix: if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1]) banned_tokens.add(ngram[-1])
banned_tokens = banned_tokens - self.whitelist_token_ids banned_tokens = banned_tokens - self.whitelist_token_ids
if banned_tokens: if banned_tokens:
scores = scores.clone() scores = scores.clone()
for token in banned_tokens: for token in banned_tokens:
scores[token] = -float("inf") scores[token] = -float("inf")
return scores return scores

0
test/__init__.py Normal file
View File

397
test/test.py Normal file
View File

@@ -0,0 +1,397 @@
import io
import json
import os
import re
import time
import config.model_settings as config
import fitz
import img2pdf
import numpy as np
from config.env_setup import setup_environment
from PIL import Image, ImageDraw, ImageFont, ImageOps
from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
from services.process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
setup_environment()
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
class Colors:
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
# --- PDF/Image Processing Functions (from run_dpsk_ocr_*.py) ---
def pdf_to_images_high_quality(pdf_path, dpi=144):
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format="JPEG", quality=95)
image_bytes_list.append(img_buffer.getvalue())
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"Error creating PDF: {e}")
def re_match(text):
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [m[0] for m in matches if "<|ref|>image<|/ref|>" in m[0]]
mathes_other = [m[0] for m in matches if "<|ref|>image<|/ref|>" not in m[0]]
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
return (label_type, cor_list)
except Exception as e:
print(f"Error extracting coordinates: {e}")
return None
def draw_bounding_boxes(image, refs, jdx=None):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
result = extract_coordinates_and_label(ref, image_width, image_height)
if not result:
continue
label_type, points_list = result
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = [
int(p / 999 * (image_width if i % 2 == 0 else image_height))
for i, p in enumerate(points)
]
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
img_filename = (
f"{jdx}_{img_idx}.jpg" if jdx is not None else f"{img_idx}.jpg"
)
cropped.save(
os.path.join(config.OUTPUT_PATH, "images", img_filename)
)
img_idx += 1
except Exception as e:
print(f"Error cropping image: {e}")
width = 4 if label_type == "title" else 2
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
draw2.rectangle(
[x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1
)
text_x, text_y = x1, max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width, text_height = (
text_bbox[2] - text_bbox[0],
text_bbox[3] - text_bbox[1],
)
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx=None):
return draw_bounding_boxes(image, ref_texts, jdx)
def load_image(image_path):
try:
image = Image.open(image_path).convert("RGB")
return ImageOps.exif_transpose(image)
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
# --- Main OCR Processing Logic ---
def process_pdf(llm, sampling_params, pdf_path):
print(f"{Colors.GREEN}Processing PDF: {pdf_path}{Colors.RESET}")
base_name = os.path.basename(pdf_path)
file_name_without_ext = os.path.splitext(base_name)[0]
images = pdf_to_images_high_quality(pdf_path)
if not images:
print(
f"{Colors.YELLOW}Could not extract images from {pdf_path}. Skipping.{Colors.RESET}"
)
return
batch_inputs = []
processor = DeepseekOCRProcessor()
for image in tqdm(images, desc="Pre-processing PDF pages"):
batch_inputs.append(
{
"prompt": config.PROMPT,
"multi_modal_data": {
"image": processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
},
}
)
start_time = time.time()
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
end_time = time.time()
contents_det = ""
contents = ""
draw_images = []
for i, (output, img) in enumerate(zip(outputs_list, images)):
content = output.outputs[0].text
if "<end of sentence>" in content:
content = content.replace("<end of sentence>", "")
elif config.SKIP_REPEAT:
continue
page_num_separator = "\n<--- Page Split --->\n"
contents_det += content + page_num_separator
matches_ref, matches_images, mathes_other = re_match(content)
result_image = process_image_with_refs(img.copy(), matches_ref, jdx=i)
draw_images.append(result_image)
for idx, match in enumerate(matches_images):
content = content.replace(match, f"![](images/{i}_{idx}.jpg)\n")
for match in mathes_other:
content = (
content.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
contents += content + page_num_separator
# Save results
result_json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
result_pdf_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_layouts.pdf"
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": contents,
}
with open(result_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, result_pdf_path)
print(
f"{Colors.GREEN}Finished processing {pdf_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def process_image(llm, sampling_params, image_path):
print(f"{Colors.GREEN}Processing Image: {image_path}{Colors.RESET}")
base_name = os.path.basename(image_path)
file_name_without_ext = os.path.splitext(base_name)[0]
image = load_image(image_path)
if image is None:
return
processor = DeepseekOCRProcessor()
image_features = processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
request = {
"prompt": config.PROMPT,
"multi_modal_data": {"image": image_features},
}
start_time = time.time()
outputs = llm.generate([request], sampling_params)
end_time = time.time()
result_out = outputs[0].outputs[0].text
print(result_out)
# Save results
result_json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
result_image_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_result_with_boxes.jpg"
)
matches_ref, matches_images, mathes_other = re_match(result_out)
result_image = process_image_with_refs(image.copy(), matches_ref)
processed_text = result_out
for idx, match in enumerate(matches_images):
processed_text = processed_text.replace(match, f"![](images/{idx}.jpg)\n")
for match in mathes_other:
processed_text = (
processed_text.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": processed_text,
}
with open(result_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
result_image.save(result_image_path)
print(
f"{Colors.GREEN}Finished processing {image_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def main():
# --- Model Initialization ---
print(f"{Colors.BLUE}Initializing model...{Colors.RESET}")
llm = LLM(
model=config.MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=config.MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True,
)
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
print(f"{Colors.BLUE}Model initialized successfully.{Colors.RESET}")
# --- File Processing ---
input_dir = config.INPUT_PATH
output_dir = config.OUTPUT_PATH
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "result"), exist_ok=True)
if not os.path.isdir(input_dir):
print(
f"{Colors.RED}Error: Input directory not found at '{input_dir}'{Colors.RESET}"
)
return
print(f"Scanning for files in '{input_dir}'...")
for filename in sorted(os.listdir(input_dir)):
input_path = os.path.join(input_dir, filename)
if not os.path.isfile(input_path):
continue
file_extension = os.path.splitext(filename)[1].lower()
try:
if file_extension == ".pdf":
process_pdf(llm, sampling_params, input_path)
elif file_extension in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]:
process_image(llm, sampling_params, input_path)
else:
print(
f"{Colors.YELLOW}Skipping unsupported file type: {filename}{Colors.RESET}"
)
except Exception as e:
print(
f"{Colors.RED}An error occurred while processing {filename}: {e}{Colors.RESET}"
)
if __name__ == "__main__":
main()