205 lines
7.6 KiB
Python
205 lines
7.6 KiB
Python
import asyncio
|
||
import io
|
||
import logging
|
||
import re
|
||
|
||
import fitz
|
||
from config import engine_settings, model_settings
|
||
from PIL import Image, ImageOps
|
||
from vllm import AsyncLLMEngine, SamplingParams
|
||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||
from vllm.model_executor.models.registry import ModelRegistry
|
||
|
||
from services.deepseek_ocr import DeepseekOCRForCausalLM
|
||
from services.process.image_process import DeepseekOCRProcessor
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --------------------------------------------------------------------------
|
||
# 1. 모델 및 프로세서 초기화
|
||
# --------------------------------------------------------------------------
|
||
|
||
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||
_engine = None
|
||
|
||
|
||
async def init_engine():
|
||
"""vLLM 엔진을 비동기적으로 초기화합니다."""
|
||
global _engine
|
||
if _engine is not None:
|
||
return
|
||
|
||
engine_args = AsyncEngineArgs(
|
||
model=model_settings.MODEL_PATH,
|
||
hf_overrides={"architectures": engine_settings.ARCHITECTURES},
|
||
block_size=engine_settings.BLOCK_SIZE,
|
||
max_model_len=engine_settings.MAX_MODEL_LEN,
|
||
enforce_eager=engine_settings.ENFORCE_EAGER,
|
||
trust_remote_code=engine_settings.TRUST_REMOTE_CODE,
|
||
tensor_parallel_size=engine_settings.TENSOR_PARALLEL_SIZE,
|
||
gpu_memory_utilization=engine_settings.GPU_MEMORY_UTILIZATION,
|
||
)
|
||
_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||
|
||
|
||
# 샘플링 파라미터 설정
|
||
sampling_params = SamplingParams(
|
||
temperature=0.0,
|
||
max_tokens=8192,
|
||
skip_special_tokens=False,
|
||
)
|
||
|
||
# 이미지 전처리기
|
||
processor = DeepseekOCRProcessor()
|
||
|
||
# --------------------------------------------------------------------------
|
||
# 2. 후처리 함수 (공식 코드 기반)
|
||
# --------------------------------------------------------------------------
|
||
|
||
|
||
def _postprocess_text(text: str, page_num: int = 0) -> str:
|
||
"""
|
||
모델의 원본 출력에서 태그를 제거/변경하고 텍스트를 정리합니다.
|
||
(test/test.py의 후처리 로직 기반으로 수정)
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
# 1. 정규식으로 모든 ref/det 태그 블록을 찾음
|
||
# 패턴은 (전체 매치, ref 내용, det 내용)을 캡처
|
||
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
|
||
matches = re.findall(pattern, text, re.DOTALL)
|
||
|
||
# 2. 전체 매치된 문자열을 이미지 태그와 기타 태그로 분리
|
||
matches_images = []
|
||
matches_other = []
|
||
for match_tuple in matches:
|
||
full_match_str = match_tuple[0] # 전체 매치된 부분
|
||
ref_content = match_tuple[1] # <|ref|> 안의 내용
|
||
|
||
if "image" in ref_content:
|
||
matches_images.append(full_match_str)
|
||
else:
|
||
matches_other.append(full_match_str)
|
||
|
||
processed_text = text
|
||
|
||
# 3. 이미지 태그는 마크다운 링크로 대체
|
||
for idx, img_tag in enumerate(matches_images):
|
||
img_link = f"\n"
|
||
processed_text = processed_text.replace(img_tag, img_link)
|
||
|
||
# 4. 이미지가 아닌 다른 모든 태그는 제거
|
||
for other_tag in matches_other:
|
||
processed_text = processed_text.replace(other_tag, "")
|
||
|
||
# 5. 특수 문자, 불필요한 토큰 및 추가 공백 정리
|
||
processed_text = (
|
||
processed_text.replace("<|end of sentence|>", "")
|
||
.replace("\\coloneqq", ":=")
|
||
.replace("\\eqqcolon", "=:")
|
||
.replace("\n\n\n", "\n\n")
|
||
.strip()
|
||
)
|
||
|
||
return processed_text
|
||
|
||
|
||
# --------------------------------------------------------------------------
|
||
# 3. 핵심 처리 함수
|
||
# --------------------------------------------------------------------------
|
||
|
||
|
||
async def _process_single_image(image: Image.Image, page_num: int = 0) -> str:
|
||
"""단일 PIL 이미지를 받아 OCR을 수행하고 후처리된 텍스트를 반환합니다."""
|
||
if _engine is None:
|
||
raise RuntimeError("vLLM engine not initialized yet")
|
||
if "<image>" not in model_settings.PROMPT:
|
||
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
|
||
|
||
image_features = processor.tokenize_with_images(
|
||
images=[image], bos=True, eos=True, cropping=model_settings.CROP_MODE
|
||
)
|
||
|
||
request = {"prompt": model_settings.PROMPT, "multi_modal_data": {"image": image_features}}
|
||
request_id = f"request-{asyncio.get_running_loop().time()}"
|
||
|
||
raw_output = ""
|
||
async for request_output in _engine.generate(request, sampling_params, request_id):
|
||
if request_output.outputs:
|
||
raw_output = request_output.outputs[0].text
|
||
|
||
# 후처리 적용 (페이지 번호 전달)
|
||
clean_text = _postprocess_text(raw_output, page_num)
|
||
return clean_text
|
||
|
||
|
||
def _pdf_to_images_high_quality(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
||
"""PDF 바이트를 받아 페이지별 고품질 PIL 이미지 리스트를 반환합니다."""
|
||
images = []
|
||
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||
zoom = dpi / 72.0
|
||
matrix = fitz.Matrix(zoom, zoom)
|
||
|
||
for page_num in range(pdf_document.page_count):
|
||
page = pdf_document[page_num]
|
||
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
||
Image.MAX_IMAGE_PIXELS = None
|
||
img_data = pixmap.tobytes("png")
|
||
img = Image.open(io.BytesIO(img_data))
|
||
|
||
if img.mode in ("RGBA", "LA"):
|
||
background = Image.new("RGB", img.size, (255, 255, 255))
|
||
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
||
img = background
|
||
images.append(img)
|
||
|
||
pdf_document.close()
|
||
return images
|
||
|
||
|
||
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
|
||
"""
|
||
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
|
||
"""
|
||
# Content-Type이 generic할 경우, 파일 확장자로 타입을 유추
|
||
inferred_content_type = content_type
|
||
if content_type == "application/octet-stream":
|
||
if filename.lower().endswith(".pdf"):
|
||
inferred_content_type = "application/pdf"
|
||
elif any(
|
||
filename.lower().endswith(ext)
|
||
for ext in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
|
||
):
|
||
inferred_content_type = "image/jpeg" # 구체적인 타입은 중요하지 않음
|
||
|
||
if inferred_content_type.startswith("image/"):
|
||
try:
|
||
image = Image.open(io.BytesIO(file_bytes))
|
||
image = ImageOps.exif_transpose(image).convert("RGB")
|
||
except Exception as e:
|
||
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
|
||
|
||
# 단일 이미지는 페이지 번호를 0으로 간주
|
||
result_text = await _process_single_image(image, page_num=0)
|
||
return {"filename": filename, "text": result_text}
|
||
|
||
elif inferred_content_type == "application/pdf":
|
||
try:
|
||
images = _pdf_to_images_high_quality(file_bytes)
|
||
except Exception as e:
|
||
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
|
||
|
||
# 각 페이지를 비동기적으로 처리 (페이지 번호 전달)
|
||
tasks = [_process_single_image(img, page_num=i) for i, img in enumerate(images)]
|
||
page_results = await asyncio.gather(*tasks)
|
||
|
||
full_text = "\n<--- Page Split --->\n".join(page_results)
|
||
return {"filename": filename, "text": full_text, "page_count": len(images)}
|
||
|
||
else:
|
||
raise ValueError(
|
||
f"지원하지 않는 파일 형식입니다: {content_type}. "
|
||
"이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요."
|
||
)
|