Files
llm_asycio/workspace/main.py
2025-01-08 15:06:58 +09:00

165 lines
6.6 KiB
Python
Executable File

import os
import pandas as pd
from fastapi import FastAPI, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse
from redis import Redis
from rq import Queue
from vllm import LLM, SamplingParams
import logging
import gc
import torch
from tqdm import tqdm
import sys
sys.path.append("/workspace/LLM_asyncio")
from template import LLMInference
app = FastAPI()
# Redis 설정
redis_conn = Redis(host="redis-server", port=6379, decode_responses=True)
queue = Queue("model_tasks", connection=redis_conn)
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
@app.post("/start-inference/")
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
# 파일 형식 확인
if not input_csv.filename.endswith(".csv"):
return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400)
if not model_list_txt.filename.endswith(".txt"):
return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400)
# 파일 저장
file_path = f"uploaded/{input_csv.filename}"
model_list_path = f"uploaded/{model_list_txt.filename}"
os.makedirs("uploaded", exist_ok=True)
with open(file_path, "wb") as f:
f.write(await input_csv.read())
with open(model_list_path, "wb") as f:
f.write(await model_list_txt.read())
logger.info(f"Files uploaded: {file_path}, {model_list_path}")
# 작업 큐에 추가
job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800)
logger.info(f"Job enqueued: {job.id}")
return {"job_id": job.id, "status": "queued"}
def chat_formating(input_sentence: str, model_name: str):
try:
if "llama" in model_name:
hidden_prompt = LLMInference.llama_template()
if "gemma" in model_name:
hidden_prompt = LLMInference.gemma_template()
if "exaone" in model_name:
hidden_prompt = LLMInference.exaone_template()
formated_sentence = hidden_prompt.format(input_sent=input_sentence)
return formated_sentence
except Exception as e:
logger.error(f"Not formatting input sentence: {e}")
return input_sentence
# 모델 추론 함수
def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
try:
logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}")
# 모델 리스트 읽기
with open(model_list_path, "r") as f:
model_list = [line.strip() for line in f.readlines()]
if not model_list:
raise ValueError("The model list file is empty.")
# CSV 읽기
df = pd.read_csv(file_path, encoding="euc-kr")
if "input" not in df.columns:
raise ValueError("The input CSV must contain a column named 'input'.")
# 에러 발생한 행 저장용 DataFrame 초기화
error_rows = pd.DataFrame(columns=df.columns)
# 각 모델로 추론
for model in model_list:
model_name = model.split("/")[-1]
try:
logger.info(f"Loading model: {model}")
llm = LLM(model)
torch.cuda.empty_cache()
logger.info(f"Model {model} loaded successfully.")
except Exception as e:
logger.error(f"Error loading model {model}: {e}")
continue
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
# 추론 수행
responses = []
for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"):
batch = df.iloc[i:i+batch_size]
batch_responses = []
for _, row in batch.iterrows():
try:
original_input = row["input"]
formating_input = chat_formating(input_sentence=row["input"], model_name=model_name.lower())
response = llm.generate(formating_input, sampling_params)[0].outputs[0].text.strip()
logger.info(f"Model: {model}, Input: {original_input}, Output: {response}")
batch_responses.append(response)
except Exception as e:
logger.error(f"Error during inference for model {model}, row {row.name}: {e}")
error_rows = pd.concat([error_rows, pd.DataFrame([row])], ignore_index=True)
batch_responses.append(None)
responses.extend(batch_responses)
# 결과 추가
df[model_name] = responses
del llm
torch.cuda.empty_cache()
gc.collect()
# 결과 저장
output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
os.makedirs("processed", exist_ok=True)
df.to_csv(output_path, index=False, encoding="utf-8")
logger.info(f"Inference completed. Result saved to: {output_path}")
# 에러 행 저장
if not error_rows.empty:
error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
os.makedirs("errors", exist_ok=True)
error_rows.to_csv(error_path, index=False, encoding="utf-8")
logger.info(f"Error rows saved to: {error_path}")
return output_path
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
# 결과 파일 다운로드
@app.get("/download-latest", response_class=FileResponse)
def download_latest_file():
try:
# processed 디렉토리 경로
directory = "processed"
csv_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".csv")]
if not csv_files:
return JSONResponse(content={"error": "No CSV files found in the processed directory."}, status_code=404)
latest_file = max(csv_files, key=os.path.getctime)
logger.info(f"Downloading latest file: {latest_file}")
return FileResponse(latest_file, media_type="application/csv", filename=os.path.basename(latest_file))
except Exception as e:
logger.error(f"Error during file download: {e}")
return JSONResponse(content={"error": "Failed to download the latest file."}, status_code=500)