126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import httpx
|
|
import redis
|
|
from celery import Task
|
|
from config.setting import (
|
|
REDIS_DB,
|
|
REDIS_HOST,
|
|
REDIS_PORT,
|
|
UPSTAGE_API_KEY,
|
|
)
|
|
from utils.celery_utils import celery_app
|
|
from utils.ocr_processor import ocr_process
|
|
from utils.text_extractor import extract_text_from_file
|
|
|
|
# ✅ Redis 클라이언트 생성
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
|
|
)
|
|
|
|
# ✅ 로깅 설정
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ✅ 공통 Task 베이스 클래스 - 상태 로그 기록 및 예외 후킹 제공
|
|
class BaseTaskWithProgress(Task):
|
|
abstract = True
|
|
|
|
def update_progress(self, request_id, status_message, step_info=None):
|
|
log_entry = {
|
|
"status": status_message,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"step_info": step_info,
|
|
}
|
|
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
|
|
logger.info(f"[{request_id}] Task Progress: {status_message}")
|
|
|
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(
|
|
request_id,
|
|
"작업 오류 발생",
|
|
{"error": str(exc), "traceback": str(einfo)},
|
|
)
|
|
logger.error(f"[{request_id}] Task Failed: {exc}")
|
|
# 실패 시 임시 파일 삭제
|
|
tmp_path = kwargs.get("tmp_path")
|
|
if tmp_path and os.path.exists(tmp_path):
|
|
try:
|
|
os.remove(tmp_path)
|
|
self.update_progress(request_id, "임시 파일 삭제 완료")
|
|
except Exception as e:
|
|
logger.error(f"[{request_id}] 임시 파일 삭제 실패: {e}")
|
|
|
|
super().on_failure(exc, task_id, args, kwargs, einfo)
|
|
|
|
def on_success(self, retval, task_id, args, kwargs):
|
|
request_id = kwargs.get("request_id", "unknown")
|
|
self.update_progress(request_id, "작업 완료")
|
|
logger.info(f"[{request_id}] Task Succeeded")
|
|
super().on_success(retval, task_id, args, kwargs)
|
|
|
|
|
|
# ✅ (Paddle) Step 2: OCR 및 후처리 수행
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def parse_ocr_text(self, tmp_path: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Paddle OCR 작업 시작")
|
|
start_time = time.time()
|
|
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path))
|
|
end_time = time.time()
|
|
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
|
|
result_json = ocr_process(file_name, ocr_model, coord, text, start_time, end_time)
|
|
return {"result": result_json, "tmp_path": tmp_path}
|
|
|
|
|
|
# ✅ (Upstage) Step 2: Upstage OCR API 호출
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress)
|
|
def call_upstage_ocr_api(self, tmp_path: str, request_id: str, file_name: str):
|
|
self.update_progress(request_id, "Upstage OCR 작업 시작")
|
|
|
|
if not UPSTAGE_API_KEY:
|
|
raise ValueError("Upstage API 키가 설정되지 않았습니다.")
|
|
|
|
url = "https://api.upstage.ai/v1/document-digitization"
|
|
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
|
|
|
|
try:
|
|
with open(tmp_path, "rb") as f:
|
|
files = {"document": (file_name, f, "application/octet-stream")}
|
|
data = {"model": "ocr"}
|
|
with httpx.Client() as client:
|
|
response = client.post(url, headers=headers, files=files, data=data)
|
|
response.raise_for_status()
|
|
self.update_progress(request_id, "Upstage API 호출 성공")
|
|
return {"result": response.json(), "tmp_path": tmp_path}
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Upstage API 오류: {e.response.text}")
|
|
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
|
|
except Exception as e:
|
|
logger.error(f"Upstage API 호출 중 예외 발생: {e}")
|
|
raise RuntimeError("Upstage API 호출 실패")
|
|
|
|
|
|
# ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제
|
|
@celery_app.task(bind=True, base=BaseTaskWithProgress, ignore_result=True)
|
|
def store_ocr_result(self, data: dict, request_id: str, task_id: str):
|
|
self.update_progress(request_id, "결과 저장 중")
|
|
redis_key = f"ocr_result:{task_id}"
|
|
redis_client.set(redis_key, json.dumps(data.get("result", {})))
|
|
|
|
tmp_path = data.get("tmp_path")
|
|
if tmp_path and os.path.exists(tmp_path):
|
|
try:
|
|
os.remove(tmp_path)
|
|
self.update_progress(request_id, "임시 파일 삭제 완료")
|
|
except Exception as e:
|
|
logger.warning(f"[{request_id}] 임시 파일 삭제 실패: {e}")
|
|
|
|
self.update_progress(request_id, "모든 작업 완료")
|