Files
llm-gateway-sub-backup/workspace/services/pipeline_runner.py
2025-08-11 18:56:38 +09:00

293 lines
11 KiB
Python

import asyncio
import json
import logging
import time
from typing import Dict, List, Literal, Optional
import httpx
import redis
from config.setting import OCR_API_URL, OCR_REDIS_DB, OCR_REDIS_HOST, OCR_REDIS_PORT
from utils.checking_files import token_counter
from utils.image_converter import prepare_images_from_file
from utils.logging_utils import log_pipeline_status, log_user_request
from utils.text_generator import (
ClaudeGenerator,
GeminiGenerator,
GptGenerator,
OllamaGenerator,
)
from utils.text_processor import post_process
logger = logging.getLogger(__name__)
# Redis 클라이언트 생성 (Celery 결과용 DB=1)
redis_client = redis.Redis(
host=OCR_REDIS_HOST,
port=OCR_REDIS_PORT,
db=OCR_REDIS_DB,
decode_responses=True,
)
class PipelineRunner:
@staticmethod
async def run_pipeline(
request_info: str, # ✅ 추가
request_id: str,
file_path: str,
filename: str,
prompt: str,
prompt_filename: str, # ✅ 추가
custom_mode: bool,
mode: str,
model: str,
inner_models: List[str],
outer_models: List[str],
model_url_map: Dict[str, str],
api_key: str,
schema_override: Optional[dict] = None,
prompt_mode: Literal["general", "extract"] = "extract",
):
start_time = time.time()
if mode == "multimodal":
# 모델 유효성
if model not in outer_models:
raise ValueError(
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
)
if not (("gpt" in model) or ("gemini" in model)):
raise ValueError("멀티모달 E2E는 gpt 계열만 지원합니다.")
# 입력 파일 → 이미지 바이트 리스트 준비
images = await prepare_images_from_file(file_path, filename)
# 요청 로깅(텍스트가 없으므로 prompt 길이만)
context_length = len(prompt)
try:
log_user_request(
request_info=request_info,
endpoint=f"/{prompt_mode}/{mode}",
input_filename=filename,
model=model,
prompt_filename=prompt_filename,
context_length=context_length,
api_key=api_key,
)
except Exception as e:
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
# 멀티모달 LLM 호출
log_pipeline_status(request_id, "멀티모달 LLM 추론 시작")
if "gpt" in model:
generator = GptGenerator(model=model)
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.generate_multimodal, images, prompt, schema_override
)
elif "gemini" in model:
generator = GeminiGenerator(model=model)
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.generate_multimodal, images, prompt, schema_override
)
end_time = time.time()
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
# 멀티모달은 OCR 텍스트/좌표 없음
text = ""
coord = None
ocr_model = "bypass(multimodal)"
json_data = post_process(
filename,
text,
generated_text,
coord,
ocr_model,
llm_model,
llm_url,
mode,
start_time,
end_time,
prompt_mode,
)
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
return json_data
try:
# OCR API 요청
log_pipeline_status(request_id, "OCR API 호출 시작")
async with httpx.AsyncClient() as client:
# ✅ presigned URL을 OCR API로 전달
ocr_resp = await client.post(
OCR_API_URL,
json=[
{
"file_url": file_path, # presigned URL
"filename": filename,
}
],
timeout=None,
)
ocr_resp.raise_for_status()
# OCR API 응답에서 task_id 추출
task_ids_json = ocr_resp.json()
print(f"[DEBUG] OCR API 응답: {task_ids_json}")
task_ids = [
item.get("task_id") for item in task_ids_json.get("results", [])
]
if not task_ids:
raise ValueError("❌ OCR API에서 유효한 task_id를 받지 못했습니다.")
task_id = task_ids[0]
# Redis에서 결과를 5초 간격으로 최대 10회 폴링
raw_result = None
for attempt in range(10): # 최대 10회 시도
redis_key = f"ocr_result:{task_id}"
raw_result = redis_client.get(redis_key)
if raw_result:
logger.info(
f"✅ Redis에서 task_id '{task_id}'에 대한 OCR 결과를 찾았습니다."
)
break
await asyncio.sleep(5)
if not raw_result: # 결과가 없으면 예외 발생
error_message = (
"❌ OCR API에서 작업을 완료하지 못했습니다. 페이지 수를 줄여주세요."
)
logger.error(error_message)
raise ValueError(error_message)
result_data = json.loads(raw_result)
text = result_data["parsed"]
coord = result_data.get("fields")
ocr_model = result_data.get("ocr_model", "OCR API(pytesseract)")
except Exception as e:
logger.error(f"❌ OCR 처리 중 예외 발생: {e}")
raise
# ✅ 입력 길이 검사
log_pipeline_status(request_id, "모델 입력 텍스트 길이 검사 시작")
token_count = token_counter(prompt, text)
context_length = len(prompt + text)
# 🔽 로그 기록
try:
log_user_request(
request_info=request_info,
endpoint=f"/{prompt_mode}/{mode}",
input_filename=filename,
model=model,
prompt_filename=prompt_filename,
context_length=context_length,
api_key=api_key,
# token_count=token_count,
)
except Exception as e:
logger.info(f"Failed to log '/{prompt_mode}/{mode}' request: {e}")
# ✅ 120K 토큰 초과 검사
if token_count > 120000:
return post_process(
filename,
text,
f"⚠️ 입력 텍스트가 {token_count} 토큰으로 입력 길이를 초과했습니다. 모델 호출 생략합니다.",
coord,
ocr_model,
"N/A",
"N/A",
mode,
start_time,
time.time(),
prompt_mode,
)
# 2. 내부 모델 처리 (Ollama)
if mode in ("inner", "all", "structured"):
if model in inner_models:
log_pipeline_status(request_id, "내부 LLM 추론 시작")
api_url = model_url_map.get(model)
if not api_url:
raise ValueError(
f"❌ 모델 '{model}'이 로드된 Ollama 서버를 찾을 수 없습니다."
)
generator = OllamaGenerator(model=model, api_url=api_url)
if mode == "structured":
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.structured_generate,
text,
prompt,
custom_mode,
schema_override,
)
else:
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.generate, text, prompt, custom_mode, prompt_mode
)
else:
raise ValueError(
f"내부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. inner_models: {inner_models}"
)
# 3. 외부 모델 처리
elif mode in ("outer", "all", "structured"):
if model in outer_models:
log_pipeline_status(request_id, "외부 LLM 추론 시작")
if "claude" in model:
generator = ClaudeGenerator(model=model)
elif "gemini" in model:
generator = GeminiGenerator(model=model)
elif "gpt" in model:
generator = GptGenerator(model=model)
else:
raise ValueError(
"지원되지 않는 외부 모델입니다. ['gemini', 'claude', 'gpt'] 중 선택하세요."
)
if mode == "structured":
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.structured_generate,
text,
prompt,
custom_mode,
schema_override,
)
else:
generated_text, llm_model, llm_url = await asyncio.to_thread(
generator.generate, text, prompt, custom_mode, prompt_mode
)
else:
raise ValueError(
f"외부 모델 리스트에 '{model}'이 포함되어 있지 않습니다. outer_models: {outer_models}"
)
else:
raise ValueError(
f"❌ 지원되지 않는 모드입니다. 'inner', 'outer', 'all', 'structured' 중에서 선택하세요. (입력: {mode})"
)
log_pipeline_status(request_id, "LLM 추론 완료 및 후처리 시작")
end_time = time.time()
# 4. 후처리
json_data = post_process(
filename,
text,
generated_text,
coord,
ocr_model,
llm_model,
llm_url,
mode,
start_time,
end_time,
prompt_mode,
)
log_pipeline_status(request_id, "후처리 완료 및 결과 반환")
return json_data