first commit
This commit is contained in:
316
utils/text_extractor.py
Normal file
316
utils/text_extractor.py
Normal file
@@ -0,0 +1,316 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import httpx
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pytesseract
|
||||
from config.setting import UPSTAGE_API_KEY, UPSTAGE_API_URL
|
||||
from paddleocr import PaddleOCR, PPStructureV3
|
||||
|
||||
from .file_handler import process_file
|
||||
from .preprocessor import tess_prep_cv2, to_rgb_uint8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# PaddleOCR 및 PPStructure 모델을 전역 변수로 초기화
|
||||
# 이렇게 하면 Celery 워커가 시작될 때 한 번만 모델을 로드합니다.
|
||||
_paddle_ocr_model = None
|
||||
_paddle_structure_model = None
|
||||
|
||||
|
||||
def get_paddle_ocr_model():
|
||||
"""PaddleOCR 모델 인스턴스를 반환합니다 (Singleton)."""
|
||||
global _paddle_ocr_model
|
||||
if _paddle_ocr_model is None:
|
||||
device = os.getenv("PADDLE_DEVICE", "cpu")
|
||||
logger.info(f"Initializing PaddleOCR model on device: {device}")
|
||||
_paddle_ocr_model = PaddleOCR(
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
device=device,
|
||||
lang="korean",
|
||||
)
|
||||
logger.info("PaddleOCR model initialized.")
|
||||
return _paddle_ocr_model
|
||||
|
||||
|
||||
def get_paddle_structure_model():
|
||||
"""PPStructure 모델 인스턴스를 반환합니다 (Singleton)."""
|
||||
global _paddle_structure_model
|
||||
if _paddle_structure_model is None:
|
||||
device = os.getenv("PADDLE_DEVICE", "cpu")
|
||||
logger.info(f"Initializing PPStructure model on device: {device}")
|
||||
_paddle_structure_model = PPStructureV3(
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
device=device,
|
||||
lang="korean",
|
||||
layout_threshold=0.3, # 레이아웃 인식 실패로 임계값 수정됨
|
||||
)
|
||||
logger.info("PPStructure model initialized.")
|
||||
return _paddle_structure_model
|
||||
|
||||
|
||||
async def extract_text_from_file(file_path, ocr_model):
|
||||
"""
|
||||
파일을 처리하고 OCR 모델을 적용하여 텍스트를 추출합니다.
|
||||
"""
|
||||
images, text_only, needs_ocr = await process_file(file_path, ocr_model)
|
||||
|
||||
if not needs_ocr:
|
||||
return text_only, [], "OCR not used"
|
||||
|
||||
if ocr_model == "tesseract":
|
||||
logger.info(f"[TESSERACT] {ocr_model} 로 이미지에서 텍스트 추출 중...")
|
||||
full_response, coord_response = await asyncio.to_thread(
|
||||
extract_tesseract_ocr, images
|
||||
)
|
||||
elif ocr_model == "pp-ocr":
|
||||
logger.info(f"[PP-OCR] {ocr_model}로 이미지에서 텍스트 추출 중...")
|
||||
full_response, coord_response = await asyncio.to_thread(
|
||||
extract_paddle_ocr, images
|
||||
)
|
||||
elif ocr_model == "pp-structure":
|
||||
logger.info(f"[PP-STRUCTURE] {ocr_model}로 이미지에서 텍스트 추출 중...")
|
||||
full_response, coord_response = await asyncio.to_thread(
|
||||
extract_paddle_structure, images
|
||||
)
|
||||
elif ocr_model == "upstage":
|
||||
logger.info(f"[UPSTAGE] {ocr_model}로 이미지에서 텍스트 추출 중...")
|
||||
full_response, coord_response = await extract_upstage_ocr(file_path)
|
||||
else:
|
||||
logger.error(f"[OCR MODEL] 지원하지 않는 모델입니다. ({ocr_model})")
|
||||
raise ValueError(f"지원하지 않는 OCR 모델입니다: {ocr_model}")
|
||||
|
||||
return full_response, coord_response, ocr_model
|
||||
|
||||
|
||||
# ✅ tesseract
|
||||
def extract_tesseract_ocr(images):
|
||||
"""
|
||||
tesseract를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
|
||||
"""
|
||||
all_texts = []
|
||||
coord_response = []
|
||||
|
||||
for page_idx, img in enumerate(images):
|
||||
logger.info(f"[UTILS-OCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
|
||||
pre_img = tess_prep_cv2(img)
|
||||
text = pytesseract.image_to_string(
|
||||
pre_img, lang="kor+eng", config="--oem 3 --psm 6"
|
||||
)
|
||||
all_texts.append(text)
|
||||
|
||||
ocr_data = pytesseract.image_to_data(
|
||||
pre_img,
|
||||
output_type=pytesseract.Output.DICT,
|
||||
lang="kor+eng",
|
||||
config="--oem 3 --psm 6",
|
||||
)
|
||||
for i in range(len(ocr_data["text"])):
|
||||
word = ocr_data["text"][i].strip()
|
||||
if word == "":
|
||||
continue
|
||||
x, y, w, h = (
|
||||
ocr_data["left"][i],
|
||||
ocr_data["top"][i],
|
||||
ocr_data["width"][i],
|
||||
ocr_data["height"][i],
|
||||
)
|
||||
coord_response.append(
|
||||
{"text": word, "coords": [x, y, x + w, y + h], "page": page_idx + 1}
|
||||
)
|
||||
|
||||
logger.info(f"[UTILS-OCR] 페이지 {page_idx + 1} 텍스트 및 좌표 추출 완료")
|
||||
|
||||
full_response = "\n".join(all_texts)
|
||||
return full_response, coord_response
|
||||
|
||||
|
||||
# ✅ PaddleOCR
|
||||
def extract_paddle_ocr(images):
|
||||
"""
|
||||
PaddleOCR를 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
|
||||
"""
|
||||
ocr = get_paddle_ocr_model()
|
||||
|
||||
full_response = []
|
||||
coord_response = []
|
||||
|
||||
for page_idx, img in enumerate(images):
|
||||
print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
|
||||
img_np = np.array(img)
|
||||
|
||||
# ✅ 채널/타입 표준화 (grayscale/rgba/float 등 대응)
|
||||
try:
|
||||
img_np = to_rgb_uint8(img_np)
|
||||
except Exception as e:
|
||||
print(f"[PaddleOCR] 페이지 {page_idx + 1} 입력 표준화 실패: {e}")
|
||||
continue # 문제 페이지 스킵 후 다음 페이지 진행
|
||||
|
||||
# ✅ 과도한 해상도 안정화 (최대 변 4000px)
|
||||
h, w = img_np.shape[:2]
|
||||
max_side = max(h, w)
|
||||
max_side_limit = 4000
|
||||
if max_side > max_side_limit:
|
||||
scale = max_side_limit / max_side
|
||||
new_size = (int(w * scale), int(h * scale))
|
||||
img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
|
||||
print(f"[PaddleOCR] Resized to {img_np.shape[1]}x{img_np.shape[0]}")
|
||||
|
||||
results = ocr.predict(input=img_np)
|
||||
|
||||
try:
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.synchronize()
|
||||
paddle.device.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"[PaddleOCR] 페이지 {page_idx + 1} OCR 결과 개수: {len(results)}")
|
||||
for res_idx, res in enumerate(results):
|
||||
print(f"[PaddleOCR] 페이지 {page_idx + 1} 결과 {res_idx + 1}개 추출 완료")
|
||||
res_dic = dict(res.items())
|
||||
|
||||
texts = res_dic.get("rec_texts", [])
|
||||
boxes = res_dic.get("rec_boxes", [])
|
||||
|
||||
for text, bbox in zip(texts, boxes):
|
||||
full_response.append(text)
|
||||
coord_response.append(
|
||||
{"text": text, "coords": bbox.tolist(), "page": page_idx + 1}
|
||||
)
|
||||
|
||||
print("[PaddleOCR] 전체 페이지 텍스트 및 좌표 추출 완료")
|
||||
return "\n".join(full_response), coord_response
|
||||
|
||||
|
||||
# ✅ PaddleStructure
|
||||
def extract_paddle_structure(images):
|
||||
"""
|
||||
PaddleSTRUCTURE 사용하여 이미지에서 텍스트 추출 및 좌표 정보 반환
|
||||
"""
|
||||
structure = get_paddle_structure_model()
|
||||
|
||||
full_response = []
|
||||
coord_response = []
|
||||
|
||||
for page_idx, img in enumerate(images):
|
||||
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} OCR로 텍스트 추출 중...")
|
||||
img_np = np.array(img)
|
||||
print(f"[Padddle-IMG]{img}")
|
||||
|
||||
# ✅ 채널/타입 표준화 (grayscale/rgba/float 등 대응)
|
||||
try:
|
||||
img_np = to_rgb_uint8(img_np)
|
||||
except Exception as e:
|
||||
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} 입력 표준화 실패: {e}")
|
||||
continue # 문제 페이지 스킵 후 다음 페이지 진행
|
||||
|
||||
# ✅ 과도한 해상도 안정화 (최대 변 4000px)
|
||||
h, w = img_np.shape[:2]
|
||||
max_side = max(h, w)
|
||||
max_side_limit = 4000
|
||||
if max_side > max_side_limit:
|
||||
scale = max_side_limit / max_side
|
||||
new_size = (int(w * scale), int(h * scale))
|
||||
img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
|
||||
print(f"[PaddleSTRUCTURE] Resized to {img_np.shape[1]}x{img_np.shape[0]}")
|
||||
|
||||
results = structure.predict(input=img_np)
|
||||
|
||||
try:
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} OCR 결과 개수: {len(results)}")
|
||||
for res_idx, res in enumerate(results):
|
||||
print(
|
||||
f"[PaddleSTRUCTURE] 페이지 {page_idx + 1} 결과 {res_idx + 1}개 추출 완료"
|
||||
)
|
||||
res_dic = dict(res.items())
|
||||
blocks = res_dic.get("parsing_res_list", []) or []
|
||||
|
||||
for block in blocks:
|
||||
bd = block.to_dict()
|
||||
|
||||
content = bd.get("content", [])
|
||||
bbox = bd.get("bbox", [])
|
||||
|
||||
full_response.append(content)
|
||||
|
||||
coord_response.append(
|
||||
{"text": content, "coords": bbox, "page": page_idx + 1}
|
||||
)
|
||||
|
||||
print("[PaddleSTRUCTURE] 전체 페이지 텍스트 및 좌표 추출 완료")
|
||||
return "\n".join(full_response), coord_response
|
||||
|
||||
|
||||
# ✅ Upstage OCR API
|
||||
async def extract_upstage_ocr(file_path: str):
|
||||
"""
|
||||
Upstage OCR API를 사용하여 이미지에서 텍스트 및 좌표 추출
|
||||
"""
|
||||
if not UPSTAGE_API_KEY:
|
||||
raise ValueError("Upstage API 키가 설정되지 않았습니다.")
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"파일이 존재하지 않습니다: {file_path}")
|
||||
|
||||
url = UPSTAGE_API_URL
|
||||
if not url:
|
||||
url = "https://api.upstage.ai/v1/document-ai/ocr"
|
||||
logger.warning(f"UPSTAGE_API_URL not set in config, using default: {url}")
|
||||
|
||||
headers = {"Authorization": f"Bearer {UPSTAGE_API_KEY}"}
|
||||
data = {"model": "ocr"}
|
||||
filename = Path(file_path).name
|
||||
full_text_parts = []
|
||||
coord_response = []
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"document": (filename, f, "application/octet-stream")}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
response = await client.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Upstage API 오류: {e.response.text}")
|
||||
raise RuntimeError(f"Upstage API 오류: {e.response.status_code}")
|
||||
|
||||
try:
|
||||
pages = result.get("pages", [])
|
||||
for page_idx, p in enumerate(pages, start=1):
|
||||
txt = p.get("text")
|
||||
if txt:
|
||||
full_text_parts.append(txt)
|
||||
|
||||
for w in p.get("words", []):
|
||||
verts = (w.get("boundingBox", {}) or {}).get("vertices")
|
||||
if not verts or len(verts) != 4:
|
||||
continue
|
||||
xs = [v.get("x", 0) for v in verts]
|
||||
ys = [v.get("y", 0) for v in verts]
|
||||
coord_response.append(
|
||||
{
|
||||
"text": w.get("text"),
|
||||
"coords": [min(xs), min(ys), max(xs), max(ys)],
|
||||
"page": page_idx,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[UPSTAGE] JSON 파싱 실패: {e} / 원본 result: {result}")
|
||||
return "", []
|
||||
|
||||
full_response = "\n".join(full_text_parts)
|
||||
return full_response, coord_response
|
||||
Reference in New Issue
Block a user