Compare commits

...

6 Commits

Author SHA1 Message Date
kyy
f757a541f8 레거시 모드 2025-11-06 12:02:22 +09:00
kyy
6a3b52fe7c 호출 경로 수정 2025-11-06 12:01:35 +09:00
kyy
715eaf8c8c 테스트 파일 제외 2025-11-06 12:01:01 +09:00
kyy
723fd4333e deepseek-ocr 구동 테스트 2025-11-06 11:58:17 +09:00
kyy
2c3b417f3b Lint 적용 2025-11-06 11:57:29 +09:00
kyy
f9975620cb 호출 경로 수정 2025-11-06 11:55:44 +09:00
15 changed files with 646 additions and 108 deletions

14
.gitignore vendored
View File

@@ -5,3 +5,17 @@ __pycache__/
*.pyo *.pyo
*.pyd *.pyd
*.log *.log
gemini.md
test/input/
test/output/
*.pdf
*.jpg
*.png
*.jpeg
*.tiff
*.bmp
*.gif
*.svg
*.json

23
api.py
View File

@@ -1,7 +1,19 @@
import logging import logging
from config.env_setup import setup_environment
# 환경 변수 설정을 최우선으로 호출
setup_environment()
# 로깅 기본 설정
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger("startup")
from fastapi import FastAPI from fastapi import FastAPI
from router import deepseek_router from router import deepseek_router
from services.ocr_engine import init_engine
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
@@ -14,6 +26,17 @@ app = FastAPI(
) )
@app.on_event("startup")
async def startup_event():
"""FastAPI startup event handler."""
logging.info("Application startup...")
try:
await init_engine()
logging.info("vLLM engine initialized successfully.")
except Exception as e:
logging.error(f"vLLM engine init failed: {e}", exc_info=True)
@app.get("/health/API", include_in_schema=False) @app.get("/health/API", include_in_schema=False)
async def health_check(): async def health_check():
return {"status": "API ok"} return {"status": "API ok"}

View File

@@ -22,8 +22,8 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py # Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = "/workspace/input" INPUT_PATH = "/workspace/test/input"
OUTPUT_PATH = "/workspace/output" OUTPUT_PATH = "/workspace/test/output"
# PROMPT = f"{PROMPT_TEXT.strip()}" # PROMPT = f"{PROMPT_TEXT.strip()}"
PROMPT = "<image>\n<|grounding|>Convert the document to markdown." PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
# PROMPT = '<image>\nFree OCR.' # PROMPT = '<image>\nFree OCR.'

View File

@@ -1,20 +1,25 @@
import logging import logging
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from services.ocr_engine import process_document from services.ocr_engine import process_document
router = APIRouter(prefix="/ocr", tags=["OCR"]) router = APIRouter(prefix="/ocr", tags=["OCR"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@router.post("", description="요청된 파일에서 Deepseek OCR을 수행하고 텍스트를 추출합니다.") @router.post(
async def perform_ocr(document: UploadFile = File(..., description="OCR을 수행할 PDF 또는 이미지 파일")): "", description="요청된 파일에서 Deepseek OCR을 수행하고 텍스트를 추출합니다."
)
async def perform_ocr(
document: UploadFile = File(..., description="OCR을 수행할 PDF 또는 이미지 파일"),
):
""" """
클라이언트로부터 받은 파일을 OCR 엔진에 전달하고, 추출된 텍스트를 반환합니다. 클라이언트로부터 받은 파일을 OCR 엔진에 전달하고, 추출된 텍스트를 반환합니다.
- **document**: `multipart/form-data` 형식으로 전송된 파일. - **document**: `multipart/form-data` 형식으로 전송된 파일.
""" """
logger.info(f"'{document.filename}' 파일에 대한 OCR 요청 수신 (Content-Type: {document.content_type})") logger.info(
f"'{document.filename}' 파일에 대한 OCR 요청 수신 (Content-Type: {document.content_type})"
)
try: try:
file_content = await document.read() file_content = await document.read()
@@ -36,6 +41,8 @@ async def perform_ocr(document: UploadFile = File(..., description="OCR을 수
except Exception as e: except Exception as e:
# 예상치 못한 서버 내부 오류 # 예상치 못한 서버 내부 오류
logger.exception(f"OCR 처리 중 예상치 못한 오류 발생: {e}") logger.exception(f"OCR 처리 중 예상치 못한 오류 발생: {e}")
raise HTTPException(status_code=500, detail=f"서버 내부 오류가 발생했습니다: {e}") raise HTTPException(
status_code=500, detail=f"서버 내부 오류가 발생했습니다: {e}"
)
finally: finally:
await document.close() await document.close()

0
services/__init__.py Normal file
View File

View File

@@ -9,11 +9,17 @@ import torch.nn as nn
from addict import Dict from addict import Dict
# import time # import time
from config import BASE_SIZE, CROP_MODE, IMAGE_SIZE, PRINT_NUM_VIS_TOKENS, PROMPT from config.model_settings import (
from deepencoder.build_linear import MlpProjector BASE_SIZE,
from deepencoder.clip_sdpa import build_clip_l CROP_MODE,
from deepencoder.sam_vary_sdpa import build_sam_vit_b IMAGE_SIZE,
from process.image_process import DeepseekOCRProcessor, count_tiles PRINT_NUM_VIS_TOKENS,
PROMPT,
)
from services.deepencoder.build_linear import MlpProjector
from services.deepencoder.clip_sdpa import build_clip_l
from services.deepencoder.sam_vary_sdpa import build_sam_vit_b
from services.process.image_process import DeepseekOCRProcessor, count_tiles
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata

View File

@@ -1,37 +1,45 @@
import asyncio import asyncio
import io import io
from typing import Union import logging
import fitz import fitz
from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT
from fastapi import UploadFile
from PIL import Image from PIL import Image
from process.image_process import DeepseekOCRProcessor
from vllm import AsyncLLMEngine, SamplingParams from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from services.deepseek_ocr import DeepseekOCRForCausalLM from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
logger = logging.getLogger(__name__)
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# 1. 모델 및 프로세서 초기화 # 1. 모델 및 프로세서 초기화
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# VLLM이 커스텀 모델을 인식하도록 등록
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
_engine = None
async def init_engine():
"""vLLM 엔진을 비동기적으로 초기화합니다."""
global _engine
if _engine is not None:
return
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
_engine = AsyncLLMEngine.from_engine_args(engine_args)
# VLLM 비동기 엔진 설정
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# 샘플링 파라미터 설정 # 샘플링 파라미터 설정
sampling_params = SamplingParams( sampling_params = SamplingParams(
@@ -47,8 +55,11 @@ processor = DeepseekOCRProcessor()
# 2. 핵심 처리 함수 # 2. 핵심 처리 함수
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
async def _process_single_image(image: Image.Image) -> str: async def _process_single_image(image: Image.Image) -> str:
"""단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다.""" """단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다."""
if _engine is None:
raise RuntimeError("vLLM engine not initialized yet")
if "<image>" not in PROMPT: if "<image>" not in PROMPT:
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.") raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
@@ -60,12 +71,13 @@ async def _process_single_image(image: Image.Image) -> str:
request_id = f"request-{asyncio.get_running_loop().time()}" request_id = f"request-{asyncio.get_running_loop().time()}"
final_output = "" final_output = ""
async for request_output in engine.generate(request, sampling_params, request_id): async for request_output in _engine.generate(request, sampling_params, request_id):
if request_output.outputs: if request_output.outputs:
final_output = request_output.outputs[0].text final_output = request_output.outputs[0].text
return final_output return final_output
def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]: def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
"""PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다.""" """PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다."""
images = [] images = []
@@ -83,6 +95,7 @@ def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
pdf_document.close() pdf_document.close()
return images return images
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict: async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
""" """
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다. 업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.

View File

@@ -3,13 +3,21 @@ from typing import List, Tuple
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from config.model_settings import (
BASE_SIZE,
IMAGE_SIZE,
MAX_CROPS,
MIN_CROPS,
PROMPT,
TOKENIZER,
)
from PIL import Image, ImageOps from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast from transformers import AutoProcessor, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf') best_ratio_diff = float("inf")
best_ratio = (1, 1) best_ratio = (1, 1)
area = width * height area = width * height
for ratio in target_ratios: for ratio in target_ratios:
@@ -25,37 +33,56 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
return best_ratio return best_ratio
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False): def count_tiles(
orig_width,
orig_height,
min_num=MIN_CROPS,
max_num=MAX_CROPS,
image_size=640,
use_thumbnail=False,
):
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
target_ratios = set( target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if (i, j)
i * j <= max_num and i * j >= min_num) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
# print(target_ratios) # print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target # find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio( target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size) aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
return target_aspect_ratio return target_aspect_ratio
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False): def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
target_ratios = set( target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if (i, j)
i * j <= max_num and i * j >= min_num) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
# print(target_ratios) # print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target # find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio( target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size) aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# print(target_aspect_ratio) # print(target_aspect_ratio)
# calculate the target width and height # calculate the target width and height
@@ -71,7 +98,7 @@ def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=6
(i % (target_width // image_size)) * image_size, (i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size, ((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size ((i // (target_width // image_size)) + 1) * image_size,
) )
# split the image # split the image
split_img = resized_img.crop(box) split_img = resized_img.crop(box)
@@ -83,15 +110,13 @@ def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=6
return processed_images, target_aspect_ratio return processed_images, target_aspect_ratio
class ImageTransform: class ImageTransform:
def __init__(
def __init__(self, self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True): normalize: bool = True,
):
self.mean = mean self.mean = mean
self.std = std self.std = std
self.normalize = normalize self.normalize = normalize
@@ -129,7 +154,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
ignore_id: int = -100, ignore_id: int = -100,
**kwargs, **kwargs,
): ):
# self.candidate_resolutions = candidate_resolutions # placeholder no use # self.candidate_resolutions = candidate_resolutions # placeholder no use
self.image_size = IMAGE_SIZE self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE self.base_size = BASE_SIZE
@@ -141,16 +165,17 @@ class DeepseekOCRProcessor(ProcessorMixin):
# self.downsample_ratio = downsample_ratio # self.downsample_ratio = downsample_ratio
self.downsample_ratio = 4 self.downsample_ratio = 4
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
# self.tokenizer = add_special_token(tokenizer) # self.tokenizer = add_special_token(tokenizer)
self.tokenizer.padding_side = 'left' # must set thispadding side with make a difference in batch inference self.tokenizer.padding_side = "left" # must set thispadding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None: if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token}) self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token # add image token
# image_token_id = self.tokenizer.vocab.get(image_token) # image_token_id = self.tokenizer.vocab.get(image_token)
@@ -186,9 +211,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
**kwargs, **kwargs,
) )
# def select_best_resolution(self, image_size): # def select_best_resolution(self, image_size):
# # used for cropping # # used for cropping
# original_width, original_height = image_size # original_width, original_height = image_size
@@ -264,13 +286,21 @@ class DeepseekOCRProcessor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens - num_image_tokens (List[int]): the number of image tokens
""" """
assert (prompt is not None and images is not None assert (
), "prompt and images must be used at the same time." prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt sft_format = prompt
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0] (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
_,
) = images[0]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
@@ -281,7 +311,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
"num_image_tokens": num_image_tokens, "num_image_tokens": num_image_tokens,
} }
# prepare = BatchFeature( # prepare = BatchFeature(
# data=dict( # data=dict(
# input_ids=input_ids, # input_ids=input_ids,
@@ -341,7 +370,12 @@ class DeepseekOCRProcessor(ProcessorMixin):
conversation = PROMPT conversation = PROMPT
assert conversation.count(self.image_token) == len(images) assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token) text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], [] images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = [] image_shapes = []
num_image_tokens = [] num_image_tokens = []
tokenized_str = [] tokenized_str = []
@@ -368,7 +402,9 @@ class DeepseekOCRProcessor(ProcessorMixin):
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
# print('image ', image.size) # print('image ', image.size)
# print('open_size:', image.size) # print('open_size:', image.size)
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE) images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
# print('crop_ratio: ', crop_ratio) # print('crop_ratio: ', crop_ratio)
else: else:
# best_width, best_height = self.image_size, self.image_size # best_width, best_height = self.image_size, self.image_size
@@ -383,8 +419,11 @@ class DeepseekOCRProcessor(ProcessorMixin):
# print('directly resize') # print('directly resize')
image = image.resize((self.image_size, self.image_size)) image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(image, (self.base_size, self.base_size), global_view = ImageOps.pad(
color=tuple(int(x * 255) for x in self.image_transform.mean)) image,
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view)) images_list.append(self.image_transform(global_view))
"""record height / width crop num""" """record height / width crop num"""
@@ -392,9 +431,6 @@ class DeepseekOCRProcessor(ProcessorMixin):
num_width_tiles, num_height_tiles = crop_ratio num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles]) images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1: if num_width_tiles > 1 or num_height_tiles > 1:
"""process the local views""" """process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height), # local_view = ImageOps.pad(image, (best_width, best_height),
@@ -421,15 +457,22 @@ class DeepseekOCRProcessor(ProcessorMixin):
# """add image tokens""" # """add image tokens"""
"""add image tokens""" """add image tokens"""
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) num_queries = math.ceil(
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) (self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base [self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id] tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1: if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * ( tokenized_image += (
num_queries * num_height_tiles) [self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image) images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image)) num_image_tokens.append(len(tokenized_image))
@@ -447,10 +490,9 @@ class DeepseekOCRProcessor(ProcessorMixin):
tokenized_str = tokenized_str + [self.eos_id] tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False] images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len( assert (
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" len(tokenized_str) == len(images_seq_mask)
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = [] masked_tokenized_str = []
for token_index in tokenized_str: for token_index in tokenized_str:
@@ -459,17 +501,21 @@ class DeepseekOCRProcessor(ProcessorMixin):
else: else:
masked_tokenized_str.append(self.ignore_id) masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ assert (
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") ), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str) input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
(input_ids == self.image_token_id)] = self.ignore_id self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id input_ids[input_ids < 0] = self.pad_id
inference_mode = True inference_mode = True
@@ -484,19 +530,32 @@ class DeepseekOCRProcessor(ProcessorMixin):
if len(images_list) == 0: if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0) images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else: else:
pixel_values = torch.stack(images_list, dim=0) pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list: if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else: else:
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0) images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0) input_ids = input_ids.unsqueeze(0)
return [
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]] [
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
]
]
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor) AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)

View File

@@ -1,32 +1,39 @@
from typing import List
import torch import torch
from transformers import LogitsProcessor from transformers import LogitsProcessor
from transformers.generation.logits_process import _calc_banned_ngram_tokens
from typing import List, Set
class NoRepeatNGramLogitsProcessor(LogitsProcessor): class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None): self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None
):
if not isinstance(ngram_size, int) or ngram_size <= 0: if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") raise ValueError(
f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
)
if not isinstance(window_size, int) or window_size <= 0: if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}") raise ValueError(
f"`window_size` has to be a strictly positive integer, but is {window_size}"
)
self.ngram_size = ngram_size self.ngram_size = ngram_size
self.window_size = window_size self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set() self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(
self, input_ids: List[int], scores: torch.FloatTensor
) -> torch.FloatTensor:
if len(input_ids) < self.ngram_size: if len(input_ids) < self.ngram_size:
return scores return scores
current_prefix = tuple(input_ids[-(self.ngram_size - 1):]) current_prefix = tuple(input_ids[-(self.ngram_size - 1) :])
search_start = max(0, len(input_ids) - self.window_size) search_start = max(0, len(input_ids) - self.window_size)
search_end = len(input_ids) - self.ngram_size + 1 search_end = len(input_ids) - self.ngram_size + 1
banned_tokens = set() banned_tokens = set()
for i in range(search_start, search_end): for i in range(search_start, search_end):
ngram = tuple(input_ids[i:i + self.ngram_size]) ngram = tuple(input_ids[i : i + self.ngram_size])
if ngram[:-1] == current_prefix: if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1]) banned_tokens.add(ngram[-1])

0
test/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

BIN
test/output/images/0_0.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

View File

@@ -0,0 +1,12 @@
{
"filename": "2016-08556-211156.pdf",
"model": {
"ocr_model": "deepseek-ocr"
},
"time": {
"duration_sec": "6.59",
"started_at": 1762395627.3137395,
"ended_at": 1762395633.9023185
},
"parsed": "\n수신자 한국수출입은행장 \n\n참조 EDCF Operations Department 2 \n\n제 목 방글라데시 반다주리 상수도 사업 컨설턴트 고용을 위한 문제유발 행위 불개입 확약서 \n\n1. 귀 은행의 무궁한 발전을 기원합니다. \n\n2. 표제 사업 컨설턴트 고용을 위한 제안요청서 조항에 따라 입찰 참여를 위한 \"문제유발행위 불개입 확약 서\"를 \n\n첨부와 같이 제출하오니, 참조해주시기 바랍니다. \n\n* 첨부: 문제유발행위 불개입 확약서 원본 1부. \n\n주식회사 삼안 대표이사 \n\n![](images/0_0.jpg)\n\n \n\n수신처 : Ms. Jiyoon Park, Sr. Loan Officer \n\n문서번호 201609-4495 (2016-09-22) \n\n서울 광진구 광나루로56실 85 프라임센터 34층 해외사업실 \n\n전화 02)6488-8095 \n\n담당 : 김관영 \n\nFAX 02)6488-8080 \n\n/ http://www.samaneng.com \n\n이메일 shkim5@samaneng.com<end▁of▁sentence>\n<--- Page Split --->\n"
}

397
test/test.py Normal file
View File

@@ -0,0 +1,397 @@
import io
import json
import os
import re
import time
import config.model_settings as config
import fitz
import img2pdf
import numpy as np
from config.env_setup import setup_environment
from PIL import Image, ImageDraw, ImageFont, ImageOps
from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
from services.process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
setup_environment()
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
class Colors:
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
# --- PDF/Image Processing Functions (from run_dpsk_ocr_*.py) ---
def pdf_to_images_high_quality(pdf_path, dpi=144):
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format="JPEG", quality=95)
image_bytes_list.append(img_buffer.getvalue())
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"Error creating PDF: {e}")
def re_match(text):
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [m[0] for m in matches if "<|ref|>image<|/ref|>" in m[0]]
mathes_other = [m[0] for m in matches if "<|ref|>image<|/ref|>" not in m[0]]
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
return (label_type, cor_list)
except Exception as e:
print(f"Error extracting coordinates: {e}")
return None
def draw_bounding_boxes(image, refs, jdx=None):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
result = extract_coordinates_and_label(ref, image_width, image_height)
if not result:
continue
label_type, points_list = result
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = [
int(p / 999 * (image_width if i % 2 == 0 else image_height))
for i, p in enumerate(points)
]
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
img_filename = (
f"{jdx}_{img_idx}.jpg" if jdx is not None else f"{img_idx}.jpg"
)
cropped.save(
os.path.join(config.OUTPUT_PATH, "images", img_filename)
)
img_idx += 1
except Exception as e:
print(f"Error cropping image: {e}")
width = 4 if label_type == "title" else 2
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
draw2.rectangle(
[x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1
)
text_x, text_y = x1, max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width, text_height = (
text_bbox[2] - text_bbox[0],
text_bbox[3] - text_bbox[1],
)
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx=None):
return draw_bounding_boxes(image, ref_texts, jdx)
def load_image(image_path):
try:
image = Image.open(image_path).convert("RGB")
return ImageOps.exif_transpose(image)
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
# --- Main OCR Processing Logic ---
def process_pdf(llm, sampling_params, pdf_path):
print(f"{Colors.GREEN}Processing PDF: {pdf_path}{Colors.RESET}")
base_name = os.path.basename(pdf_path)
file_name_without_ext = os.path.splitext(base_name)[0]
images = pdf_to_images_high_quality(pdf_path)
if not images:
print(
f"{Colors.YELLOW}Could not extract images from {pdf_path}. Skipping.{Colors.RESET}"
)
return
batch_inputs = []
processor = DeepseekOCRProcessor()
for image in tqdm(images, desc="Pre-processing PDF pages"):
batch_inputs.append(
{
"prompt": config.PROMPT,
"multi_modal_data": {
"image": processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
},
}
)
start_time = time.time()
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
end_time = time.time()
contents_det = ""
contents = ""
draw_images = []
for i, (output, img) in enumerate(zip(outputs_list, images)):
content = output.outputs[0].text
if "<end of sentence>" in content:
content = content.replace("<end of sentence>", "")
elif config.SKIP_REPEAT:
continue
page_num_separator = "\n<--- Page Split --->\n"
contents_det += content + page_num_separator
matches_ref, matches_images, mathes_other = re_match(content)
result_image = process_image_with_refs(img.copy(), matches_ref, jdx=i)
draw_images.append(result_image)
for idx, match in enumerate(matches_images):
content = content.replace(match, f"![](images/{i}_{idx}.jpg)\n")
for match in mathes_other:
content = (
content.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
contents += content + page_num_separator
# Save results
json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
pdf_out_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_layouts.pdf"
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": contents,
}
with open(json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
print(
f"{Colors.GREEN}Finished processing {pdf_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def process_image(llm, sampling_params, image_path):
print(f"{Colors.GREEN}Processing Image: {image_path}{Colors.RESET}")
base_name = os.path.basename(image_path)
file_name_without_ext = os.path.splitext(base_name)[0]
image = load_image(image_path)
if image is None:
return
processor = DeepseekOCRProcessor()
image_features = processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
request = {
"prompt": config.PROMPT,
"multi_modal_data": {"image": image_features},
}
start_time = time.time()
outputs = llm.generate([request], sampling_params)
end_time = time.time()
result_out = outputs[0].outputs[0].text
print(result_out)
# Save results
output_json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
result_image_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_result_with_boxes.jpg"
)
matches_ref, matches_images, mathes_other = re_match(result_out)
result_image = process_image_with_refs(image.copy(), matches_ref)
processed_text = result_out
for idx, match in enumerate(matches_images):
processed_text = processed_text.replace(match, f"![](images/{idx}.jpg)\n")
for match in mathes_other:
processed_text = (
processed_text.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": processed_text,
}
with open(output_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
result_image.save(result_image_path)
print(
f"{Colors.GREEN}Finished processing {image_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def main():
# --- Model Initialization ---
print(f"{Colors.BLUE}Initializing model...{Colors.RESET}")
llm = LLM(
model=config.MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=config.MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True,
)
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
print(f"{Colors.BLUE}Model initialized successfully.{Colors.RESET}")
# --- File Processing ---
input_dir = config.INPUT_PATH
output_dir = config.OUTPUT_PATH
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "result"), exist_ok=True)
if not os.path.isdir(input_dir):
print(
f"{Colors.RED}Error: Input directory not found at '{input_dir}'{Colors.RESET}"
)
return
print(f"Scanning for files in '{input_dir}'...")
for filename in sorted(os.listdir(input_dir)):
input_path = os.path.join(input_dir, filename)
if not os.path.isfile(input_path):
continue
file_extension = os.path.splitext(filename)[1].lower()
try:
if file_extension == ".pdf":
process_pdf(llm, sampling_params, input_path)
elif file_extension in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]:
process_image(llm, sampling_params, input_path)
else:
print(
f"{Colors.YELLOW}Skipping unsupported file type: {filename}{Colors.RESET}"
)
except Exception as e:
print(
f"{Colors.RED}An error occurred while processing {filename}: {e}{Colors.RESET}"
)
if __name__ == "__main__":
main()