Files
llm_asycio/workspace/main.py
b24503@hanmaceng.co.kr fc3ead893a Initial commit
2025-01-07 09:11:27 +09:00

183 lines
7.6 KiB
Python
Executable File

import os
import pandas as pd
from fastapi import FastAPI, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse
from redis import Redis
from rq import Queue
from vllm import LLM, SamplingParams
import logging
import gc
import torch
from tqdm import tqdm
import sys
sys.path.append("/opt/workspace/")
from template import LLMInference
app = FastAPI()
# Redis 설정
redis_conn = Redis(host="redis-server", port=6379, decode_responses=True)
queue = Queue("model_tasks", connection=redis_conn)
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
@app.post("/start-inference/")
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
logger.info(f"file_name: {input_csv},model_list_file: {model_list_txt}")
# 파일 형식 확인 및 저장
if not input_csv.filename.endswith(".csv") or not model_list_txt.filename.endswith(".txt"):
return JSONResponse(content={"error": "Invalid file format."}, status_code=400)
file_path = f"uploaded/{input_csv.filename}"
model_list_path = f"uploaded/{model_list_txt.filename}"
os.makedirs("uploaded", exist_ok=True)
with open(file_path, "wb") as f:
f.write(await input_csv.read())
with open(model_list_path, "wb") as f:
f.write(await model_list_txt.read())
df = pd.read_csv(file_path, encoding="euc-kr")
batch_size = 10
job_ids = []
# 데이터를 batch_size로 나누어 작업 큐에 추가
for i in range(0, len(df), batch_size):
batch_file_path = file_path.replace(".csv", f"_batch_{i}_{i+batch_size}.csv")
df.iloc[i:i+batch_size].to_csv(batch_file_path, index=False, encoding="utf-8")
job = queue.enqueue(run_inference, batch_file_path, model_list_path, job_timeout=1800)
job_ids.append(job.id)
logger.info(f"Jobs enqueued: {job_ids}")
return {"job_ids": job_ids, "status": "queued"}
def chat_formating(input_sentence: str, model_name: str):
if "llama" in model_name:
hidden_prompt = LLMInference.llama_template()
elif "gemma" in model_name:
hidden_prompt = LLMInference.gemma_template()
elif "exaone" in model_name:
hidden_prompt = LLMInference.exaone_template()
else:
raise ValueError("Unknown model name: " + model_name)
formated_sentence = hidden_prompt.format(input_sent=input_sentence)
logger.info(f"Sentence: {formated_sentence}")
return formated_sentence
# 모델 추론 함수
def run_inference(batch_file_path: str, model_list_path: str):
try:
# 워커 ID 확인
worker_id = os.environ.get("HOSTNAME", "Unknown Worker")
logger.info(f"Worker {worker_id} started inference for batch file: {batch_file_path}")
# 모델 리스트 읽기
with open(model_list_path, "r") as f:
model_list = [line.strip() for line in f.readlines()]
if not model_list:
raise ValueError("The model list file is empty.")
# 배치 데이터 읽기
df = pd.read_csv(batch_file_path, encoding="utf-8")
if "input" not in df.columns:
raise ValueError("The input CSV must contain a column named 'input'.")
# 추론 수행
for model in model_list:
logger.info(f"Worker {worker_id} loading model: {model}")
try:
llm = LLM(model)
torch.cuda.empty_cache()
logger.info(f"Worker {worker_id} loaded model {model} successfully.")
except Exception as e:
logger.error(f"Worker {worker_id} error loading model {model}: {e}")
continue
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
responses = []
# tqdm 추가: 워커별 모델 진행 상태 표시
with tqdm(total=len(df), desc=f"[{worker_id}] Model: {model}") as pbar:
model_name = model.split("/")[-1]
for _, row in df.iterrows():
try:
input_text = chat_formating(input_sentence=row["input"], model_name=model_name)
response = llm.generate(input_text, sampling_params)[0].outputs[0].text.strip()
logger.info(f"Model: {model}, Input: {input_text}, Output: {response}")
responses.append(response)
except Exception as e:
logger.error(f"Worker {worker_id} error during inference for model {model}, row {row.name}: {e}")
error_rows = pd.concat([error_rows, pd.DataFrame([row])], ignore_index=True)
responses.append(None)
finally:
pbar.update(1)
# 결과 추가
df[model_name] = responses
del llm
torch.cuda.empty_cache()
gc.collect()
# 배치 결과 저장
output_path = batch_file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
os.makedirs("processed", exist_ok=True)
df.to_csv(output_path, index=False, encoding="utf-8")
logger.info(f"Worker {worker_id} inference completed for batch. Result saved to: {output_path}")
# 에러 행 저장
if not error_rows.empty:
error_path = batch_file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
os.makedirs("errors", exist_ok=True)
error_rows.to_csv(error_path, index=False, encoding="utf-8")
logger.info(f"Error rows saved to: {error_path}")
return output_path
except Exception as e:
logger.error(f"Worker {worker_id} error during inference: {e}")
raise
@app.get("/merge-results/")
def merge_results():
try:
processed_dir = "processed"
all_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith("_result.csv")]
combined_df = pd.concat([pd.read_csv(f, encoding="utf-8") for f in all_files], ignore_index=True)
final_output_path = os.path.join(processed_dir, "final_result.csv")
combined_df.to_csv(final_output_path, index=False, encoding="utf-8")
logger.info(f"Final merged result saved to: {final_output_path}")
return {"final_result_path": final_output_path}
except Exception as e:
logger.error(f"Error during merging results: {e}")
return JSONResponse(content={"error": "Failed to merge results."}, status_code=500)
# 결과 파일 다운로드
@app.get("/download-latest", response_class=FileResponse)
def download_latest_file():
try:
# processed 디렉토리 경로
directory = "LLM_asyncio/processed"
csv_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".csv")]
if not csv_files:
return JSONResponse(content={"error": "No CSV files found in the processed directory."}, status_code=404)
latest_file = max(csv_files, key=os.path.getctime)
logger.info(f"Downloading latest file: {latest_file}")
return FileResponse(latest_file, media_type="application/csv", filename=os.path.basename(latest_file))
except Exception as e:
logger.error(f"Error during file download: {e}")
return JSONResponse(content={"error": "Failed to download the latest file."}, status_code=500)