Compare commits

...

8 Commits

10 changed files with 580 additions and 179 deletions

2
.gitignore vendored
View File

@@ -171,3 +171,5 @@ venv2
/workspace/audio /workspace/audio
/workspace/results /workspace/results
.venv_stt .venv_stt
config/model

15
api.py
View File

@@ -9,6 +9,7 @@ from prometheus_fastapi_instrumentator import Instrumentator
from router import ocr_router from router import ocr_router
from utils.celery_utils import celery_app from utils.celery_utils import celery_app
from utils.celery_utils import health_check as celery_health_check_task from utils.celery_utils import health_check as celery_health_check_task
from utils.minio_utils import get_minio_client
from utils.redis_utils import get_redis_client from utils.redis_utils import get_redis_client
logging.basicConfig( logging.basicConfig(
@@ -16,7 +17,7 @@ logging.basicConfig(
) )
app = FastAPI(title="OCR GATEWAY", description="OCR API 서비스", docs_url="/docs") app = FastAPI(title="OCR LAB", description="OCR 성능 비교 분석", docs_url="/docs")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -58,7 +59,6 @@ def redis_health_check():
@app.get("/health/Celery") @app.get("/health/Celery")
async def celery_health_check(): async def celery_health_check():
"""Celery 워커 상태 확인""" """Celery 워커 상태 확인"""
# celery_app = get_celery_app() # 이제 celery_utils에서 직접 임포트합니다.
try: try:
# 1. 워커들에게 ping 보내기 # 1. 워커들에게 ping 보내기
@@ -124,3 +124,14 @@ async def flower_health_check():
status_code=500, status_code=500,
detail=f"An error occurred during Flower health check: {str(e)}", detail=f"An error occurred during Flower health check: {str(e)}",
) )
@app.get("/health/MinIO")
def minio_health_check():
try:
client = get_minio_client()
return {"status": "MinIO ok"}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"MinIO health check failed: {str(e)}"
)

View File

@@ -5,7 +5,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Redis 기본 설정 # Redis 기본 설정
REDIS_HOST = "ocr_redis" REDIS_HOST = "ocr_perf_redis"
REDIS_PORT = 6379 REDIS_PORT = 6379
REDIS_DB = 0 REDIS_DB = 0
@@ -14,8 +14,13 @@ CELERY_BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/0"
CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/1" CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/1"
# Celery Flower 설정 # Celery Flower 설정
CELERY_FLOWER = "http://ocr_celery_flower:5557/api/workers" CELERY_FLOWER = "http://ocr_perf_flower:5557/api/workers"
# Upstage API Key # Upstage API Key
UPSTAGE_API_KEY = os.getenv("UPSTAGE_API_KEY") UPSTAGE_API_KEY = os.getenv("UPSTAGE_API_KEY")
# MinIO Settings
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY")
MINIO_BUCKET_NAME = os.getenv("MINIO_BUCKET_NAME")

View File

@@ -3,7 +3,7 @@ services:
build: build:
context: . context: .
image: ocr_perf_api image: ocr_perf_api
container_name: ocr_perf_lab_api container_name: ocr_perf_api
ports: ports:
- "8892:8892" - "8892:8892"
volumes: volumes:
@@ -12,8 +12,8 @@ services:
- .env - .env
environment: environment:
- TZ=Asia/Seoul - TZ=Asia/Seoul
- CELERY_BROKER_URL=redis://ocr_perf_lab_redis:6379/0 - CELERY_BROKER_URL=redis://ocr_perf_redis:6379/0
- CELERY_RESULT_BACKEND=redis://ocr_perf_lab_redis:6379/1 - CELERY_RESULT_BACKEND=redis://ocr_perf_redis:6379/1
- TESSDATA_PREFIX=/usr/share/tessdata - TESSDATA_PREFIX=/usr/share/tessdata
restart: always restart: always
networks: networks:
@@ -26,28 +26,18 @@ services:
count: all count: all
capabilities: [gpu] capabilities: [gpu]
depends_on: depends_on:
ocr_perf_lab_redis: ocr_perf_redis:
condition: service_healthy condition: service_healthy
healthcheck:
test:
[
"CMD-SHELL",
"curl -f http://localhost:8892/health/API && curl -f http://localhost:8892/health/Redis && curl -f http://localhost:8892/health/Celery && curl -f http://localhost:8892/health/Flower",
]
interval: 60s
timeout: 5s
retries: 3
start_period: 10s
ocr_perf_lab_worker: ocr_perf_worker:
image: ocr_perf_api image: ocr_perf_api
container_name: ocr_perf_lab_worker container_name: ocr_perf_worker
volumes: volumes:
- ./:/workspace - ./:/workspace
environment: environment:
- TZ=Asia/Seoul - TZ=Asia/Seoul
- CELERY_BROKER_URL=redis://ocr_perf_lab_redis:6379/0 - CELERY_BROKER_URL=redis://ocr_perf_redis:6379/0
- CELERY_RESULT_BACKEND=redis://ocr_perf_lab_redis:6379/1 - CELERY_RESULT_BACKEND=redis://ocr_perf_redis:6379/1
- TESSDATA_PREFIX=/usr/share/tessdata - TESSDATA_PREFIX=/usr/share/tessdata
command: celery -A tasks worker --loglevel=info --concurrency=4 command: celery -A tasks worker --loglevel=info --concurrency=4
networks: networks:
@@ -60,35 +50,27 @@ services:
count: all count: all
capabilities: [gpu] capabilities: [gpu]
depends_on: depends_on:
ocr_perf_lab_redis: ocr_perf_redis:
condition: service_healthy condition: service_healthy
ocr_perf_celery_flower: ocr_perf_flower:
image: ocr_perf_api image: ocr_perf_api
container_name: ocr_perf_celery_flower container_name: ocr_perf_flower
environment: environment:
- TZ=Asia/Seoul - TZ=Asia/Seoul
- FLOWER_UNAUTHENTICATED_API=true - FLOWER_UNAUTHENTICATED_API=true
- TESSDATA_PREFIX=/usr/share/tessdata - TESSDATA_PREFIX=/usr/share/tessdata
entrypoint: celery --broker=redis://ocr_perf_lab_redis:6379/0 flower --port=5557 entrypoint: celery --broker=redis://ocr_perf_redis:6379/0 flower --port=5557
ports: ports:
- "5557:5557" - "5557:5557"
networks: networks:
- ocr_perf_net - ocr_perf_net
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
depends_on: depends_on:
ocr_perf_lab_redis: - ocr_perf_redis
condition: service_healthy
ocr_perf_lab_redis: ocr_perf_redis:
image: redis:7-alpine image: redis:7-alpine
container_name: ocr_perf_lab_redis container_name: ocr_perf_redis
command: command:
[ [
"redis-server", "redis-server",

View File

@@ -19,3 +19,4 @@ flower
minio minio
opencv-python-headless opencv-python-headless
python-dotenv python-dotenv
pytesseract

View File

@@ -1,126 +1,152 @@
import json import json
import os import logging
import tempfile
from datetime import datetime from datetime import datetime
from typing import List
from celery import chain from celery import chain
from celery.result import AsyncResult from celery.result import AsyncResult
from config.setting import MINIO_BUCKET_NAME
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from tasks import ( from tasks import (
call_paddle_ocr,
call_tesseract_ocr,
call_tesstrain_ocr,
call_upstage_ocr_api, call_upstage_ocr_api,
celery_app, celery_app,
parse_ocr_text,
store_ocr_result, store_ocr_result,
) )
from utils.checking_keys import create_key from utils.checking_keys import create_key
from utils.minio_utils import upload_file_to_minio
from utils.redis_utils import get_redis_client from utils.redis_utils import get_redis_client
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ocr", tags=["OCR"]) router = APIRouter(prefix="/ocr", tags=["OCR"])
redis_client = get_redis_client() redis_client = get_redis_client()
async def _process_ocr_request(files: List[UploadFile], ocr_task): async def _process_ocr_request(file: UploadFile, ocr_task):
results = [] if not file.filename:
for file in files: raise HTTPException(status_code=400, detail="파일 이름이 없습니다.")
if not file.filename:
raise HTTPException(status_code=400, detail="파일 이름이 없습니다.")
tmp_path = "" request_id = create_key()
try: task_id = create_key()
suffix = os.path.splitext(file.filename)[-1] bucket_name = MINIO_BUCKET_NAME
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: object_name = f"{request_id}/{file.filename}"
content = await file.read()
tmp_file.write(content)
tmp_path = tmp_file.name
except Exception as e:
raise HTTPException(status_code=500, detail=f"파일 저장 실패: {str(e)}")
finally:
await file.close()
request_id = create_key() # MinIO에 파일 업로드 후 presigned URL 생성
task_id = create_key() presigned_url = upload_file_to_minio(
file=file, bucket_name=bucket_name, object_name=object_name
)
logger.info(f"[MinIO] ✅ presigned URL 생성 완료: {presigned_url}")
task_chain = chain( task_chain = chain(
ocr_task.s( ocr_task.s(
tmp_path=tmp_path, request_id=request_id, file_name=file.filename presigned_url=presigned_url, request_id=request_id, file_name=file.filename
), ),
store_ocr_result.s(request_id=request_id, task_id=task_id), store_ocr_result.s(request_id=request_id, task_id=task_id),
) )
task_chain.apply_async(task_id=task_id) task_chain.apply_async(task_id=task_id)
try: # Redis에 request_id → task_id 매핑 저장
redis_client.hset("ocr_task_mapping", request_id, task_id) try:
except Exception as e: redis_client.hset("ocr_task_mapping", request_id, task_id)
if tmp_path and os.path.exists(tmp_path): except Exception as e:
os.remove(tmp_path) raise HTTPException(status_code=500, detail=f"작업 정보 저장 오류: {str(e)}")
raise HTTPException(
status_code=500, detail=f"작업 정보 저장 오류: {str(e)}"
)
try: try:
log_entry = { log_entry = {
"status": "작업 접수", "status": "작업 접수",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"task_id": task_id, "initial_file": file.filename,
"initial_file": file.filename, }
} redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry)) except Exception:
except Exception: pass
pass
results.append( return JSONResponse(
{ content={
"message": "OCR 작업이 접수되었습니다.", "message": "OCR 작업이 접수되었습니다.",
"request_id": request_id, "request_id": request_id,
"task_id": task_id, "status_check_url": f"/ocr/progress/{request_id}",
"status_check_url": f"/ocr/progress/{request_id}", "filename": file.filename,
"filename": file.filename, }
} )
)
return JSONResponse(content={"results": results})
@router.post("/paddle", summary="[Paddle] 파일 업로드 기반 비동기 OCR") @router.post("/paddle", summary="[Paddle] 파일 업로드 기반 비동기 OCR")
async def ocr_paddle_endpoint(files: List[UploadFile] = File(...)): async def ocr_paddle_endpoint(file: UploadFile = File(...)):
return await _process_ocr_request(files, parse_ocr_text) return await _process_ocr_request(file, call_paddle_ocr)
@router.post("/upstage", summary="[Upstage] 파일 업로드 기반 비동기 OCR") @router.post("/upstage", summary="[Upstage] 파일 업로드 기반 비동기 OCR")
async def ocr_upstage_endpoint(files: List[UploadFile] = File(...)): async def ocr_upstage_endpoint(file: UploadFile = File(...)):
return await _process_ocr_request(files, call_upstage_ocr_api) return await _process_ocr_request(file, call_upstage_ocr_api)
@router.get("/progress/{request_id}", summary="📊 OCR 진행 상태 및 결과 조회") @router.post("/tesseract", summary="[Tesseract] 기본 모델 비동기 OCR")
async def ocr_tesseract_endpoint(file: UploadFile = File(...)):
return await _process_ocr_request(file, call_tesseract_ocr)
@router.post("/tesstrain", summary="[Tesseract] 훈련된 모델 비동기 OCR")
async def ocr_tesstrain_endpoint(file: UploadFile = File(...)):
return await _process_ocr_request(file, call_tesstrain_ocr)
@router.get("/progress/{request_id}", summary="OCR 진행 상태 및 결과 조회")
async def check_progress(request_id: str): async def check_progress(request_id: str):
task_id = redis_client.hget("ocr_task_mapping", request_id) task_id = redis_client.hget("ocr_task_mapping", request_id)
if not task_id: if not task_id:
raise HTTPException(status_code=404, detail=f"ID {request_id} 작업을 찾을 수 없습니다.") raise HTTPException(
status_code=404, detail=f"ID {request_id} 작업을 찾을 수 없습니다."
result = AsyncResult(task_id, app=celery_app) )
status = result.status
# 1) 진행 로그 조회
try: try:
logs = redis_client.lrange(f"ocr_status:{request_id}", 0, -1) logs_raw = redis_client.lrange(f"ocr_status:{request_id}", 0, -1)
parsed_logs = [json.loads(log) for log in logs] parsed_logs = [json.loads(x) for x in logs_raw]
except Exception as e: except Exception as e:
parsed_logs = [{"status": "로그 조회 실패", "error": str(e)}] parsed_logs = [{"status": "로그 조회 실패", "error": str(e)}]
# 2) 로그 기반 파생 상태(dervived_status) 계산
derived_status = None
if parsed_logs:
last = parsed_logs[-1].get("status")
if last in ("모든 작업 완료", "작업 완료"):
derived_status = "SUCCESS"
elif last == "작업 오류 발생":
derived_status = "FAILURE"
# 3) Celery 상태 (가능하면 조회, 실패해도 무시)
celery_status = "PENDING"
try:
result = AsyncResult(task_id, app=celery_app)
celery_status = result.status or "PENDING"
except Exception:
pass
# 4) **상태와 무관하게** 결과 먼저 조회
final_result = None final_result = None
if status == "SUCCESS": try:
try: result_str = redis_client.get(f"ocr_result:{task_id}")
result_str = redis_client.get(f"ocr_result:{task_id}") if result_str:
if result_str: final_result = json.loads(result_str)
final_result = json.loads(result_str) # 결과가 있으면 상태를 SUCCESS로 정규화
except Exception as e: if derived_status is None and celery_status not in ("FAILURE", "REVOKED"):
final_result = {"error": f"결과 조회 실패: {str(e)}"} derived_status = "SUCCESS"
except Exception as e:
# 결과 조회 실패도 노출
final_result = {"error": f"결과 조회 실패: {str(e)}"}
# 5) 최종 표시 상태 선택(로그/결과가 더 신뢰되면 그걸 우선)
display_status = derived_status or celery_status
return JSONResponse( return JSONResponse(
content={ content={
"request_id": request_id, "request_id": request_id,
"task_id": task_id, "task_id": task_id,
"celery_status": status, "celery_status": celery_status, # 원래 Celery 상태(참고용)
"status": display_status, # 사용자가 보기 쉬운 최종 상태
"progress_logs": parsed_logs, "progress_logs": parsed_logs,
"final_result": final_result, "final_result": final_result,
} }

283
tasks.py
View File

@@ -2,10 +2,13 @@ import asyncio
import json import json
import logging import logging
import os import os
import tempfile
import time import time
from datetime import datetime from datetime import datetime, timezone
from io import BytesIO
import httpx import httpx
import pytesseract
import redis import redis
from celery import Task from celery import Task
from config.setting import ( from config.setting import (
@@ -14,28 +17,30 @@ from config.setting import (
REDIS_PORT, REDIS_PORT,
UPSTAGE_API_KEY, UPSTAGE_API_KEY,
) )
from PIL import Image
from pdf2image import convert_from_path
from utils.celery_utils import celery_app from utils.celery_utils import celery_app
from utils.ocr_processor import ocr_process from utils.ocr_processor import ocr_process
from utils.text_extractor import extract_text_from_file from utils.text_extractor import extract_text_from_file
# Redis 클라이언트 생성 # Redis 클라이언트
redis_client = redis.Redis( redis_client = redis.Redis(
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True
) )
# ✅ 로깅 설정 # 로깅
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 공통 Task 베이스 클래스 - 상태 로그 기록 및 예외 후킹 제공 # 공통 Task 베이스 클래스 (진행 로그 + 실패/성공 훅)
class BaseTaskWithProgress(Task): class BaseTaskWithProgress(Task):
abstract = True abstract = True
def update_progress(self, request_id, status_message, step_info=None): def update_progress(self, request_id: str, status_message: str, step_info=None):
log_entry = { log_entry = {
"status": status_message, "status": status_message,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now(timezone.utc).isoformat(),
"step_info": step_info, "step_info": step_info,
} }
redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry)) redis_client.rpush(f"ocr_status:{request_id}", json.dumps(log_entry))
@@ -49,15 +54,6 @@ class BaseTaskWithProgress(Task):
{"error": str(exc), "traceback": str(einfo)}, {"error": str(exc), "traceback": str(einfo)},
) )
logger.error(f"[{request_id}] Task Failed: {exc}") logger.error(f"[{request_id}] Task Failed: {exc}")
# 실패 시 임시 파일 삭제
tmp_path = kwargs.get("tmp_path")
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
self.update_progress(request_id, "임시 파일 삭제 완료")
except Exception as e:
logger.error(f"[{request_id}] 임시 파일 삭제 실패: {e}")
super().on_failure(exc, task_id, args, kwargs, einfo) super().on_failure(exc, task_id, args, kwargs, einfo)
def on_success(self, retval, task_id, args, kwargs): def on_success(self, retval, task_id, args, kwargs):
@@ -67,21 +63,107 @@ class BaseTaskWithProgress(Task):
super().on_success(retval, task_id, args, kwargs) super().on_success(retval, task_id, args, kwargs)
# ✅ (Paddle) Step 2: OCR 및 후처리 수행 # presigned URL에서 파일 다운로드 (비동기)
async def download_file_from_presigned_url(file_url: str, save_path: str):
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
resp = await client.get(file_url)
resp.raise_for_status()
with open(save_path, "wb") as f:
f.write(resp.content)
# (Paddle) OCR + 후처리
@celery_app.task(bind=True, base=BaseTaskWithProgress) @celery_app.task(bind=True, base=BaseTaskWithProgress)
def parse_ocr_text(self, tmp_path: str, request_id: str, file_name: str): def call_paddle_ocr(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Paddle OCR 작업 시작") self.update_progress(request_id, "Paddle OCR 작업 시작")
start_time = time.time()
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path)) suffix = os.path.splitext(file_name)[-1]
end_time = time.time() with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
self.update_progress(request_id, "텍스트 추출 및 후처리 완료") tmp_path = tmp_file.name
result_json = ocr_process(file_name, ocr_model, coord, text, start_time, end_time)
return {"result": result_json, "tmp_path": tmp_path} try:
# 1) 파일 다운로드
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
try:
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
except Exception as e:
raise RuntimeError(f"파일 다운로드 실패: {e}")
self.update_progress(request_id, "파일 다운로드 완료")
# 2) OCR 실행
start_time = time.time()
text, coord, ocr_model = asyncio.run(extract_text_from_file(tmp_path))
end_time = time.time()
self.update_progress(request_id, "텍스트 추출 및 후처리 완료")
# 3) 결과 JSON 생성
result_json = ocr_process(
file_name, # 1
ocr_model, # 2
coord, # 3
text, # 4
start_time, # 5
end_time, # 6
)
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# ✅ (Upstage) Step 2: Upstage OCR API 호출 # Upstage 응답 정규화: 가능한 많은 'text'를 모으고, 후보 bbox를 수집
def _normalize_upstage_response(resp_json, return_word_level=False, normalize=False):
"""
Upstage 문서 디지타이제이션 응답에서 text와 bbox 후보를 추출.
구조가 달라도 dict/list를 재귀 탐색하여 'text' 유사 키와 bbox 유사 키를 모읍니다.
"""
# 1) 전체 텍스트 추출
if isinstance(resp_json, dict) and resp_json.get("text"):
full_text = resp_json["text"]
else:
full_text = ""
for p in resp_json.get("pages") or []:
t = p.get("text")
if t:
full_text += t + "\n"
full_text = full_text.rstrip("\n")
# 2) 좌표/워드 추출
coords = []
word_items = []
pages = resp_json.get("pages") or []
for p_idx, page in enumerate(pages, start=1):
w = page.get("words") or []
pw, ph = page.get("width"), page.get("height") # 정규화 옵션용
for wobj in w:
bb = (wobj.get("boundingBox") or {}).get("vertices") or []
if len(bb) == 4:
poly = [[float(pt.get("x", 0)), float(pt.get("y", 0))] for pt in bb]
if normalize and pw and ph:
poly = [[x / float(pw), y / float(ph)] for x, y in poly]
if return_word_level:
word_items.append(
{
"page": p_idx,
"text": wobj.get("text", ""),
"confidence": float(wobj.get("confidence") or 0.0),
"box": poly, # 4x2
}
)
else:
coords.append(poly)
if return_word_level:
return full_text, word_items
return full_text, coords
# (Upstage) 외부 OCR API 호출 + 후처리
@celery_app.task(bind=True, base=BaseTaskWithProgress) @celery_app.task(bind=True, base=BaseTaskWithProgress)
def call_upstage_ocr_api(self, tmp_path: str, request_id: str, file_name: str): def call_upstage_ocr_api(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Upstage OCR 작업 시작") self.update_progress(request_id, "Upstage OCR 작업 시작")
if not UPSTAGE_API_KEY: if not UPSTAGE_API_KEY:
@@ -90,36 +172,143 @@ def call_upstage_ocr_api(self, tmp_path: str, request_id: str, file_name: str):
url = "https://api.upstage.ai/v1/document-digitization" url = "https://api.upstage.ai/v1/document-digitization"
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"} headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try: try:
# 1) 파일 다운로드
self.update_progress(request_id, "파일 다운로드 중 (presigned URL)")
try:
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
except Exception as e:
raise RuntimeError(f"파일 다운로드 실패: {e}")
self.update_progress(request_id, "파일 다운로드 완료")
# 2) Upstage API 호출(시간 측정)
start_time = time.time()
with open(tmp_path, "rb") as f: with open(tmp_path, "rb") as f:
files = {"document": (file_name, f, "application/octet-stream")} files = {"document": (file_name, f, "application/octet-stream")}
data = {"model": "ocr"} data = {"model": "ocr"}
with httpx.Client() as client: try:
response = client.post(url, headers=headers, files=files, data=data) with httpx.Client(timeout=120.0, follow_redirects=True) as client:
response.raise_for_status() response = client.post(url, headers=headers, files=files, data=data)
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Upstage API 오류: {e.response.text}")
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
except Exception as e:
logger.error(f"Upstage API 호출 중 예외 발생: {e}")
raise RuntimeError("Upstage API 호출 실패")
end_time = time.time()
self.update_progress(request_id, "Upstage API 호출 성공") self.update_progress(request_id, "Upstage API 호출 성공")
return {"result": response.json(), "tmp_path": tmp_path}
except httpx.HTTPStatusError as e: # 3) 응답 정규화 → text/coord 추출
logger.error(f"Upstage API 오류: {e.response.text}") resp_json = response.json()
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}") text, coord = _normalize_upstage_response(resp_json)
except Exception as e:
logger.error(f"Upstage API 호출 중 예외 발생: {e}") # 4) 공통 후처리(JSON 스키마 통일)
raise RuntimeError("Upstage API 호출 실패") result_json = ocr_process(
file_name, # 1
"upstage", # 2
coord, # 3
text, # 4
start_time, # 5
end_time, # 6
)
self.update_progress(request_id, "후처리 완료")
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# ✅ Step 3: 결과 Redis 저장 및 임시 파일 삭제 # (Tesseract) 기본 모델 OCR
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def call_tesseract_ocr(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Tesseract (기본) OCR 작업 시작")
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try:
self.update_progress(request_id, "파일 다운로드 중")
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
self.update_progress(request_id, "파일 다운로드 완료")
start_time = time.time()
if file_name.lower().endswith(".pdf"):
images = convert_from_path(tmp_path)
text = ""
for image in images:
text += pytesseract.image_to_string(image, lang="kor")
else:
with open(tmp_path, "rb") as f:
image_bytes = f.read()
image = Image.open(BytesIO(image_bytes))
text = pytesseract.image_to_string(image, lang="kor")
end_time = time.time()
self.update_progress(request_id, "Tesseract OCR 완료")
# 좌표(coord) 정보는 pytesseract 기본 출력에서 얻기 어려우므로 빈 리스트로 처리
result_json = ocr_process(
file_name, "tesseract_default", [], text, start_time, end_time
)
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# (Tesseract) 훈련된 모델 OCR
@celery_app.task(bind=True, base=BaseTaskWithProgress)
def call_tesstrain_ocr(self, presigned_url: str, request_id: str, file_name: str):
self.update_progress(request_id, "Tesseract (훈련 모델) OCR 작업 시작")
TESSDATA_DIR = "/tesseract_trainer/tesstrain/workspace/"
MODEL_NAME = "kor_fonts"
suffix = os.path.splitext(file_name)[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_path = tmp_file.name
try:
self.update_progress(request_id, "파일 다운로드 중")
asyncio.run(download_file_from_presigned_url(presigned_url, tmp_path))
self.update_progress(request_id, "파일 다운로드 완료")
start_time = time.time()
if file_name.lower().endswith(".pdf"):
images = convert_from_path(tmp_path)
text = ""
custom_config = f"--tessdata-dir {TESSDATA_DIR} -l {MODEL_NAME}"
for image in images:
text += pytesseract.image_to_string(image, config=custom_config)
else:
with open(tmp_path, "rb") as f:
image_bytes = f.read()
image = Image.open(BytesIO(image_bytes))
custom_config = f"--tessdata-dir {TESSDATA_DIR} -l {MODEL_NAME}"
text = pytesseract.image_to_string(image, config=custom_config)
end_time = time.time()
self.update_progress(request_id, "Tesseract (훈련 모델) OCR 완료")
result_json = ocr_process(
file_name, f"tesstrain_{MODEL_NAME}", [], text, start_time, end_time
)
return result_json
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
# 결과 Redis 저장 (체인의 두 번째 스텝)
# router 체인: store_ocr_result.s(request_id=request_id, task_id=task_id)
@celery_app.task(bind=True, base=BaseTaskWithProgress, ignore_result=True) @celery_app.task(bind=True, base=BaseTaskWithProgress, ignore_result=True)
def store_ocr_result(self, data: dict, request_id: str, task_id: str): def store_ocr_result(self, result_data: dict, request_id: str, task_id: str):
self.update_progress(request_id, "결과 저장 중") self.update_progress(request_id, "결과 저장 중")
redis_key = f"ocr_result:{task_id}" redis_key = f"ocr_result:{task_id}"
redis_client.set(redis_key, json.dumps(data.get("result", {}))) redis_client.set(redis_key, json.dumps(result_data, ensure_ascii=False))
tmp_path = data.get("tmp_path")
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
self.update_progress(request_id, "임시 파일 삭제 완료")
except Exception as e:
logger.warning(f"[{request_id}] 임시 파일 삭제 실패: {e}")
self.update_progress(request_id, "모든 작업 완료") self.update_progress(request_id, "모든 작업 완료")

99
utils/minio_utils.py Normal file
View File

@@ -0,0 +1,99 @@
import logging
from datetime import timedelta
from config.setting import (
MINIO_ACCESS_KEY,
MINIO_BUCKET_NAME,
MINIO_ENDPOINT,
MINIO_SECRET_KEY,
)
from fastapi import UploadFile
from minio import Minio
from minio.error import S3Error
logger = logging.getLogger(__name__)
def get_minio_client():
"""MinIO 클라이언트를 생성하고 반환합니다."""
try:
client = Minio(
MINIO_ENDPOINT,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
secure=False, # 개발 환경에서는 False, 프로덕션에서는 True 사용
)
# 버킷 존재 여부 확인 및 생성
found = client.bucket_exists(MINIO_BUCKET_NAME)
if not found:
client.make_bucket(MINIO_BUCKET_NAME)
logger.info(f"Bucket '{MINIO_BUCKET_NAME}' created.")
else:
logger.info(f"Bucket '{MINIO_BUCKET_NAME}' already exists.")
return client
except (S3Error, Exception) as e:
logger.error(f"Error connecting to MinIO: {e}")
raise
def upload_file_to_minio(file: UploadFile, bucket_name: str, object_name: str) -> str:
"""
파일을 MinIO에 업로드하고, presigned URL을 반환합니다.
Args:
file (UploadFile): FastAPI의 UploadFile 객체
bucket_name (str): 업로드할 버킷 이름
object_name (str): 저장될 객체 이름 (경로 포함 가능)
Returns:
str: 생성된 presigned URL
"""
minio_client = get_minio_client()
try:
# 1. 버킷 존재 확인 및 생성
found = minio_client.bucket_exists(bucket_name)
if not found:
minio_client.make_bucket(bucket_name)
logger.info(f"✅ 버킷 '{bucket_name}' 생성 완료.")
# 2. 파일 업로드
file.file.seek(0) # 파일 포인터를 처음으로 이동
minio_client.put_object(
bucket_name,
object_name,
file.file,
length=-1, # 파일 크기를 모를 때 -1로 설정
part_size=10 * 1024 * 1024, # 10MB 단위로 청크 업로드
)
logger.info(f"'{object_name}' -> '{bucket_name}' 업로드 성공.")
# 3. Presigned URL 생성
presigned_url = minio_client.presigned_get_object(
bucket_name,
object_name,
expires=timedelta(days=7), # URL 만료 기간 (예: 7일, 필요에 따라 조절 가능)
)
logger.info(f"✅ Presigned URL 생성 완료: {presigned_url}")
return presigned_url
except Exception as e:
logger.error(f"❌ MinIO 작업 실패: {e}")
raise # 실패 시 예외를 다시 발생시켜 호출 측에서 처리하도록 함
def download_file_from_minio(object_name: str, local_path: str):
"""
MinIO에서 객체를 다운로드하여 로컬 파일로 저장합니다.
Args:
object_name (str): 다운로드할 객체의 이름
local_path (str): 파일을 저장할 로컬 경로
"""
client = get_minio_client()
try:
client.fget_object(MINIO_BUCKET_NAME, object_name, local_path)
logger.info(f"'{object_name}' downloaded to '{local_path}' successfully.")
except S3Error as e:
logger.error(f"Error downloading from MinIO: {e}")
raise

View File

@@ -7,7 +7,7 @@ def ocr_process(filename, ocr_model, coord, text, start_time, end_time):
"started_at": start_time, "started_at": start_time,
"ended_at": end_time, "ended_at": end_time,
}, },
"fields": coord, # "fields": coord,
"parsed": text, "parsed": text,
} }

View File

@@ -19,13 +19,13 @@ async def extract_text_from_file(file_path):
images = [] images = []
if ext == ".pdf": if ext == ".pdf":
# ① 먼저 PDF에서 텍스트 추출 시도 # ① 먼저 PDF에서 텍스트 추출 시도 -> GT를 만들기에 무조건 ocr 과정 거치도록 변경
text_only = await asyncio.to_thread(extract_text_from_pdf_direct, file_path) # text_only = await asyncio.to_thread(extract_text_from_pdf_direct, file_path)
if text_only.strip(): # if text_only.strip():
logger.info( # logger.info(
"[UTILS-TEXT] PDF는 텍스트 기반입니다. (OCR 없이 텍스트 추출 완료)" # "[UTILS-TEXT] PDF는 텍스트 기반입니다. (OCR 없이 텍스트 추출 완료)"
) # )
return text_only, [], "OCR not used" # return text_only, [], "OCR not used"
# ② 텍스트가 없으면 이미지 변환 → OCR 수행 # ② 텍스트가 없으면 이미지 변환 → OCR 수행
images = await asyncio.to_thread(convert_from_path, file_path, dpi=400) images = await asyncio.to_thread(convert_from_path, file_path, dpi=400)
@@ -105,6 +105,41 @@ def preprocess_image_for_ocr(pil_img, page_idx=None):
return Image.fromarray(img) return Image.fromarray(img)
def _to_rgb_uint8(img_np: np.ndarray) -> np.ndarray:
"""
입력 이미지를 3채널 RGB, uint8 [0,255] 로 표준화
허용 입력: HxW, HxWx1, HxWx3, HxWx4, float[0..1]/[0..255], int 등
"""
if img_np is None:
raise ValueError("Input image is None")
# dtype/범위 표준화
if img_np.dtype != np.uint8:
arr = img_np.astype(np.float32)
if arr.max() <= 1.0: # [0,1]로 보이면 스케일업
arr *= 255.0
arr = np.clip(arr, 0, 255).astype(np.uint8)
img_np = arr
# 채널 표준화
if img_np.ndim == 2: # HxW
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
elif img_np.ndim == 3:
h, w, c = img_np.shape
if c == 1:
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
elif c == 4:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
elif c == 3:
pass # 그대로 사용
else:
raise ValueError(f"Unsupported channel count: {c}")
else:
raise ValueError(f"Unsupported ndim: {img_np.ndim}")
return img_np
def extract_text_paddle_ocr(images): def extract_text_paddle_ocr(images):
""" """
PaddleOCR를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환 PaddleOCR를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
@@ -114,15 +149,29 @@ def extract_text_paddle_ocr(images):
use_doc_orientation_classify=False, use_doc_unwarping=False, lang="korean" use_doc_orientation_classify=False, use_doc_unwarping=False, lang="korean"
) )
full_response = []
coord_response = [] coord_response = []
all_text_boxes = [] # (y_center, x_center, text, box) 저장용
for page_idx, img in enumerate(images): for page_idx, img in enumerate(images):
print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...") print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
img_np = np.array(img) img_np = np.array(img)
if len(img_np.shape) == 2: # grayscale → RGB 변환 # ✅ 채널/타입 표준화 (grayscale/rgba/float 등 대응)
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB) try:
img_np = _to_rgb_uint8(img_np)
except Exception as e:
print(f"[PaddleOCR] 페이지 {page_idx + 1} 입력 표준화 실패: {e}")
continue # 문제 페이지 스킵 후 다음 페이지 진행
# ✅ 과도한 해상도 안정화 (최대 변 4000px)
h, w = img_np.shape[:2]
max_side = max(h, w)
max_side_limit = 4000
if max_side > max_side_limit:
scale = max_side_limit / max_side
new_size = (int(w * scale), int(h * scale))
img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
print(f"[PaddleOCR] Resized to {img_np.shape[1]}x{img_np.shape[0]}")
results = ocr.predict(input=img_np) results = ocr.predict(input=img_np)
@@ -134,13 +183,50 @@ def extract_text_paddle_ocr(images):
texts = res_dic.get("rec_texts", []) texts = res_dic.get("rec_texts", [])
boxes = res_dic.get("rec_boxes", []) boxes = res_dic.get("rec_boxes", [])
full_response.extend(texts) for text, box in zip(texts, boxes):
if isinstance(box, np.ndarray):
box = box.tolist()
# ✅ box 정규화
if all(isinstance(p, (int, float)) for p in box):
if len(box) % 2 == 0:
box = [[box[i], box[i + 1]] for i in range(0, len(box), 2)]
else:
print(f"[PaddleOCR] 잘못된 box 형식: {box}")
continue
# ndarray → list 변환 coord_response.append(box)
clean_boxes = [
box.tolist() if isinstance(box, np.ndarray) else box for box in boxes # 중심 좌표 계산 (y → 줄 순서, x → 단어 순서)
] x_coords = [p[0] for p in box]
coord_response.extend(clean_boxes) y_coords = [p[1] for p in box]
x_center = sum(x_coords) / len(x_coords)
y_center = sum(y_coords) / len(y_coords)
all_text_boxes.append((y_center, x_center, text))
# ✅ 위치 기반 정렬
all_text_boxes.sort(key=lambda x: (x[0], x[1])) # y 먼저, 그 다음 x 정렬
# ✅ 줄 단위 그룹핑
lines = []
current_line = []
prev_y = None
line_threshold = 15 # 줄 묶음 y 오차 허용값
for y, x, text in all_text_boxes:
if prev_y is None or abs(y - prev_y) < line_threshold:
current_line.append((x, text))
else:
current_line.sort(key=lambda xx: xx[0])
lines.append(" ".join(t for _, t in current_line))
current_line = [(x, text)]
prev_y = y
if current_line:
current_line.sort(key=lambda xx: xx[0])
lines.append(" ".join(t for _, t in current_line))
parsed_text = "\n".join(lines)
print("[PaddleOCR] 전체 페이지 텍스트 및 좌표 추출 완료") print("[PaddleOCR] 전체 페이지 텍스트 및 좌표 추출 완료")
return " ".join(full_response), coord_response return parsed_text, coord_response