Files
deepseek_ocr/services/ocr_engine.py

205 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import io
import logging
import re
import fitz
from config import engine_settings, model_settings
from PIL import Image, ImageOps
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry
from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------
# 1. 모델 및 프로세서 초기화
# --------------------------------------------------------------------------
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
_engine = None
async def init_engine():
"""vLLM 엔진을 비동기적으로 초기화합니다."""
global _engine
if _engine is not None:
return
engine_args = AsyncEngineArgs(
model=model_settings.MODEL_PATH,
hf_overrides={"architectures": engine_settings.ARCHITECTURES},
block_size=engine_settings.BLOCK_SIZE,
max_model_len=engine_settings.MAX_MODEL_LEN,
enforce_eager=engine_settings.ENFORCE_EAGER,
trust_remote_code=engine_settings.TRUST_REMOTE_CODE,
tensor_parallel_size=engine_settings.TENSOR_PARALLEL_SIZE,
gpu_memory_utilization=engine_settings.GPU_MEMORY_UTILIZATION,
)
_engine = AsyncLLMEngine.from_engine_args(engine_args)
# 샘플링 파라미터 설정
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
skip_special_tokens=False,
)
# 이미지 전처리기
processor = DeepseekOCRProcessor()
# --------------------------------------------------------------------------
# 2. 후처리 함수 (공식 코드 기반)
# --------------------------------------------------------------------------
def _postprocess_text(text: str, page_num: int = 0) -> str:
"""
모델의 원본 출력에서 태그를 제거/변경하고 텍스트를 정리합니다.
(test/test.py의 후처리 로직 기반으로 수정)
"""
if not text:
return ""
# 1. 정규식으로 모든 ref/det 태그 블록을 찾음
# 패턴은 (전체 매치, ref 내용, det 내용)을 캡처
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
# 2. 전체 매치된 문자열을 이미지 태그와 기타 태그로 분리
matches_images = []
matches_other = []
for match_tuple in matches:
full_match_str = match_tuple[0] # 전체 매치된 부분
ref_content = match_tuple[1] # <|ref|> 안의 내용
if "image" in ref_content:
matches_images.append(full_match_str)
else:
matches_other.append(full_match_str)
processed_text = text
# 3. 이미지 태그는 마크다운 링크로 대체
for idx, img_tag in enumerate(matches_images):
img_link = f"![](images/{page_num}_{idx}.jpg)\n"
processed_text = processed_text.replace(img_tag, img_link)
# 4. 이미지가 아닌 다른 모든 태그는 제거
for other_tag in matches_other:
processed_text = processed_text.replace(other_tag, "")
# 5. 특수 문자, 불필요한 토큰 및 추가 공백 정리
processed_text = (
processed_text.replace("<end of sentence>", "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
.strip()
)
return processed_text
# --------------------------------------------------------------------------
# 3. 핵심 처리 함수
# --------------------------------------------------------------------------
async def _process_single_image(image: Image.Image, page_num: int = 0) -> str:
"""단일 PIL 이미지를 받아 OCR을 수행하고 후처리된 텍스트를 반환합니다."""
if _engine is None:
raise RuntimeError("vLLM engine not initialized yet")
if "<image>" not in model_settings.PROMPT:
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
image_features = processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=model_settings.CROP_MODE
)
request = {"prompt": model_settings.PROMPT, "multi_modal_data": {"image": image_features}}
request_id = f"request-{asyncio.get_running_loop().time()}"
raw_output = ""
async for request_output in _engine.generate(request, sampling_params, request_id):
if request_output.outputs:
raw_output = request_output.outputs[0].text
# 후처리 적용 (페이지 번호 전달)
clean_text = _postprocess_text(raw_output, page_num)
return clean_text
def _pdf_to_images_high_quality(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
"""PDF 바이트를 받아 페이지별 고품질 PIL 이미지 리스트를 반환합니다."""
images = []
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
images.append(img)
pdf_document.close()
return images
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
"""
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
"""
# Content-Type이 generic할 경우, 파일 확장자로 타입을 유추
inferred_content_type = content_type
if content_type == "application/octet-stream":
if filename.lower().endswith(".pdf"):
inferred_content_type = "application/pdf"
elif any(
filename.lower().endswith(ext)
for ext in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
):
inferred_content_type = "image/jpeg" # 구체적인 타입은 중요하지 않음
if inferred_content_type.startswith("image/"):
try:
image = Image.open(io.BytesIO(file_bytes))
image = ImageOps.exif_transpose(image).convert("RGB")
except Exception as e:
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
# 단일 이미지는 페이지 번호를 0으로 간주
result_text = await _process_single_image(image, page_num=0)
return {"filename": filename, "text": result_text}
elif inferred_content_type == "application/pdf":
try:
images = _pdf_to_images_high_quality(file_bytes)
except Exception as e:
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
# 각 페이지를 비동기적으로 처리 (페이지 번호 전달)
tasks = [_process_single_image(img, page_num=i) for i, img in enumerate(images)]
page_results = await asyncio.gather(*tasks)
full_text = "\n<--- Page Split --->\n".join(page_results)
return {"filename": filename, "text": full_text, "page_count": len(images)}
else:
raise ValueError(
f"지원하지 않는 파일 형식입니다: {content_type}. "
"이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요."
)