199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
import httpx
|
|
from anthropic import AsyncAnthropic
|
|
from dotenv import load_dotenv
|
|
from google.generativeai import GenerativeModel # gemini
|
|
from openai import AsyncOpenAI
|
|
|
|
from services.prompt import ONLY_GEMINI_PROMPT_TEMPLATE, SUMMARY_PROMPT_TEMPLATE
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
tasks_store = {}
|
|
ask_gpt_name = "gpt-4.1-mini"
|
|
ask_ollama_qwen_name = "qwen3:custom"
|
|
ask_gemini_name = "gemini-2.5-flash"
|
|
ask_claude_name = "claude-3-7-sonnet-latest"
|
|
|
|
|
|
def parse_json_safe(text: str):
|
|
"""응답 텍스트가 JSON 포맷이 아닐 수도 있으니 안전하게 파싱 시도"""
|
|
try:
|
|
# 혹시 ```json ... ``` 형식 포함 시 제거
|
|
if text.startswith("```json"):
|
|
text = text.strip("```json").strip("```").strip()
|
|
return json.loads(text)
|
|
except Exception:
|
|
return {"raw_text": text}
|
|
|
|
|
|
async def ask_gpt4(text: str):
|
|
try:
|
|
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
response = await client.chat.completions.create(
|
|
model=ask_gpt_name,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": SUMMARY_PROMPT_TEMPLATE.format(context=text),
|
|
}
|
|
],
|
|
temperature=0,
|
|
)
|
|
return ask_gpt_name, parse_json_safe(response.choices[0].message.content)
|
|
except Exception as e:
|
|
logger.error(f"ask_gpt4 error: {e}")
|
|
return ask_gpt_name, {"error": str(e)}
|
|
|
|
|
|
def fix_incomplete_json(text: str) -> str:
|
|
open_braces = text.count("{")
|
|
close_braces = text.count("}")
|
|
if open_braces > close_braces:
|
|
text += "}" * (open_braces - close_braces)
|
|
return text
|
|
|
|
|
|
async def ask_ollama_qwen(text: str):
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
res = await client.post(
|
|
"http://172.16.10.176:11434/api/generate",
|
|
json={
|
|
"model": "qwen3:custom",
|
|
"prompt": SUMMARY_PROMPT_TEMPLATE.format(context=text),
|
|
},
|
|
timeout=300,
|
|
)
|
|
raw_text = res.text
|
|
|
|
# 1. <think> 태그 제거
|
|
raw_text = re.sub(r"</?think>", "", raw_text)
|
|
|
|
# 2. 각 줄별 JSON 파싱 시도 (스트림 JSON 형식)
|
|
json_objects = []
|
|
for line in raw_text.splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
obj = json.loads(line)
|
|
json_objects.append(obj)
|
|
except json.JSONDecodeError:
|
|
# 무시하거나 로그 남기기
|
|
pass
|
|
|
|
# 3. 여러 JSON 조각 중 'response' 필드 내용만 합치기 (필요시)
|
|
full_response = "".join(obj.get("response", "") for obj in json_objects)
|
|
|
|
# 4. 합쳐진 response에서 JSON 부분만 추출
|
|
json_match = re.search(r"\{.*\}", full_response, re.DOTALL)
|
|
if json_match:
|
|
json_str = json_match.group(0)
|
|
try:
|
|
parsed_json = json.loads(json_str)
|
|
return "qwen3:custom", parsed_json
|
|
except json.JSONDecodeError:
|
|
return "qwen3:custom", {
|
|
"error": "Invalid JSON in response",
|
|
"raw_text": full_response,
|
|
}
|
|
else:
|
|
return "qwen3:custom", {
|
|
"error": "No JSON found in response",
|
|
"raw_text": full_response,
|
|
}
|
|
|
|
except Exception as e:
|
|
return "qwen3:custom", {"error": str(e)}
|
|
|
|
|
|
async def ask_gemini(text: str):
|
|
try:
|
|
model = GenerativeModel(model_name=ask_gemini_name)
|
|
response = model.generate_content(SUMMARY_PROMPT_TEMPLATE.format(context=text))
|
|
return ask_gemini_name, parse_json_safe(response.text)
|
|
except Exception as e:
|
|
logger.error(f"ask_gemini error: {e}")
|
|
return ask_gemini_name, {"error": str(e)}
|
|
|
|
|
|
async def dialog_ask_gemini(text: str):
|
|
try:
|
|
model = GenerativeModel(model_name=ask_gemini_name)
|
|
response = model.generate_content(
|
|
ONLY_GEMINI_PROMPT_TEMPLATE.format(context=text)
|
|
)
|
|
return ask_gemini_name, parse_json_safe(response.text)
|
|
except Exception as e:
|
|
logger.error(f"ask_gemini error: {e}")
|
|
return ask_gemini_name, {"error": str(e)}
|
|
|
|
|
|
async def ask_claude(text: str):
|
|
try:
|
|
client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
|
response = await client.messages.create(
|
|
model=ask_claude_name,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": SUMMARY_PROMPT_TEMPLATE.format(context=text),
|
|
}
|
|
],
|
|
max_tokens=12800,
|
|
stream=False,
|
|
)
|
|
raw = response.content[0].text
|
|
return ask_claude_name, parse_json_safe(raw)
|
|
except Exception as e:
|
|
logger.error(f"ask_claude error: {e}")
|
|
return ask_claude_name, {"error": str(e)}
|
|
|
|
|
|
async def total_summation(text: str) -> dict:
|
|
tasks = [ask_gpt4(text), ask_ollama_qwen(text), ask_gemini(text), ask_claude(text)]
|
|
results = await asyncio.gather(*tasks)
|
|
return dict(results)
|
|
|
|
|
|
async def run_model_task(model_func, text, key, task_id):
|
|
try:
|
|
model_name, result = await model_func(text)
|
|
tasks_store[task_id][key] = {
|
|
"status": "completed",
|
|
"model_name": model_name,
|
|
"result": result,
|
|
}
|
|
except Exception as e:
|
|
tasks_store[task_id][key] = {
|
|
"status": "failed",
|
|
"error": str(e),
|
|
}
|
|
|
|
|
|
async def run_all_models(text: str, task_id: str):
|
|
# 초기 상태 세팅
|
|
tasks_store[task_id] = {
|
|
"gpt4": {"status": "pending", "result": None},
|
|
"qwen3": {"status": "pending", "result": None},
|
|
"gemini": {"status": "pending", "result": None},
|
|
"claude": {"status": "pending", "result": None},
|
|
"finished": False,
|
|
}
|
|
|
|
await asyncio.gather(
|
|
run_model_task(ask_gpt4, text, "gpt4", task_id),
|
|
run_model_task(ask_ollama_qwen, text, "qwen3", task_id),
|
|
run_model_task(ask_gemini, text, "gemini", task_id),
|
|
run_model_task(ask_claude, text, "claude", task_id),
|
|
)
|
|
|
|
tasks_store[task_id]["finished"] = True
|