원 레포랑 완전 분리
This commit is contained in:
292
workspace/services/pipeline_runner.py
Normal file
292
workspace/services/pipeline_runner.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from config.setting import OCR_API_URL, OCR_REDIS_DB, OCR_REDIS_HOST, OCR_REDIS_PORT
|
||||
from utils.checking_files import token_counter
|
||||
from utils.image_converter import prepare_images_from_file
|
||||
from utils.logging_utils import log_pipeline_status, log_user_request
|
||||
from utils.text_generator import (
|
||||
ClaudeGenerator,
|
||||
GeminiGenerator,
|
||||
GptGenerator,
|
||||
OllamaGenerator,
|
||||
)
|
||||
from utils.text_processor import post_process
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis 클라이언트 생성 (Celery 결과용 DB=1)
|
||||
redis_client = redis.Redis(
|
||||
host=OCR_REDIS_HOST,
|
||||
port=OCR_REDIS_PORT,
|
||||
db=OCR_REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
|
||||
class PipelineRunner:
|
||||
@staticmethod
|
||||
async def run_pipeline(
|
||||
request_info: str, # ✅ 추가
|
||||
request_id: str,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
prompt: str,
|
||||
prompt_filename: str, # ✅ 추가
|
||||
custom_mode: bool,
|
||||
mode: str,
|
||||
model: str,
|
||||
inner_models: List[str],
|
||||
outer_models: List[str],
|
||||
model_url_map: Dict[str, str],
|
||||
api_key: str,
|
||||
schema_override: Optional[dict] = None,
|
||||
prompt_mode: Literal["general", "extract"] = "extract",
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
if mode == "multimodal":
|
||||
# 모델 유효성
|
||||
if model not in outer_models:
|
||||
raise ValueError(
|
||||
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
|
||||
)
|
||||
if not (("gpt" in model) or ("gemini" in model)):
|
||||
raise ValueError("멀티모달 E2E는 gpt 계열만 지원합니다.")
|
||||
|
||||
# 입력 파일 → 이미지 바이트 리스트 준비
|
||||
images = await prepare_images_from_file(file_path, filename)
|
||||
|
||||
# 요청 로깅(텍스트가 없으므로 prompt 길이만)
|
||||
context_length = len(prompt)
|
||||
try:
|
||||
log_user_request(
|
||||
request_info=request_info,
|
||||
endpoint=f"/{prompt_mode}/{mode}",
|
||||
input_filename=filename,
|
||||
model=model,
|
||||
prompt_filename=prompt_filename,
|
||||
context_length=context_length,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
|
||||
|
||||
# 멀티모달 LLM 호출
|
||||
log_pipeline_status(request_id, "멀티모달 LLM 추론 시작")
|
||||
if "gpt" in model:
|
||||
generator = GptGenerator(model=model)
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.generate_multimodal, images, prompt, schema_override
|
||||
)
|
||||
elif "gemini" in model:
|
||||
generator = GeminiGenerator(model=model)
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.generate_multimodal, images, prompt, schema_override
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
|
||||
|
||||
# 멀티모달은 OCR 텍스트/좌표 없음
|
||||
text = ""
|
||||
coord = None
|
||||
ocr_model = "bypass(multimodal)"
|
||||
|
||||
json_data = post_process(
|
||||
filename,
|
||||
text,
|
||||
generated_text,
|
||||
coord,
|
||||
ocr_model,
|
||||
llm_model,
|
||||
llm_url,
|
||||
mode,
|
||||
start_time,
|
||||
end_time,
|
||||
prompt_mode,
|
||||
)
|
||||
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
|
||||
return json_data
|
||||
|
||||
try:
|
||||
# OCR API 요청
|
||||
log_pipeline_status(request_id, "OCR API 호출 시작")
|
||||
async with httpx.AsyncClient() as client:
|
||||
# ✅ presigned URL을 OCR API로 전달
|
||||
ocr_resp = await client.post(
|
||||
OCR_API_URL,
|
||||
json=[
|
||||
{
|
||||
"file_url": file_path, # presigned URL
|
||||
"filename": filename,
|
||||
}
|
||||
],
|
||||
timeout=None,
|
||||
)
|
||||
ocr_resp.raise_for_status()
|
||||
|
||||
# OCR API 응답에서 task_id 추출
|
||||
task_ids_json = ocr_resp.json()
|
||||
print(f"[DEBUG] OCR API 응답: {task_ids_json}")
|
||||
task_ids = [
|
||||
item.get("task_id") for item in task_ids_json.get("results", [])
|
||||
]
|
||||
if not task_ids:
|
||||
raise ValueError("❌ OCR API에서 유효한 task_id를 받지 못했습니다.")
|
||||
task_id = task_ids[0]
|
||||
|
||||
# Redis에서 결과를 5초 간격으로 최대 10회 폴링
|
||||
raw_result = None
|
||||
for attempt in range(10): # 최대 10회 시도
|
||||
redis_key = f"ocr_result:{task_id}"
|
||||
raw_result = redis_client.get(redis_key)
|
||||
if raw_result:
|
||||
logger.info(
|
||||
f"✅ Redis에서 task_id '{task_id}'에 대한 OCR 결과를 찾았습니다."
|
||||
)
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if not raw_result: # 결과가 없으면 예외 발생
|
||||
error_message = (
|
||||
"❌ OCR API에서 작업을 완료하지 못했습니다. 페이지 수를 줄여주세요."
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
result_data = json.loads(raw_result)
|
||||
text = result_data["parsed"]
|
||||
coord = result_data.get("fields")
|
||||
ocr_model = result_data.get("ocr_model", "OCR API(pytesseract)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OCR 처리 중 예외 발생: {e}")
|
||||
raise
|
||||
|
||||
# ✅ 입력 길이 검사
|
||||
log_pipeline_status(request_id, "모델 입력 텍스트 길이 검사 시작")
|
||||
token_count = token_counter(prompt, text)
|
||||
context_length = len(prompt + text)
|
||||
|
||||
# 🔽 로그 기록
|
||||
try:
|
||||
log_user_request(
|
||||
request_info=request_info,
|
||||
endpoint=f"/{prompt_mode}/{mode}",
|
||||
input_filename=filename,
|
||||
model=model,
|
||||
prompt_filename=prompt_filename,
|
||||
context_length=context_length,
|
||||
api_key=api_key,
|
||||
# token_count=token_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
|
||||
|
||||
# ✅ 120K 토큰 초과 검사
|
||||
if token_count > 120000:
|
||||
return post_process(
|
||||
filename,
|
||||
text,
|
||||
f"⚠️ 입력 텍스트가 {token_count} 토큰으로 입력 길이를 초과했습니다. 모델 호출 생략합니다.",
|
||||
coord,
|
||||
ocr_model,
|
||||
"N/A",
|
||||
"N/A",
|
||||
mode,
|
||||
start_time,
|
||||
time.time(),
|
||||
prompt_mode,
|
||||
)
|
||||
|
||||
# 2. 내부 모델 처리 (Ollama)
|
||||
if mode in ("inner", "all", "structured"):
|
||||
if model in inner_models:
|
||||
log_pipeline_status(request_id, "내부 LLM 추론 시작")
|
||||
api_url = model_url_map.get(model)
|
||||
if not api_url:
|
||||
raise ValueError(
|
||||
f"❌ 모델 '{model}'이 로드된 Ollama 서버를 찾을 수 없습니다."
|
||||
)
|
||||
|
||||
generator = OllamaGenerator(model=model, api_url=api_url)
|
||||
|
||||
if mode == "structured":
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.structured_generate,
|
||||
text,
|
||||
prompt,
|
||||
custom_mode,
|
||||
schema_override,
|
||||
)
|
||||
else:
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.generate, text, prompt, custom_mode, prompt_mode
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"내부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. inner_models: {inner_models}"
|
||||
)
|
||||
|
||||
# 3. 외부 모델 처리
|
||||
elif mode in ("outer", "all", "structured"):
|
||||
if model in outer_models:
|
||||
log_pipeline_status(request_id, "외부 LLM 추론 시작")
|
||||
if "claude" in model:
|
||||
generator = ClaudeGenerator(model=model)
|
||||
elif "gemini" in model:
|
||||
generator = GeminiGenerator(model=model)
|
||||
elif "gpt" in model:
|
||||
generator = GptGenerator(model=model)
|
||||
else:
|
||||
raise ValueError(
|
||||
"지원되지 않는 외부 모델입니다. ['gemini', 'claude', 'gpt'] 중 선택하세요."
|
||||
)
|
||||
|
||||
if mode == "structured":
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.structured_generate,
|
||||
text,
|
||||
prompt,
|
||||
custom_mode,
|
||||
schema_override,
|
||||
)
|
||||
else:
|
||||
generated_text, llm_model, llm_url = await asyncio.to_thread(
|
||||
generator.generate, text, prompt, custom_mode, prompt_mode
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"❌ 지원되지 않는 모드입니다. 'inner', 'outer', 'all', 'structured' 중에서 선택하세요. (입력: {mode})"
|
||||
)
|
||||
|
||||
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
|
||||
end_time = time.time()
|
||||
|
||||
# 4. 후처리
|
||||
json_data = post_process(
|
||||
filename,
|
||||
text,
|
||||
generated_text,
|
||||
coord,
|
||||
ocr_model,
|
||||
llm_model,
|
||||
llm_url,
|
||||
mode,
|
||||
start_time,
|
||||
end_time,
|
||||
prompt_mode,
|
||||
)
|
||||
|
||||
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
|
||||
return json_data
|
||||
Reference in New Issue
Block a user