Files
ocr_performance_lab/tasks.py
2025-08-12 12:25:56 +09:00

216 lines
7.6 KiB
Python

import asyncio
import json
import logging
import os
import tempfile
import time
from datetime import datetime, timezone
import httpx
import redis
from celery import Task
from config.setting import (
REDIS_DB,
REDIS_HOST,
REDIS_PORT,
UPSTAGE_API_KEY,
)
from utils.celery_utils import celery_app
from utils.ocr_processor import ocr_process
from utils.text_extractor import extract_text_from_file
# Redis 클라이언트
redis_client = redis.Redis(
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
)
# 로깅
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 공통 Task 베이스 클래스 (진행 로그 + 실패/성공 훅)
class BaseTaskWithProgress(Task):
abstract = True
def update_progress(self, request_id: str, status_message: str, step_info=None):
log_entry = {
"status": status_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
"step_info": step_info,
}
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
logger.info(f"[{request_id}] Task Progress: {status_message}")
def on_failure(self, exc, task_id, args, kwargs, einfo):
request_id = kwargs.get("request_id", "unknown")
self.update_progress(
request_id,
"작업 오류 발생",
{"error": str(exc), "traceback": str(einfo)},
)
logger.error(f"[{request_id}] Task Failed: {exc}")
super().on_failure(exc, task_id, args, kwargs, einfo)
def on_success(self, retval, task_id, args, kwargs):
request_id = kwargs.get("request_id", "unknown")
self.update_progress(request_id, "작업 완료")
logger.info(f"[{request_id}] Task Succeeded")
super().on_success(retval, task_id, args, kwargs)
# presigned URL에서 파일 다운로드 (비동기)
async def download_file_from_presigned_url(file_url: str, save_path: str):
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
resp = await client.get(file_url)
resp.raise_for_status()
with open(save_path, "wb") as f:
f.write(resp.content)
# (Paddle) OCR + 후처리
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def parse_ocr_text(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Paddle OCR 작업 시작")
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try:
# 1) 파일 다운로드
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
try:
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
except Exception as e:
raise RuntimeError(f"파일 다운로드 실패: {e}")
self.update_progress(request_id, "파일 다운로드 완료")
# 2) OCR 실행
start_time = time.time()
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path))
end_time = time.time()
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
# 3) 결과 JSON 생성
result_json = ocr_process(
file_name, # 1
ocr_model, # 2
coord, # 3
text, # 4
start_time, # 5
end_time, # 6
)
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# Upstage 응답 정규화: 가능한 많은 'text'를 모으고, 후보 bbox를 수집
def _normalize_upstage_response(resp_json):
"""
Upstage 문서 디지타이제이션 응답에서 text와 bbox 후보를 추출.
구조가 달라도 dict/list를 재귀 탐색하여 'text' 유사 키와 bbox 유사 키를 모읍니다.
"""
texts = []
boxes = []
def walk(obj):
if isinstance(obj, dict):
for k, v in obj.items():
kl = k.lower()
# text 후보 키
if kl in ("text", "content", "ocr_text", "full_text", "value"):
if isinstance(v, str) and v.strip():
texts.append(v.strip())
# bbox/box 후보 키
if kl in ("bbox", "box", "bounding_box", "boundingbox", "polygon"):
boxes.append(v)
# 재귀
walk(v)
elif isinstance(obj, list):
for item in obj:
walk(item)
walk(resp_json)
merged_text = (
"\n".join(texts) if texts else json.dumps(resp_json, ensure_ascii=False)
)
return merged_text, boxes
# (Upstage) 외부 OCR API 호출 + 후처리
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def call_upstage_ocr_api(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Upstage OCR 작업 시작")
if not UPSTAGE_API_KEY:
raise ValueError("Upstage API 키가 설정되지 않았습니다.")
url = "https://api.upstage.ai/v1/document-digitization"
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try:
# 1) 파일 다운로드
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
try:
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
except Exception as e:
raise RuntimeError(f"파일 다운로드 실패: {e}")
self.update_progress(request_id, "파일 다운로드 완료")
# 2) Upstage API 호출(시간 측정)
start_time = time.time()
with open(tmp_path, "rb") as f:
files = {"document": (file_name, f, "application/octet-stream")}
data = {"model": "ocr"}
try:
with httpx.Client(timeout=120.0, follow_redirects=True) as client:
response = client.post(url, headers=headers, files=files, data=data)
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Upstage API 오류: {e.response.text}")
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
except Exception as e:
logger.error(f"Upstage API 호출 중 예외 발생: {e}")
raise RuntimeError("Upstage API 호출 실패")
end_time = time.time()
self.update_progress(request_id, "Upstage API 호출 성공")
# 3) 응답 정규화 → text/coord 추출
resp_json = response.json()
text, coord = _normalize_upstage_response(resp_json)
# 4) 공통 후처리(JSON 스키마 통일)
result_json = ocr_process(
file_name, # 1
"upstage", # 2
coord, # 3
text, # 4
start_time, # 5
end_time, # 6
)
self.update_progress(request_id, "후처리 완료")
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# 결과 Redis 저장 (체인의 두 번째 스텝)
# router 체인: store_ocr_result.s(request_id=request_id, task_id=task_id)
@celery_app.task(bind=True, base=BaseTaskWithProgress, ignore_result=True)
def store_ocr_result(self, result_data: dict, request_id: str, task_id: str):
self.update_progress(request_id, "결과 저장 중")
redis_key = f"ocr_result:{task_id}"
redis_client.set(redis_key, json.dumps(result_data, ensure_ascii=False))
self.update_progress(request_id, "모든 작업 완료")