131 lines
4.4 KiB
Python
131 lines
4.4 KiB
Python
import asyncio
|
|
import io
|
|
import logging
|
|
|
|
import fitz
|
|
from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT
|
|
from PIL import Image
|
|
from vllm import AsyncLLMEngine, SamplingParams
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.model_executor.models.registry import ModelRegistry
|
|
|
|
from services.deepseek_ocr import DeepseekOCRForCausalLM
|
|
from services.process.image_process import DeepseekOCRProcessor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --------------------------------------------------------------------------
|
|
# 1. 모델 및 프로세서 초기화
|
|
# --------------------------------------------------------------------------
|
|
|
|
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
|
_engine = None
|
|
|
|
|
|
async def init_engine():
|
|
"""vLLM 엔진을 비동기적으로 초기화합니다."""
|
|
global _engine
|
|
if _engine is not None:
|
|
return
|
|
|
|
engine_args = AsyncEngineArgs(
|
|
model=MODEL_PATH,
|
|
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
|
block_size=256,
|
|
max_model_len=8192,
|
|
enforce_eager=False,
|
|
trust_remote_code=True,
|
|
tensor_parallel_size=1,
|
|
gpu_memory_utilization=0.75,
|
|
)
|
|
_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
|
|
# 샘플링 파라미터 설정
|
|
sampling_params = SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=8192,
|
|
skip_special_tokens=False,
|
|
)
|
|
|
|
# 이미지 전처리기
|
|
processor = DeepseekOCRProcessor()
|
|
|
|
# --------------------------------------------------------------------------
|
|
# 2. 핵심 처리 함수
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
async def _process_single_image(image: Image.Image) -> str:
|
|
"""단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다."""
|
|
if _engine is None:
|
|
raise RuntimeError("vLLM engine not initialized yet")
|
|
if "<image>" not in PROMPT:
|
|
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
|
|
|
|
image_features = processor.tokenize_with_images(
|
|
images=[image], bos=True, eos=True, cropping=CROP_MODE
|
|
)
|
|
|
|
request = {"prompt": PROMPT, "multi_modal_data": {"image": image_features}}
|
|
request_id = f"request-{asyncio.get_running_loop().time()}"
|
|
|
|
final_output = ""
|
|
async for request_output in _engine.generate(request, sampling_params, request_id):
|
|
if request_output.outputs:
|
|
final_output = request_output.outputs[0].text
|
|
|
|
return final_output
|
|
|
|
|
|
def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
|
"""PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다."""
|
|
images = []
|
|
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
|
zoom = dpi / 72.0
|
|
matrix = fitz.Matrix(zoom, zoom)
|
|
|
|
for page_num in range(pdf_document.page_count):
|
|
page = pdf_document[page_num]
|
|
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
|
img_data = pixmap.tobytes("png")
|
|
img = Image.open(io.BytesIO(img_data))
|
|
images.append(img)
|
|
|
|
pdf_document.close()
|
|
return images
|
|
|
|
|
|
async def process_document(file_bytes: bytes, content_type: str, filename: str) -> dict:
|
|
"""
|
|
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
|
|
"""
|
|
if content_type.startswith("image/"):
|
|
try:
|
|
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
|
|
except Exception as e:
|
|
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
|
|
|
|
result_text = await _process_single_image(image)
|
|
return {"filename": filename, "text": result_text}
|
|
|
|
elif content_type == "application/pdf":
|
|
try:
|
|
images = _pdf_to_images(file_bytes)
|
|
except Exception as e:
|
|
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
|
|
|
|
# 각 페이지를 비동기적으로 처리
|
|
tasks = [_process_single_image(img) for img in images]
|
|
page_results = await asyncio.gather(*tasks)
|
|
|
|
# 페이지 구분자를 넣어 전체 텍스트 합치기
|
|
full_text = "\n<--- Page Split --->\n".join(page_results)
|
|
return {"filename": filename, "text": full_text, "page_count": len(images)}
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"지원하지 않는 파일 형식입니다: {content_type}. "
|
|
"이미지(JPEG, PNG 등) 또는 PDF 파일을 업로드해주세요."
|
|
)
|