158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import httpx
|
|
import redis
|
|
from celery import Task, chain
|
|
from config.setting import REDIS_DB, REDIS_HOST, REDIS_PORT
|
|
from utils.celery_utils import celery_app
|
|
from utils.ocr_processor import ocr_process
|
|
from utils.text_extractor import extract_text_from_file
|
|
|
|
# ✅ Redis 클라이언트 생성
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
|
|
)
|
|
|
|
# ✅ 로깅 설정
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ✅ 공통 Task 베이스 클래스 - 상태 로그 기록 및 예외 후킹 제공
|
|
class BaseTaskWithProgress(Task):
|
|
"""
|
|
Celery Task를 상속한 공통 Task 베이스 클래스입니다.
|
|
주요 목적은:
|
|
|
|
- update_progress()로 Redis에 작업 진행상황 저장
|
|
- on_failure, on_success 메서드를 오버라이딩하여 자동 상태 기록
|
|
|
|
주요 기능:
|
|
- update_progress: 단계별 상태를 ocr_status:{request_id}에 rpush
|
|
- on_failure: 예외 발생 시 에러 로그 저장
|
|
- on_success: 작업 성공 시 성공 로그 저장
|
|
"""
|
|
|
|
abstract = True
|
|
|
|
def update_progress(self, request_id, status_message, step_info=None):
|
|
log_entry = {
|
|
"status": status_message,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"step_info": step_info,
|
|
}
|
|
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
|
|
logger.info(f"[{request_id}] Task Progress: {status_message}")
|
|
|
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(
|
|
request_id,
|
|
"작업 오류 발생",
|
|
{"error": str(exc), "traceback": str(einfo)},
|
|
)
|
|
logger.error(f"[{request_id}] Task Failed: {exc}")
|
|
super().on_failure(exc, task_id, args, kwargs, einfo)
|
|
|
|
def on_success(self, retval, task_id, args, kwargs):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(request_id, "작업 완료")
|
|
logger.info(f"[{request_id}] Task Succeeded")
|
|
super().on_success(retval, task_id, args, kwargs)
|
|
|
|
|
|
# ✅ Step 1: presigned URL에서 파일 다운로드
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def fetch_file_from_url(
|
|
self, file_url: str, file_name: str, request_id: str, task_id: str
|
|
):
|
|
self.update_progress(request_id, "파일 다운로드 중")
|
|
suffix = os.path.splitext(file_name)[-1]
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
asyncio.run(
|
|
download_file_from_presigned_url(file_url, tmp_path)
|
|
) # 비동기 다운로드 함수 호출
|
|
except Exception as e:
|
|
raise RuntimeError(f"파일 다운로드 실패: {e}")
|
|
|
|
self.update_progress(request_id, "파일 다운로드 완료")
|
|
return tmp_path
|
|
|
|
|
|
# ✅ Step 2: OCR 및 후처리 수행
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def parse_ocr_text(
|
|
self, tmp_path: str, request_id: str, file_name: str, ocr_model: str = "upstage"
|
|
):
|
|
self.update_progress(request_id, "OCR 작업 시작")
|
|
start_time = time.time()
|
|
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path, ocr_model))
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
|
|
result_json = ocr_process(file_name, ocr_model, coord, text, start_time, end_time)
|
|
return {"result": result_json, "tmp_path": tmp_path}
|
|
|
|
|
|
# ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def store_ocr_result(self, data: dict, request_id: str, task_id: str):
|
|
self.update_progress(request_id, "결과 저장 중")
|
|
redis_key = f"ocr_result:{task_id}"
|
|
redis_client.set(redis_key, json.dumps(data["result"]))
|
|
|
|
try:
|
|
os.remove(data["tmp_path"])
|
|
except Exception:
|
|
logger.warning(f"[{request_id}] 임시 파일 삭제 실패")
|
|
|
|
self.update_progress(request_id, "모든 작업 완료")
|
|
|
|
|
|
# ✅ 실제 presigned URL에서 파일 다운로드 수행
|
|
async def download_file_from_presigned_url(file_url: str, save_path: str):
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(file_url)
|
|
resp.raise_for_status()
|
|
with open(save_path, "wb") as f:
|
|
f.write(resp.content)
|
|
|
|
|
|
# ✅ 전체 OCR 체인 실행 함수
|
|
def run_ocr_pipeline(file_url, file_name, request_id, task_id, ocr_model):
|
|
chain(
|
|
fetch_file_from_url.s(
|
|
file_url=file_url, file_name=file_name, request_id=request_id, task_id=task_id
|
|
) # ✅ Step 1: presigned URL에서 파일 다운로드
|
|
| parse_ocr_text.s(
|
|
request_id=request_id, file_name=file_name, ocr_model=ocr_model
|
|
) # ✅ Step 2: OCR 및 후처리 수행
|
|
| store_ocr_result.s(
|
|
request_id=request_id, task_id=task_id
|
|
) # ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제
|
|
).apply_async(task_id=task_id)
|
|
|
|
|
|
# ✅ 결과 조회 함수: Redis에서 task_id로 OCR 결과 조회
|
|
def get_ocr_result(task_id: str):
|
|
redis_key = f"ocr_result:{task_id}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
return json.loads(result)
|
|
return None
|
|
|
|
|
|
# ✅ 상태 로그 조회 함수: Redis에서 request_id 기반 상태 로그 조회
|
|
def get_ocr_status_log(request_id: str):
|
|
redis_key = f"ocr_status:{request_id}"
|
|
logs = redis_client.lrange(redis_key, 0, -1)
|
|
return [json.loads(entry) for entry in logs]
|