Files
llm_asycio/workspace/main.py

205 lines
8.1 KiB
Python
Executable File

import os
import pandas as pd
from fastapi import FastAPI, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse
import shutil
from redis import Redis
from rq import Queue
from vllm import LLM, SamplingParams
import logging
import gc
import torch
from tqdm import tqdm
from template import LLMInference
import pickle
app = FastAPI()
# Redis 설정
redis_conn = Redis(host="redis-server", port=6379, decode_responses=True)
queue = Queue("model_tasks", connection=redis_conn)
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
@app.post("/start-inference/")
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
# 파일 형식 확인
if not input_csv.filename.endswith(".csv"):
return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400)
if not model_list_txt.filename.endswith(".txt"):
return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400)
# 파일 저장
file_path = f"uploaded/{input_csv.filename}"
model_list_path = f"uploaded/{model_list_txt.filename}"
os.makedirs("uploaded", exist_ok=True)
with open(file_path, "wb") as f:
f.write(await input_csv.read())
with open(model_list_path, "wb") as f:
f.write(await model_list_txt.read())
logger.info(f"Files uploaded: {file_path}, {model_list_path}")
# 작업 큐에 추가
job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800)
logger.info(f"Job enqueued: {job.id}")
return {"job_id": job.id, "status": "queued"}
# CSV를 PKL로 변환
def save_to_pkl(dataframe: pd.DataFrame, output_path: str):
pkl_path = output_path.replace(".csv", ".pkl")
with open(pkl_path, "wb") as pkl_file:
pickle.dump(dataframe, pkl_file)
logger.info(f"Data saved as PKL: {pkl_path}")
return pkl_path
# PKL을 CSV로 변환
def convert_pkl_to_csv(pkl_path: str, csv_path: str):
with open(pkl_path, "rb") as pkl_file:
dataframe = pickle.load(pkl_file)
dataframe.to_csv(csv_path, index=False, encoding="utf-8")
logger.info(f"PKL converted to CSV: {csv_path}")
return csv_path
# CSV 파일 삭제 작업
def delete_csv_file(file_path: str):
try:
os.remove(file_path)
logger.info(f"CSV file deleted: {file_path}")
except Exception as e:
logger.error(f"Error deleting CSV file: {e}")
def chat_formating(input_sentence: str, model_name: str):
try:
if "llama" in model_name:
hidden_prompt = LLMInference.llama_template()
if "gemma" in model_name:
hidden_prompt = LLMInference.gemma_template()
if "exaone" in model_name:
hidden_prompt = LLMInference.exaone_template()
formated_sentence = hidden_prompt.format(input_sent=input_sentence)
return formated_sentence
except Exception as e:
logger.error(f"Not formatting input sentence: {e}")
return input_sentence
# 모델 추론 함수
def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
try:
logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}")
# 모델 리스트 읽기
with open(model_list_path, "r") as f:
model_list = [line.strip() for line in f.readlines()]
if not model_list:
raise ValueError("The model list file is empty.")
# CSV 읽기
try:
df = pd.read_csv(file_path, encoding="euc-kr")
except Exception as e:
df = pd.read_csv(file_path, encoding="utf-8")
logger.info(f"Failed to read {file_path} as {e}")
if "input" not in df.columns:
raise ValueError("The input CSV must contain a column named 'input'.")
# 에러 발생한 행 저장용 DataFrame 초기화
error_rows = pd.DataFrame(columns=df.columns)
# 각 모델로 추론
for model in model_list:
model_name = model.split("/")[-1]
try:
logger.info(f"Loading model: {model}")
llm = LLM(model)
torch.cuda.empty_cache()
logger.info(f"Model {model} loaded successfully.")
except Exception as e:
logger.error(f"Error loading model {model}: {e}")
continue
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
# 추론 수행
responses = []
for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"):
batch = df.iloc[i:i+batch_size]
batch_responses = []
for _, row in batch.iterrows():
try:
original_input = row["input"]
formating_input = chat_formating(input_sentence=row["input"], model_name=model_name.lower())
response = llm.generate(formating_input, sampling_params)[0].outputs[0].text.strip()
logger.info(f"Model: {model}, Input: {original_input}, Output: {response}")
batch_responses.append(response)
except Exception as e:
logger.error(f"Error during inference for model {model}, row {row.name}: {e}")
error_rows = pd.concat([error_rows, pd.DataFrame([row])], ignore_index=True)
batch_responses.append(None)
responses.extend(batch_responses)
# 결과 추가
df[model_name] = responses
del llm
torch.cuda.empty_cache()
gc.collect()
# 결과 저장
output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
os.makedirs("processed", exist_ok=True)
# df.to_csv(output_path, index=False, encoding="utf-8")
save_to_pkl(df, output_path)
logger.info(f"Inference completed. Result saved to: {output_path}")
# 에러 행 저장
if not error_rows.empty:
error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
os.makedirs("errors", exist_ok=True)
error_rows.to_csv(error_path, index=False, encoding="utf-8")
logger.info(f"Error rows saved to: {error_path}")
return output_path
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
# PKL에서 CSV로 변환하여 다운로드 후 삭제
@app.get("/download-latest-result", response_class=FileResponse)
def download_latest_result(background_tasks: BackgroundTasks):
try:
# PKL 파일 저장 디렉토리
processed_dir = "processed"
if not os.path.exists(processed_dir):
return JSONResponse(content={"error": "Processed directory not found."}, status_code=404)
# 디렉토리 내 가장 최근에 저장된 PKL 파일 찾기
pkl_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith(".pkl")]
if not pkl_files:
return JSONResponse(content={"error": "No PKL files found in the processed directory."}, status_code=404)
latest_pkl = max(pkl_files, key=os.path.getctime)
csv_path = latest_pkl.replace(".pkl", ".csv")
# PKL 파일을 CSV로 변환
convert_pkl_to_csv(latest_pkl, csv_path)
# Background task에 파일 삭제 작업 추가
background_tasks.add_task(delete_csv_file, csv_path)
# CSV 파일 응답 반환
return FileResponse(csv_path, media_type="application/csv", filename=os.path.basename(csv_path))
except Exception as e:
logger.error(f"Error during file download: {e}")
return JSONResponse(content={"error": "Failed to download the result file."}, status_code=500)