293 lines
11 KiB
Python
293 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Literal, Optional
|
|
|
|
import httpx
|
|
import redis
|
|
from config.setting import OCR_API_URL, OCR_REDIS_DB, OCR_REDIS_HOST, OCR_REDIS_PORT
|
|
from utils.checking_files import token_counter
|
|
from utils.image_converter import prepare_images_from_file
|
|
from utils.logging_utils import log_pipeline_status, log_user_request
|
|
from utils.text_generator import (
|
|
ClaudeGenerator,
|
|
GeminiGenerator,
|
|
GptGenerator,
|
|
OllamaGenerator,
|
|
)
|
|
from utils.text_processor import post_process
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Redis 클라이언트 생성 (Celery 결과용 DB=1)
|
|
redis_client = redis.Redis(
|
|
host=OCR_REDIS_HOST,
|
|
port=OCR_REDIS_PORT,
|
|
db=OCR_REDIS_DB,
|
|
decode_responses=True,
|
|
)
|
|
|
|
|
|
class PipelineRunner:
|
|
@staticmethod
|
|
async def run_pipeline(
|
|
request_info: str, # ✅ 추가
|
|
request_id: str,
|
|
file_path: str,
|
|
filename: str,
|
|
prompt: str,
|
|
prompt_filename: str, # ✅ 추가
|
|
custom_mode: bool,
|
|
mode: str,
|
|
model: str,
|
|
inner_models: List[str],
|
|
outer_models: List[str],
|
|
model_url_map: Dict[str, str],
|
|
api_key: str,
|
|
schema_override: Optional[dict] = None,
|
|
prompt_mode: Literal["general", "extract"] = "extract",
|
|
):
|
|
start_time = time.time()
|
|
|
|
if mode == "multimodal":
|
|
# 모델 유효성
|
|
if model not in outer_models:
|
|
raise ValueError(
|
|
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
|
|
)
|
|
if not (("gpt" in model) or ("gemini" in model)):
|
|
raise ValueError("멀티모달 E2E는 gpt 계열만 지원합니다.")
|
|
|
|
# 입력 파일 → 이미지 바이트 리스트 준비
|
|
images = await prepare_images_from_file(file_path, filename)
|
|
|
|
# 요청 로깅(텍스트가 없으므로 prompt 길이만)
|
|
context_length = len(prompt)
|
|
try:
|
|
log_user_request(
|
|
request_info=request_info,
|
|
endpoint=f"/{prompt_mode}/{mode}",
|
|
input_filename=filename,
|
|
model=model,
|
|
prompt_filename=prompt_filename,
|
|
context_length=context_length,
|
|
api_key=api_key,
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
|
|
|
|
# 멀티모달 LLM 호출
|
|
log_pipeline_status(request_id, "멀티모달 LLM 추론 시작")
|
|
if "gpt" in model:
|
|
generator = GptGenerator(model=model)
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.generate_multimodal, images, prompt, schema_override
|
|
)
|
|
elif "gemini" in model:
|
|
generator = GeminiGenerator(model=model)
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.generate_multimodal, images, prompt, schema_override
|
|
)
|
|
|
|
end_time = time.time()
|
|
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
|
|
|
|
# 멀티모달은 OCR 텍스트/좌표 없음
|
|
text = ""
|
|
coord = None
|
|
ocr_model = "bypass(multimodal)"
|
|
|
|
json_data = post_process(
|
|
filename,
|
|
text,
|
|
generated_text,
|
|
coord,
|
|
ocr_model,
|
|
llm_model,
|
|
llm_url,
|
|
mode,
|
|
start_time,
|
|
end_time,
|
|
prompt_mode,
|
|
)
|
|
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
|
|
return json_data
|
|
|
|
try:
|
|
# OCR API 요청
|
|
log_pipeline_status(request_id, "OCR API 호출 시작")
|
|
async with httpx.AsyncClient() as client:
|
|
# ✅ presigned URL을 OCR API로 전달
|
|
ocr_resp = await client.post(
|
|
OCR_API_URL,
|
|
json=[
|
|
{
|
|
"file_url": file_path, # presigned URL
|
|
"filename": filename,
|
|
}
|
|
],
|
|
timeout=None,
|
|
)
|
|
ocr_resp.raise_for_status()
|
|
|
|
# OCR API 응답에서 task_id 추출
|
|
task_ids_json = ocr_resp.json()
|
|
print(f"[DEBUG] OCR API 응답: {task_ids_json}")
|
|
task_ids = [
|
|
item.get("task_id") for item in task_ids_json.get("results", [])
|
|
]
|
|
if not task_ids:
|
|
raise ValueError("❌ OCR API에서 유효한 task_id를 받지 못했습니다.")
|
|
task_id = task_ids[0]
|
|
|
|
# Redis에서 결과를 5초 간격으로 최대 10회 폴링
|
|
raw_result = None
|
|
for attempt in range(10): # 최대 10회 시도
|
|
redis_key = f"ocr_result:{task_id}"
|
|
raw_result = redis_client.get(redis_key)
|
|
if raw_result:
|
|
logger.info(
|
|
f"✅ Redis에서 task_id '{task_id}'에 대한 OCR 결과를 찾았습니다."
|
|
)
|
|
break
|
|
await asyncio.sleep(5)
|
|
|
|
if not raw_result: # 결과가 없으면 예외 발생
|
|
error_message = (
|
|
"❌ OCR API에서 작업을 완료하지 못했습니다. 페이지 수를 줄여주세요."
|
|
)
|
|
logger.error(error_message)
|
|
raise ValueError(error_message)
|
|
|
|
result_data = json.loads(raw_result)
|
|
text = result_data["parsed"]
|
|
coord = result_data.get("fields")
|
|
ocr_model = result_data.get("ocr_model", "OCR API(pytesseract)")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ OCR 처리 중 예외 발생: {e}")
|
|
raise
|
|
|
|
# ✅ 입력 길이 검사
|
|
log_pipeline_status(request_id, "모델 입력 텍스트 길이 검사 시작")
|
|
token_count = token_counter(prompt, text)
|
|
context_length = len(prompt + text)
|
|
|
|
# 🔽 로그 기록
|
|
try:
|
|
log_user_request(
|
|
request_info=request_info,
|
|
endpoint=f"/{prompt_mode}/{mode}",
|
|
input_filename=filename,
|
|
model=model,
|
|
prompt_filename=prompt_filename,
|
|
context_length=context_length,
|
|
api_key=api_key,
|
|
# token_count=token_count,
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
|
|
|
|
# ✅ 120K 토큰 초과 검사
|
|
if token_count > 120000:
|
|
return post_process(
|
|
filename,
|
|
text,
|
|
f"⚠️ 입력 텍스트가 {token_count} 토큰으로 입력 길이를 초과했습니다. 모델 호출 생략합니다.",
|
|
coord,
|
|
ocr_model,
|
|
"N/A",
|
|
"N/A",
|
|
mode,
|
|
start_time,
|
|
time.time(),
|
|
prompt_mode,
|
|
)
|
|
|
|
# 2. 내부 모델 처리 (Ollama)
|
|
if mode in ("inner", "all", "structured"):
|
|
if model in inner_models:
|
|
log_pipeline_status(request_id, "내부 LLM 추론 시작")
|
|
api_url = model_url_map.get(model)
|
|
if not api_url:
|
|
raise ValueError(
|
|
f"❌ 모델 '{model}'이 로드된 Ollama 서버를 찾을 수 없습니다."
|
|
)
|
|
|
|
generator = OllamaGenerator(model=model, api_url=api_url)
|
|
|
|
if mode == "structured":
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.structured_generate,
|
|
text,
|
|
prompt,
|
|
custom_mode,
|
|
schema_override,
|
|
)
|
|
else:
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.generate, text, prompt, custom_mode, prompt_mode
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"내부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. inner_models: {inner_models}"
|
|
)
|
|
|
|
# 3. 외부 모델 처리
|
|
elif mode in ("outer", "all", "structured"):
|
|
if model in outer_models:
|
|
log_pipeline_status(request_id, "외부 LLM 추론 시작")
|
|
if "claude" in model:
|
|
generator = ClaudeGenerator(model=model)
|
|
elif "gemini" in model:
|
|
generator = GeminiGenerator(model=model)
|
|
elif "gpt" in model:
|
|
generator = GptGenerator(model=model)
|
|
else:
|
|
raise ValueError(
|
|
"지원되지 않는 외부 모델입니다. ['gemini', 'claude', 'gpt'] 중 선택하세요."
|
|
)
|
|
|
|
if mode == "structured":
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.structured_generate,
|
|
text,
|
|
prompt,
|
|
custom_mode,
|
|
schema_override,
|
|
)
|
|
else:
|
|
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
|
generator.generate, text, prompt, custom_mode, prompt_mode
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"❌ 지원되지 않는 모드입니다. 'inner', 'outer', 'all', 'structured' 중에서 선택하세요. (입력: {mode})"
|
|
)
|
|
|
|
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
|
|
end_time = time.time()
|
|
|
|
# 4. 후처리
|
|
json_data = post_process(
|
|
filename,
|
|
text,
|
|
generated_text,
|
|
coord,
|
|
ocr_model,
|
|
llm_model,
|
|
llm_url,
|
|
mode,
|
|
start_time,
|
|
end_time,
|
|
prompt_mode,
|
|
)
|
|
|
|
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
|
|
return json_data
|