315 lines
12 KiB
Python
315 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from io import BytesIO
|
|
|
|
import httpx
|
|
import pytesseract
|
|
import redis
|
|
from celery import Task
|
|
from config.setting import (
|
|
REDIS_DB,
|
|
REDIS_HOST,
|
|
REDIS_PORT,
|
|
UPSTAGE_API_KEY,
|
|
)
|
|
from PIL import Image
|
|
from pdf2image import convert_from_path
|
|
from utils.celery_utils import celery_app
|
|
from utils.ocr_processor import ocr_process
|
|
from utils.text_extractor import extract_text_from_file
|
|
|
|
# Redis 클라이언트
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
|
|
)
|
|
|
|
# 로깅
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# 공통 Task 베이스 클래스 (진행 로그 + 실패/성공 훅)
|
|
class BaseTaskWithProgress(Task):
|
|
abstract = True
|
|
|
|
def update_progress(self, request_id: str, status_message: str, step_info=None):
|
|
log_entry = {
|
|
"status": status_message,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"step_info": step_info,
|
|
}
|
|
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
|
|
logger.info(f"[{request_id}] Task Progress: {status_message}")
|
|
|
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(
|
|
request_id,
|
|
"작업 오류 발생",
|
|
{"error": str(exc), "traceback": str(einfo)},
|
|
)
|
|
logger.error(f"[{request_id}] Task Failed: {exc}")
|
|
super().on_failure(exc, task_id, args, kwargs, einfo)
|
|
|
|
def on_success(self, retval, task_id, args, kwargs):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(request_id, "작업 완료")
|
|
logger.info(f"[{request_id}] Task Succeeded")
|
|
super().on_success(retval, task_id, args, kwargs)
|
|
|
|
|
|
# presigned URL에서 파일 다운로드 (비동기)
|
|
async def download_file_from_presigned_url(file_url: str, save_path: str):
|
|
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
|
resp = await client.get(file_url)
|
|
resp.raise_for_status()
|
|
with open(save_path, "wb") as f:
|
|
f.write(resp.content)
|
|
|
|
|
|
# (Paddle) OCR + 후처리
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def call_paddle_ocr(self, presigned_url: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Paddle OCR 작업 시작")
|
|
|
|
suffix = os.path.splitext(file_name)[-1]
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# 1) 파일 다운로드
|
|
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
|
|
try:
|
|
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
|
|
except Exception as e:
|
|
raise RuntimeError(f"파일 다운로드 실패: {e}")
|
|
self.update_progress(request_id, "파일 다운로드 완료")
|
|
|
|
# 2) OCR 실행
|
|
start_time = time.time()
|
|
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path))
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
|
|
|
|
# 3) 결과 JSON 생성
|
|
result_json = ocr_process(
|
|
file_name, # 1
|
|
ocr_model, # 2
|
|
coord, # 3
|
|
text, # 4
|
|
start_time, # 5
|
|
end_time, # 6
|
|
)
|
|
return result_json
|
|
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
|
|
# Upstage 응답 정규화: 가능한 많은 'text'를 모으고, 후보 bbox를 수집
|
|
def _normalize_upstage_response(resp_json, return_word_level=False, normalize=False):
|
|
"""
|
|
Upstage 문서 디지타이제이션 응답에서 text와 bbox 후보를 추출.
|
|
구조가 달라도 dict/list를 재귀 탐색하여 'text' 유사 키와 bbox 유사 키를 모읍니다.
|
|
"""
|
|
# 1) 전체 텍스트 추출
|
|
if isinstance(resp_json, dict) and resp_json.get("text"):
|
|
full_text = resp_json["text"]
|
|
else:
|
|
full_text = ""
|
|
for p in resp_json.get("pages") or []:
|
|
t = p.get("text")
|
|
if t:
|
|
full_text += t + "\n"
|
|
full_text = full_text.rstrip("\n")
|
|
|
|
# 2) 좌표/워드 추출
|
|
coords = []
|
|
word_items = []
|
|
pages = resp_json.get("pages") or []
|
|
for p_idx, page in enumerate(pages, start=1):
|
|
w = page.get("words") or []
|
|
pw, ph = page.get("width"), page.get("height") # 정규화 옵션용
|
|
for wobj in w:
|
|
bb = (wobj.get("boundingBox") or {}).get("vertices") or []
|
|
if len(bb) == 4:
|
|
poly = [[float(pt.get("x", 0)), float(pt.get("y", 0))] for pt in bb]
|
|
|
|
if normalize and pw and ph:
|
|
poly = [[x / float(pw), y / float(ph)] for x, y in poly]
|
|
|
|
if return_word_level:
|
|
word_items.append(
|
|
{
|
|
"page": p_idx,
|
|
"text": wobj.get("text", ""),
|
|
"confidence": float(wobj.get("confidence") or 0.0),
|
|
"box": poly, # 4x2
|
|
}
|
|
)
|
|
else:
|
|
coords.append(poly)
|
|
|
|
if return_word_level:
|
|
return full_text, word_items
|
|
return full_text, coords
|
|
|
|
|
|
# (Upstage) 외부 OCR API 호출 + 후처리
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def call_upstage_ocr_api(self, presigned_url: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Upstage OCR 작업 시작")
|
|
|
|
if not UPSTAGE_API_KEY:
|
|
raise ValueError("Upstage API 키가 설정되지 않았습니다.")
|
|
|
|
url = "https://api.upstage.ai/v1/document-digitization"
|
|
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
|
|
|
|
suffix = os.path.splitext(file_name)[-1]
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# 1) 파일 다운로드
|
|
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
|
|
try:
|
|
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
|
|
except Exception as e:
|
|
raise RuntimeError(f"파일 다운로드 실패: {e}")
|
|
self.update_progress(request_id, "파일 다운로드 완료")
|
|
|
|
# 2) Upstage API 호출(시간 측정)
|
|
start_time = time.time()
|
|
with open(tmp_path, "rb") as f:
|
|
files = {"document": (file_name, f, "application/octet-stream")}
|
|
data = {"model": "ocr"}
|
|
try:
|
|
with httpx.Client(timeout=120.0, follow_redirects=True) as client:
|
|
response = client.post(url, headers=headers, files=files, data=data)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Upstage API 오류: {e.response.text}")
|
|
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
|
|
except Exception as e:
|
|
logger.error(f"Upstage API 호출 중 예외 발생: {e}")
|
|
raise RuntimeError("Upstage API 호출 실패")
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "Upstage API 호출 성공")
|
|
|
|
# 3) 응답 정규화 → text/coord 추출
|
|
resp_json = response.json()
|
|
text, coord = _normalize_upstage_response(resp_json)
|
|
|
|
# 4) 공통 후처리(JSON 스키마 통일)
|
|
result_json = ocr_process(
|
|
file_name, # 1
|
|
"upstage", # 2
|
|
coord, # 3
|
|
text, # 4
|
|
start_time, # 5
|
|
end_time, # 6
|
|
)
|
|
self.update_progress(request_id, "후처리 완료")
|
|
return result_json
|
|
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
|
|
# (Tesseract) 기본 모델 OCR
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def call_tesseract_ocr(self, presigned_url: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Tesseract (기본) OCR 작업 시작")
|
|
|
|
suffix = os.path.splitext(file_name)[-1]
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
self.update_progress(request_id, "파일 다운로드 중")
|
|
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
|
|
self.update_progress(request_id, "파일 다운로드 완료")
|
|
|
|
start_time = time.time()
|
|
if file_name.lower().endswith(".pdf"):
|
|
images = convert_from_path(tmp_path)
|
|
text = ""
|
|
for image in images:
|
|
text += pytesseract.image_to_string(image, lang="kor")
|
|
else:
|
|
with open(tmp_path, "rb") as f:
|
|
image_bytes = f.read()
|
|
image = Image.open(BytesIO(image_bytes))
|
|
text = pytesseract.image_to_string(image, lang="kor")
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "Tesseract OCR 완료")
|
|
|
|
# 좌표(coord) 정보는 pytesseract 기본 출력에서 얻기 어려우므로 빈 리스트로 처리
|
|
result_json = ocr_process(
|
|
file_name, "tesseract_default", [], text, start_time, end_time
|
|
)
|
|
return result_json
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
|
|
# (Tesseract) 훈련된 모델 OCR
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def call_tesstrain_ocr(self, presigned_url: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Tesseract (훈련 모델) OCR 작업 시작")
|
|
|
|
TESSDATA_DIR = "/tesseract_trainer/tesstrain/workspace/"
|
|
MODEL_NAME = "kor_fonts"
|
|
|
|
suffix = os.path.splitext(file_name)[-1]
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
self.update_progress(request_id, "파일 다운로드 중")
|
|
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
|
|
self.update_progress(request_id, "파일 다운로드 완료")
|
|
|
|
start_time = time.time()
|
|
if file_name.lower().endswith(".pdf"):
|
|
images = convert_from_path(tmp_path)
|
|
text = ""
|
|
custom_config = f"--tessdata-dir {TESSDATA_DIR} -l {MODEL_NAME}"
|
|
for image in images:
|
|
text += pytesseract.image_to_string(image, config=custom_config)
|
|
else:
|
|
with open(tmp_path, "rb") as f:
|
|
image_bytes = f.read()
|
|
image = Image.open(BytesIO(image_bytes))
|
|
custom_config = f"--tessdata-dir {TESSDATA_DIR} -l {MODEL_NAME}"
|
|
text = pytesseract.image_to_string(image, config=custom_config)
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "Tesseract (훈련 모델) OCR 완료")
|
|
|
|
result_json = ocr_process(
|
|
file_name, f"tesstrain_{MODEL_NAME}", [], text, start_time, end_time
|
|
)
|
|
return result_json
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
|
|
# 결과 Redis 저장 (체인의 두 번째 스텝)
|
|
# router 체인: store_ocr_result.s(request_id=request_id, task_id=task_id)
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress, ignore_result=True)
|
|
def store_ocr_result(self, result_data: dict, request_id: str, task_id: str):
|
|
self.update_progress(request_id, "결과 저장 중")
|
|
redis_key = f"ocr_result:{task_id}"
|
|
redis_client.set(redis_key, json.dumps(result_data, ensure_ascii=False))
|
|
self.update_progress(request_id, "모든 작업 완료")
|