Files
ocr_gateway_test/tasks.py
2025-10-27 09:18:24 +09:00

158 lines
5.6 KiB
Python

import asyncio
import json
import logging
import os
import tempfile
import time
from datetime import datetime
import httpx
import redis
from celery import Task, chain
from config.setting import REDIS_DB, REDIS_HOST, REDIS_PORT
from utils.celery_utils import celery_app
from utils.ocr_processor import ocr_process
from utils.text_extractor import extract_text_from_file
# ✅ Redis 클라이언트 생성
redis_client = redis.Redis(
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
)
# ✅ 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ✅ 공통 Task 베이스 클래스 - 상태 로그 기록 및 예외 후킹 제공
class BaseTaskWithProgress(Task):
"""
Celery Task를 상속한 공통 Task 베이스 클래스입니다.
주요 목적은:
- update_progress()로 Redis에 작업 진행상황 저장
- on_failure, on_success 메서드를 오버라이딩하여 자동 상태 기록
주요 기능:
- update_progress: 단계별 상태를 ocr_status:{request_id}에 rpush
- on_failure: 예외 발생 시 에러 로그 저장
- on_success: 작업 성공 시 성공 로그 저장
"""
abstract = True
def update_progress(self, request_id, status_message, step_info=None):
log_entry = {
"status": status_message,
"timestamp": datetime.now().isoformat(),
"step_info": step_info,
}
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
logger.info(f"[{request_id}] Task Progress: {status_message}")
def on_failure(self, exc, task_id, args, kwargs, einfo):
request_id = kwargs.get("request_id", "unknown")
self.update_progress(
request_id,
"작업 오류 발생",
{"error": str(exc), "traceback": str(einfo)},
)
logger.error(f"[{request_id}] Task Failed: {exc}")
super().on_failure(exc, task_id, args, kwargs, einfo)
def on_success(self, retval, task_id, args, kwargs):
request_id = kwargs.get("request_id", "unknown")
self.update_progress(request_id, "작업 완료")
logger.info(f"[{request_id}] Task Succeeded")
super().on_success(retval, task_id, args, kwargs)
# ✅ Step 1: presigned URL에서 파일 다운로드
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def fetch_file_from_url(
self, file_url: str, file_name: str, request_id: str, task_id: str
):
self.update_progress(request_id, "파일 다운로드 중")
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try:
asyncio.run(
download_file_from_presigned_url(file_url, tmp_path)
) # 비동기 다운로드 함수 호출
except Exception as e:
raise RuntimeError(f"파일 다운로드 실패: {e}")
self.update_progress(request_id, "파일 다운로드 완료")
return tmp_path
# ✅ Step 2: OCR 및 후처리 수행
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def parse_ocr_text(
self, tmp_path: str, request_id: str, file_name: str, ocr_model: str = "upstage"
):
self.update_progress(request_id, "OCR 작업 시작")
start_time = time.time()
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path, ocr_model))
end_time = time.time()
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
result_json = ocr_process(file_name, ocr_model, coord, text, start_time, end_time)
return {"result": result_json, "tmp_path": tmp_path}
# ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def store_ocr_result(self, data: dict, request_id: str, task_id: str):
self.update_progress(request_id, "결과 저장 중")
redis_key = f"ocr_result:{task_id}"
redis_client.set(redis_key, json.dumps(data["result"]))
try:
os.remove(data["tmp_path"])
except Exception:
logger.warning(f"[{request_id}] 임시 파일 삭제 실패")
self.update_progress(request_id, "모든 작업 완료")
# ✅ 실제 presigned URL에서 파일 다운로드 수행
async def download_file_from_presigned_url(file_url: str, save_path: str):
async with httpx.AsyncClient() as client:
resp = await client.get(file_url)
resp.raise_for_status()
with open(save_path, "wb") as f:
f.write(resp.content)
# ✅ 전체 OCR 체인 실행 함수
def run_ocr_pipeline(file_url, file_name, request_id, task_id, ocr_model):
chain(
fetch_file_from_url.s(
file_url=file_url, file_name=file_name, request_id=request_id, task_id=task_id
) # ✅ Step 1: presigned URL에서 파일 다운로드
| parse_ocr_text.s(
request_id=request_id, file_name=file_name, ocr_model=ocr_model
) # ✅ Step 2: OCR 및 후처리 수행
| store_ocr_result.s(
request_id=request_id, task_id=task_id
) # ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제
).apply_async(task_id=task_id)
# ✅ 결과 조회 함수: Redis에서 task_id로 OCR 결과 조회
def get_ocr_result(task_id: str):
redis_key = f"ocr_result:{task_id}"
result = redis_client.get(redis_key)
if result:
return json.loads(result)
return None
# ✅ 상태 로그 조회 함수: Redis에서 request_id 기반 상태 로그 조회
def get_ocr_status_log(request_id: str):
redis_key = f"ocr_status:{request_id}"
logs = redis_client.lrange(redis_key, 0, -1)
return [json.loads(entry) for entry in logs]