Files
ocr_gateway_test/utils/text_extractor.py
2025-10-27 09:18:24 +09:00

317 lines
11 KiB
Python

import asyncio
import logging
import os
from pathlib import Path
import cv2
import httpx
import numpy as np
import paddle
import pytesseract
from config.setting import UPSTAGE_API_KEY, UPSTAGE_API_URL
from paddleocr import PaddleOCR, PPStructureV3
from .file_handler import process_file
from .preprocessor import tess_prep_cv2, to_rgb_uint8
logger = logging.getLogger(__name__)
# PaddleOCR 및 PPStructure 모델을 전역 변수로 초기화
# 이렇게 하면 Celery 워커가 시작될 때 한 번만 모델을 로드합니다.
_paddle_ocr_model = None
_paddle_structure_model = None
def get_paddle_ocr_model():
"""PaddleOCR 모델 인스턴스를 반환합니다 (Singleton)."""
global _paddle_ocr_model
if _paddle_ocr_model is None:
device = os.getenv("PADDLE_DEVICE", "cpu")
logger.info(f"Initializing PaddleOCR model on device: {device}")
_paddle_ocr_model = PaddleOCR(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
device=device,
lang="korean",
)
logger.info("PaddleOCR model initialized.")
return _paddle_ocr_model
def get_paddle_structure_model():
"""PPStructure 모델 인스턴스를 반환합니다 (Singleton)."""
global _paddle_structure_model
if _paddle_structure_model is None:
device = os.getenv("PADDLE_DEVICE", "cpu")
logger.info(f"Initializing PPStructure model on device: {device}")
_paddle_structure_model = PPStructureV3(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
device=device,
lang="korean",
layout_threshold=0.3, # 레이아웃 인식 실패로 임계값 수정됨
)
logger.info("PPStructure model initialized.")
return _paddle_structure_model
async def extract_text_from_file(file_path, ocr_model):
"""
파일을 처리하고 OCR 모델을 적용하여 텍스트를 추출합니다.
"""
images, text_only, needs_ocr = await process_file(file_path, ocr_model)
if not needs_ocr:
return text_only, [], "OCR not used"
if ocr_model == "tesseract":
logger.info(f"[TESSERACT] {ocr_model} 로 이미지에서 텍스트 추출 중...")
full_response, coord_response = await asyncio.to_thread(
extract_tesseract_ocr, images
)
elif ocr_model == "pp-ocr":
logger.info(f"[PP-OCR] {ocr_model}로 이미지에서 텍스트 추출 중...")
full_response, coord_response = await asyncio.to_thread(
extract_paddle_ocr, images
)
elif ocr_model == "pp-structure":
logger.info(f"[PP-STRUCTURE] {ocr_model}로 이미지에서 텍스트 추출 중...")
full_response, coord_response = await asyncio.to_thread(
extract_paddle_structure, images
)
elif ocr_model == "upstage":
logger.info(f"[UPSTAGE] {ocr_model}로 이미지에서 텍스트 추출 중...")
full_response, coord_response = await extract_upstage_ocr(file_path)
else:
logger.error(f"[OCR MODEL] 지원하지 않는 모델입니다. ({ocr_model})")
raise ValueError(f"지원하지 않는 OCR 모델입니다: {ocr_model}")
return full_response, coord_response, ocr_model
# ✅ tesseract
def extract_tesseract_ocr(images):
"""
tesseract를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
"""
all_texts = []
coord_response = []
for page_idx, img in enumerate(images):
logger.info(f"[UTILS-OCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
pre_img = tess_prep_cv2(img)
text = pytesseract.image_to_string(
pre_img, lang="kor+eng", config="--oem 3 --psm 6"
)
all_texts.append(text)
ocr_data = pytesseract.image_to_data(
pre_img,
output_type=pytesseract.Output.DICT,
lang="kor+eng",
config="--oem 3 --psm 6",
)
for i in range(len(ocr_data["text"])):
word = ocr_data["text"][i].strip()
if word == "":
continue
x, y, w, h = (
ocr_data["left"][i],
ocr_data["top"][i],
ocr_data["width"][i],
ocr_data["height"][i],
)
coord_response.append(
{"text": word, "coords": [x, y, x + w, y + h], "page": page_idx + 1}
)
logger.info(f"[UTILS-OCR] 페이지 {page_idx + 1} 텍스트 및 좌표 추출 완료")
full_response = "\n".join(all_texts)
return full_response, coord_response
# ✅ PaddleOCR
def extract_paddle_ocr(images):
"""
PaddleOCR를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
"""
ocr = get_paddle_ocr_model()
full_response = []
coord_response = []
for page_idx, img in enumerate(images):
print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
img_np = np.array(img)
# ✅ 채널/타입 표준화 (grayscale/rgba/float 등 대응)
try:
img_np = to_rgb_uint8(img_np)
except Exception as e:
print(f"[PaddleOCR] 페이지 {page_idx + 1} 입력 표준화 실패: {e}")
continue # 문제 페이지 스킵 후 다음 페이지 진행
# ✅ 과도한 해상도 안정화 (최대 변 4000px)
h, w = img_np.shape[:2]
max_side = max(h, w)
max_side_limit = 4000
if max_side > max_side_limit:
scale = max_side_limit / max_side
new_size = (int(w * scale), int(h * scale))
img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
print(f"[PaddleOCR] Resized to {img_np.shape[1]}x{img_np.shape[0]}")
results = ocr.predict(input=img_np)
try:
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
paddle.device.cuda.empty_cache()
except Exception:
pass
print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR 결과 개수: {len(results)}")
for res_idx, res in enumerate(results):
print(f"[PaddleOCR] 페이지 {page_idx + 1} 결과 {res_idx + 1}개 추출 완료")
res_dic = dict(res.items())
texts = res_dic.get("rec_texts", [])
boxes = res_dic.get("rec_boxes", [])
for text, bbox in zip(texts, boxes):
full_response.append(text)
coord_response.append(
{"text": text, "coords": bbox.tolist(), "page": page_idx + 1}
)
print("[PaddleOCR] 전체 페이지 텍스트 및 좌표 추출 완료")
return "\n".join(full_response), coord_response
# ✅ PaddleStructure
def extract_paddle_structure(images):
"""
PaddleSTRUCTURE 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
"""
structure = get_paddle_structure_model()
full_response = []
coord_response = []
for page_idx, img in enumerate(images):
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
img_np = np.array(img)
print(f"[Padddle-IMG]{img}")
# ✅ 채널/타입 표준화 (grayscale/rgba/float 등 대응)
try:
img_np = to_rgb_uint8(img_np)
except Exception as e:
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} 입력 표준화 실패: {e}")
continue # 문제 페이지 스킵 후 다음 페이지 진행
# ✅ 과도한 해상도 안정화 (최대 변 4000px)
h, w = img_np.shape[:2]
max_side = max(h, w)
max_side_limit = 4000
if max_side > max_side_limit:
scale = max_side_limit / max_side
new_size = (int(w * scale), int(h * scale))
img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
print(f"[PaddleSTRUCTURE] Resized to {img_np.shape[1]}x{img_np.shape[0]}")
results = structure.predict(input=img_np)
try:
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
except Exception:
pass
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} OCR 결과 개수: {len(results)}")
for res_idx, res in enumerate(results):
print(
f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} 결과 {res_idx + 1}개 추출 완료"
)
res_dic = dict(res.items())
blocks = res_dic.get("parsing_res_list", []) or []
for block in blocks:
bd = block.to_dict()
content = bd.get("content", [])
bbox = bd.get("bbox", [])
full_response.append(content)
coord_response.append(
{"text": content, "coords": bbox, "page": page_idx + 1}
)
print("[PaddleSTRUCTURE] 전체 페이지 텍스트 및 좌표 추출 완료")
return "\n".join(full_response), coord_response
# ✅ Upstage OCR API
async def extract_upstage_ocr(file_path: str):
"""
Upstage OCR API를 사용하여 이미지에서 텍스트 및 좌표 추출
"""
if not UPSTAGE_API_KEY:
raise ValueError("Upstage API 키가 설정되지 않았습니다.")
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f"파일이 존재하지 않습니다: {file_path}")
url = UPSTAGE_API_URL
if not url:
url = "https://api.upstage.ai/v1/document-ai/ocr"
logger.warning(f"UPSTAGE_API_URL not set in config, using default: {url}")
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
data = {"model": "ocr"}
filename = Path(file_path).name
full_text_parts = []
coord_response = []
with open(file_path, "rb") as f:
files = {"document": (filename, f, "application/octet-stream")}
try:
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
response = await client.post(
url, headers=headers, files=files, data=data
)
response.raise_for_status()
result = response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Upstage API 오류: {e.response.text}")
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
try:
pages = result.get("pages", [])
for page_idx, p in enumerate(pages, start=1):
txt = p.get("text")
if txt:
full_text_parts.append(txt)
for w in p.get("words", []):
verts = (w.get("boundingBox", {}) or {}).get("vertices")
if not verts or len(verts) != 4:
continue
xs = [v.get("x", 0) for v in verts]
ys = [v.get("y", 0) for v in verts]
coord_response.append(
{
"text": w.get("text"),
"coords": [min(xs), min(ys), max(xs), max(ys)],
"page": page_idx,
}
)
except Exception as e:
logger.error(f"[UPSTAGE] JSON 파싱 실패: {e} / 원본 result: {result}")
return "", []
full_response = "\n".join(full_text_parts)
return full_response, coord_response