Files
llm_macro/workspace/routers/costs_router.py
2025-10-30 10:32:31 +09:00

296 lines
9.9 KiB
Python

import asyncio
import io
import json
import os
from typing import Optional
from config.setting import (
DEFAULT_PROMPT_PATH,
PGN_REDIS_DB,
PGN_REDIS_HOST,
PGN_REDIS_PORT,
)
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse
from redis import Redis
from services.gemini_service import GeminiService
from services.ollama_service import OllamaService
from utils.checking_keys import create_key, get_api_key
from utils.text_processor import post_process
# Redis 클라이언트
redis_client = Redis(
host=PGN_REDIS_HOST, port=PGN_REDIS_PORT, db=PGN_REDIS_DB, decode_responses=True
)
router = APIRouter(prefix="/llm", tags=["target model"])
def clone_upload_file(upload_file: UploadFile) -> io.BytesIO:
"""UploadFile을 메모리 내에서 복제하여 백그라운드 작업에 전달합니다."""
file_content = upload_file.file.read()
upload_file.file.seek(0) # 원본 파일 포인터를 재설정
return io.BytesIO(file_content)
async def run_gemini_background_task(
result_id: str,
input_file_name: str,
input_file_clone: io.BytesIO,
model: str,
source_dir: Optional[str],
):
"""Gemini API 호출 및 결과 저장을 처리하는 백그라운드 작업"""
try:
# 1. Read the default prompt
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
default_prompt = f.read()
# 2. Read and parse the cloned input_file
input_data = input_file_clone.read()
input_json = json.loads(input_data)
parsed_value = input_json.get("parsed", "")
# 3. Combine prompt and parsed value
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
# 4. Call Gemini API
gemini_service = GeminiService()
gemini_response = await gemini_service.generate_content(
[combined_prompt], model=model
)
# 5. Post-process the response
processed_result = post_process(input_json, gemini_response, model)
# 6. Save the result to Redis
redis_key = f"pipeline_result:{result_id}"
redis_client.set(
redis_key, json.dumps(processed_result, ensure_ascii=False), ex=3600
)
# 7. Save the result to a local file
if source_dir:
output_dir = os.path.join("result", f"gemini-{source_dir}")
else:
output_dir = "result"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"{input_file_name}"
output_path = os.path.join(output_dir, output_filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(processed_result, f, ensure_ascii=False, indent=4)
except Exception as e:
# 에러 발생 시 Redis에 에러 메시지 저장
redis_key = f"pipeline_result:{result_id}"
redis_client.set(redis_key, json.dumps({"error": str(e)}), ex=3600)
@router.post(
"/gemini",
summary="해외 문서 테스트용 (백그라운드)",
)
async def costs_gemini_background(
request_info: Request,
input_file: UploadFile = File(...),
model: Optional[str] = Form(default="gemini-2.5-flash"),
source_dir: Optional[str] = Form(default=None),
api_key: str = Depends(get_api_key),
):
request_id = create_key(request_info.client.host)
result_id = create_key(request_id)
# 파일 복제
input_file_clone = clone_upload_file(input_file)
# 백그라운드 작업 시작
asyncio.create_task(
run_gemini_background_task(
result_id=result_id,
input_file_name=input_file.filename,
input_file_clone=input_file_clone,
model=model,
source_dir=source_dir,
)
)
# 요청 ID와 결과 ID를 매핑하여 Redis에 저장
redis_client.hset("pipeline_result_mapping", request_id, result_id)
return JSONResponse(
content={
"message": "문서 처리 작업이 백그라운드에서 시작되었습니다.",
"request_id": request_id,
"status_check_url": f"/costs/progress/{request_id}",
}
)
@router.post(
"/gemma3",
summary="Gemma3 모델 테스트용 (백그라운드)",
)
async def costs_gemma3_background(
request_info: Request,
input_file: UploadFile = File(...),
model: Optional[str] = Form(default="gemma3:27b"),
source_dir: Optional[str] = Form(default=None),
api_key: str = Depends(get_api_key),
):
request_id = create_key(request_info.client.host)
result_id = create_key(request_id)
# 파일 복제
input_file_clone = clone_upload_file(input_file)
# 백그라운드 작업 시작
asyncio.create_task(
run_gemma3_background_task(
result_id=result_id,
input_file_name=input_file.filename,
input_file_clone=input_file_clone,
model=model,
source_dir=source_dir,
)
)
# 요청 ID와 결과 ID를 매핑하여 Redis에 저장
redis_client.hset("pipeline_result_mapping", request_id, result_id)
return JSONResponse(
content={
"message": "Gemma3 문서 처리 작업이 백그라운드에서 시작되었습니다.",
"request_id": request_id,
"status_check_url": f"/costs/progress/{request_id}",
}
)
@router.get("/progress/{request_id}", summary="작업 진행 상태 및 결과 확인")
async def get_progress(request_id: str, api_key: str = Depends(get_api_key)):
"""
request_id를 사용하여 작업의 진행 상태를 확인하고, 완료 시 결과를 반환합니다.
"""
result_id = redis_client.hget("pipeline_result_mapping", request_id)
if not result_id:
raise HTTPException(status_code=404, detail="잘못된 요청 ID입니다.")
redis_key = f"pipeline_result:{result_id}"
result = redis_client.get(redis_key)
if result:
# 결과가 Redis에 있으면, JSON으로 파싱하여 반환
return JSONResponse(content=json.loads(result))
else:
# 결과가 아직 없으면, 처리 중임을 알림
return JSONResponse(
content={"status": "processing", "message": "작업이 아직 처리 중입니다."},
status_code=202,
)
@router.post(
"/gemma3/sync",
summary="Gemma3 모델 동기 처리",
)
async def costs_gemma3_sync(
request_info: Request,
input_file: UploadFile = File(...),
model: Optional[str] = Form(default="gemma3:27b"),
source_dir: Optional[str] = Form(default=None),
api_key: str = Depends(get_api_key),
):
"""Ollama 동기 처리 및 결과 반환"""
try:
# 1. Read the default prompt
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
default_prompt = f.read()
# 2. Read and parse the input_file
input_data = await input_file.read()
input_json = json.loads(input_data)
parsed_value = input_json.get("parsed", "")
# 3. Combine prompt and parsed value
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
# 4. Call Gemma API
ollama_service = OllamaService()
gemma_response = await ollama_service.generate_content(
combined_prompt, model=model
)
# 5. Post-process the response
processed_result = post_process(input_json, gemma_response, model)
# 6. Save the result to a local file on the server
if source_dir:
output_dir = os.path.join("result", f"gemma3-{source_dir}")
else:
output_dir = "result"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"{input_file.filename}"
output_path = os.path.join(output_dir, output_filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(processed_result, f, ensure_ascii=False, indent=4)
return JSONResponse(content=processed_result)
except Exception as e:
# Log the exception for debugging
print(f"Error in gemma3/sync endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def run_gemma3_background_task(
result_id: str,
input_file_name: str,
input_file_clone: io.BytesIO,
model: str,
source_dir: Optional[str],
):
"""Gemma API 호출 및 결과 저장을 처리하는 백그라운드 작업"""
try:
# 1. Read the default prompt
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
default_prompt = f.read()
# 2. Read and parse the cloned input_file
input_data = input_file_clone.read()
input_json = json.loads(input_data)
parsed_value = input_json.get("parsed", "")
# 3. Combine prompt and parsed value
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
# 4. Call Gemma API
ollama_service = OllamaService()
gemma_response = await ollama_service.generate_content(
combined_prompt, model=model
)
# 5. Post-process the response
processed_result = post_process(input_json, gemma_response, model)
# 6. Save the result to Redis
redis_key = f"pipeline_result:{result_id}"
redis_client.set(
redis_key, json.dumps(processed_result, ensure_ascii=False), ex=3600
)
# 7. Save the result to a local file
if source_dir:
output_dir = os.path.join("result", f"gemma3-{source_dir}")
else:
output_dir = "result"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"{input_file_name}"
output_path = os.path.join(output_dir, output_filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(processed_result, f, ensure_ascii=False, indent=4)
except Exception as e:
# 에러 발생 시 Redis에 에러 메시지 저장
redis_key = f"pipeline_result:{result_id}"
redis_client.set(redis_key, json.dumps({"error": str(e)}), ex=3600)