296 lines
9.9 KiB
Python
296 lines
9.9 KiB
Python
import asyncio
|
|
import io
|
|
import json
|
|
import os
|
|
from typing import Optional
|
|
|
|
from config.setting import (
|
|
DEFAULT_PROMPT_PATH,
|
|
PGN_REDIS_DB,
|
|
PGN_REDIS_HOST,
|
|
PGN_REDIS_PORT,
|
|
)
|
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
from redis import Redis
|
|
from services.gemini_service import GeminiService
|
|
from services.ollama_service import OllamaService
|
|
from utils.checking_keys import create_key, get_api_key
|
|
from utils.text_processor import post_process
|
|
|
|
# Redis 클라이언트
|
|
redis_client = Redis(
|
|
host=PGN_REDIS_HOST, port=PGN_REDIS_PORT, db=PGN_REDIS_DB, decode_responses=True
|
|
)
|
|
|
|
router = APIRouter(prefix="/llm", tags=["target model"])
|
|
|
|
|
|
def clone_upload_file(upload_file: UploadFile) -> io.BytesIO:
|
|
"""UploadFile을 메모리 내에서 복제하여 백그라운드 작업에 전달합니다."""
|
|
file_content = upload_file.file.read()
|
|
upload_file.file.seek(0) # 원본 파일 포인터를 재설정
|
|
return io.BytesIO(file_content)
|
|
|
|
|
|
async def run_gemini_background_task(
|
|
result_id: str,
|
|
input_file_name: str,
|
|
input_file_clone: io.BytesIO,
|
|
model: str,
|
|
source_dir: Optional[str],
|
|
):
|
|
"""Gemini API 호출 및 결과 저장을 처리하는 백그라운드 작업"""
|
|
try:
|
|
# 1. Read the default prompt
|
|
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
|
|
default_prompt = f.read()
|
|
|
|
# 2. Read and parse the cloned input_file
|
|
input_data = input_file_clone.read()
|
|
input_json = json.loads(input_data)
|
|
parsed_value = input_json.get("parsed", "")
|
|
|
|
# 3. Combine prompt and parsed value
|
|
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
|
|
|
|
# 4. Call Gemini API
|
|
gemini_service = GeminiService()
|
|
gemini_response = await gemini_service.generate_content(
|
|
[combined_prompt], model=model
|
|
)
|
|
|
|
# 5. Post-process the response
|
|
processed_result = post_process(input_json, gemini_response, model)
|
|
|
|
# 6. Save the result to Redis
|
|
redis_key = f"pipeline_result:{result_id}"
|
|
redis_client.set(
|
|
redis_key, json.dumps(processed_result, ensure_ascii=False), ex=3600
|
|
)
|
|
|
|
# 7. Save the result to a local file
|
|
if source_dir:
|
|
output_dir = os.path.join("result", f"gemini-{source_dir}")
|
|
else:
|
|
output_dir = "result"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_filename = f"{input_file_name}"
|
|
output_path = os.path.join(output_dir, output_filename)
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(processed_result, f, ensure_ascii=False, indent=4)
|
|
|
|
except Exception as e:
|
|
# 에러 발생 시 Redis에 에러 메시지 저장
|
|
redis_key = f"pipeline_result:{result_id}"
|
|
redis_client.set(redis_key, json.dumps({"error": str(e)}), ex=3600)
|
|
|
|
|
|
@router.post(
|
|
"/gemini",
|
|
summary="해외 문서 테스트용 (백그라운드)",
|
|
)
|
|
async def costs_gemini_background(
|
|
request_info: Request,
|
|
input_file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default="gemini-2.5-flash"),
|
|
source_dir: Optional[str] = Form(default=None),
|
|
api_key: str = Depends(get_api_key),
|
|
):
|
|
request_id = create_key(request_info.client.host)
|
|
result_id = create_key(request_id)
|
|
|
|
# 파일 복제
|
|
input_file_clone = clone_upload_file(input_file)
|
|
|
|
# 백그라운드 작업 시작
|
|
asyncio.create_task(
|
|
run_gemini_background_task(
|
|
result_id=result_id,
|
|
input_file_name=input_file.filename,
|
|
input_file_clone=input_file_clone,
|
|
model=model,
|
|
source_dir=source_dir,
|
|
)
|
|
)
|
|
|
|
# 요청 ID와 결과 ID를 매핑하여 Redis에 저장
|
|
redis_client.hset("pipeline_result_mapping", request_id, result_id)
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"message": "문서 처리 작업이 백그라운드에서 시작되었습니다.",
|
|
"request_id": request_id,
|
|
"status_check_url": f"/costs/progress/{request_id}",
|
|
}
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/gemma3",
|
|
summary="Gemma3 모델 테스트용 (백그라운드)",
|
|
)
|
|
async def costs_gemma3_background(
|
|
request_info: Request,
|
|
input_file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default="gemma3:27b"),
|
|
source_dir: Optional[str] = Form(default=None),
|
|
api_key: str = Depends(get_api_key),
|
|
):
|
|
request_id = create_key(request_info.client.host)
|
|
result_id = create_key(request_id)
|
|
|
|
# 파일 복제
|
|
input_file_clone = clone_upload_file(input_file)
|
|
|
|
# 백그라운드 작업 시작
|
|
asyncio.create_task(
|
|
run_gemma3_background_task(
|
|
result_id=result_id,
|
|
input_file_name=input_file.filename,
|
|
input_file_clone=input_file_clone,
|
|
model=model,
|
|
source_dir=source_dir,
|
|
)
|
|
)
|
|
|
|
# 요청 ID와 결과 ID를 매핑하여 Redis에 저장
|
|
redis_client.hset("pipeline_result_mapping", request_id, result_id)
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"message": "Gemma3 문서 처리 작업이 백그라운드에서 시작되었습니다.",
|
|
"request_id": request_id,
|
|
"status_check_url": f"/costs/progress/{request_id}",
|
|
}
|
|
)
|
|
|
|
|
|
@router.get("/progress/{request_id}", summary="작업 진행 상태 및 결과 확인")
|
|
async def get_progress(request_id: str, api_key: str = Depends(get_api_key)):
|
|
"""
|
|
request_id를 사용하여 작업의 진행 상태를 확인하고, 완료 시 결과를 반환합니다.
|
|
"""
|
|
result_id = redis_client.hget("pipeline_result_mapping", request_id)
|
|
if not result_id:
|
|
raise HTTPException(status_code=404, detail="잘못된 요청 ID입니다.")
|
|
|
|
redis_key = f"pipeline_result:{result_id}"
|
|
result = redis_client.get(redis_key)
|
|
|
|
if result:
|
|
# 결과가 Redis에 있으면, JSON으로 파싱하여 반환
|
|
return JSONResponse(content=json.loads(result))
|
|
else:
|
|
# 결과가 아직 없으면, 처리 중임을 알림
|
|
return JSONResponse(
|
|
content={"status": "processing", "message": "작업이 아직 처리 중입니다."},
|
|
status_code=202,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/gemma3/sync",
|
|
summary="Gemma3 모델 동기 처리",
|
|
)
|
|
async def costs_gemma3_sync(
|
|
request_info: Request,
|
|
input_file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default="gemma3:27b"),
|
|
source_dir: Optional[str] = Form(default=None),
|
|
api_key: str = Depends(get_api_key),
|
|
):
|
|
"""Ollama 동기 처리 및 결과 반환"""
|
|
try:
|
|
# 1. Read the default prompt
|
|
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
|
|
default_prompt = f.read()
|
|
|
|
# 2. Read and parse the input_file
|
|
input_data = await input_file.read()
|
|
input_json = json.loads(input_data)
|
|
parsed_value = input_json.get("parsed", "")
|
|
|
|
# 3. Combine prompt and parsed value
|
|
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
|
|
|
|
# 4. Call Gemma API
|
|
ollama_service = OllamaService()
|
|
gemma_response = await ollama_service.generate_content(
|
|
combined_prompt, model=model
|
|
)
|
|
|
|
# 5. Post-process the response
|
|
processed_result = post_process(input_json, gemma_response, model)
|
|
|
|
# 6. Save the result to a local file on the server
|
|
if source_dir:
|
|
output_dir = os.path.join("result", f"gemma3-{source_dir}")
|
|
else:
|
|
output_dir = "result"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_filename = f"{input_file.filename}"
|
|
output_path = os.path.join(output_dir, output_filename)
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(processed_result, f, ensure_ascii=False, indent=4)
|
|
|
|
return JSONResponse(content=processed_result)
|
|
|
|
except Exception as e:
|
|
# Log the exception for debugging
|
|
print(f"Error in gemma3/sync endpoint: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def run_gemma3_background_task(
|
|
result_id: str,
|
|
input_file_name: str,
|
|
input_file_clone: io.BytesIO,
|
|
model: str,
|
|
source_dir: Optional[str],
|
|
):
|
|
"""Gemma API 호출 및 결과 저장을 처리하는 백그라운드 작업"""
|
|
try:
|
|
# 1. Read the default prompt
|
|
with open(DEFAULT_PROMPT_PATH, "r", encoding="utf-8") as f:
|
|
default_prompt = f.read()
|
|
|
|
# 2. Read and parse the cloned input_file
|
|
input_data = input_file_clone.read()
|
|
input_json = json.loads(input_data)
|
|
parsed_value = input_json.get("parsed", "")
|
|
|
|
# 3. Combine prompt and parsed value
|
|
combined_prompt = f"{default_prompt}\n\n{parsed_value}"
|
|
|
|
# 4. Call Gemma API
|
|
ollama_service = OllamaService()
|
|
gemma_response = await ollama_service.generate_content(
|
|
combined_prompt, model=model
|
|
)
|
|
|
|
# 5. Post-process the response
|
|
processed_result = post_process(input_json, gemma_response, model)
|
|
|
|
# 6. Save the result to Redis
|
|
redis_key = f"pipeline_result:{result_id}"
|
|
redis_client.set(
|
|
redis_key, json.dumps(processed_result, ensure_ascii=False), ex=3600
|
|
)
|
|
|
|
# 7. Save the result to a local file
|
|
if source_dir:
|
|
output_dir = os.path.join("result", f"gemma3-{source_dir}")
|
|
else:
|
|
output_dir = "result"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_filename = f"{input_file_name}"
|
|
output_path = os.path.join(output_dir, output_filename)
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(processed_result, f, ensure_ascii=False, indent=4)
|
|
|
|
except Exception as e:
|
|
# 에러 발생 시 Redis에 에러 메시지 저장
|
|
redis_key = f"pipeline_result:{result_id}"
|
|
redis_client.set(redis_key, json.dumps({"error": str(e)}), ex=3600)
|