183 lines
7.6 KiB
Python
Executable File
183 lines
7.6 KiB
Python
Executable File
import os
|
|
import pandas as pd
|
|
from fastapi import FastAPI, UploadFile, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
from redis import Redis
|
|
from rq import Queue
|
|
from vllm import LLM, SamplingParams
|
|
import logging
|
|
import gc
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
import sys
|
|
sys.path.append("/opt/workspace/")
|
|
from template import LLMInference
|
|
|
|
app = FastAPI()
|
|
|
|
# Redis 설정
|
|
redis_conn = Redis(host="redis-server", port=6379, decode_responses=True)
|
|
queue = Queue("model_tasks", connection=redis_conn)
|
|
|
|
# 로깅 설정
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
|
|
@app.post("/start-inference/")
|
|
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
|
|
logger.info(f"file_name: {input_csv},model_list_file: {model_list_txt}")
|
|
# 파일 형식 확인 및 저장
|
|
if not input_csv.filename.endswith(".csv") or not model_list_txt.filename.endswith(".txt"):
|
|
return JSONResponse(content={"error": "Invalid file format."}, status_code=400)
|
|
|
|
file_path = f"uploaded/{input_csv.filename}"
|
|
model_list_path = f"uploaded/{model_list_txt.filename}"
|
|
os.makedirs("uploaded", exist_ok=True)
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(await input_csv.read())
|
|
|
|
with open(model_list_path, "wb") as f:
|
|
f.write(await model_list_txt.read())
|
|
|
|
df = pd.read_csv(file_path, encoding="euc-kr")
|
|
batch_size = 10
|
|
job_ids = []
|
|
|
|
# 데이터를 batch_size로 나누어 작업 큐에 추가
|
|
for i in range(0, len(df), batch_size):
|
|
batch_file_path = file_path.replace(".csv", f"_batch_{i}_{i+batch_size}.csv")
|
|
df.iloc[i:i+batch_size].to_csv(batch_file_path, index=False, encoding="utf-8")
|
|
job = queue.enqueue(run_inference, batch_file_path, model_list_path, job_timeout=1800)
|
|
job_ids.append(job.id)
|
|
|
|
logger.info(f"Jobs enqueued: {job_ids}")
|
|
return {"job_ids": job_ids, "status": "queued"}
|
|
|
|
def chat_formating(input_sentence: str, model_name: str):
|
|
|
|
if "llama" in model_name:
|
|
hidden_prompt = LLMInference.llama_template()
|
|
elif "gemma" in model_name:
|
|
hidden_prompt = LLMInference.gemma_template()
|
|
elif "exaone" in model_name:
|
|
hidden_prompt = LLMInference.exaone_template()
|
|
else:
|
|
raise ValueError("Unknown model name: " + model_name)
|
|
|
|
formated_sentence = hidden_prompt.format(input_sent=input_sentence)
|
|
logger.info(f"Sentence: {formated_sentence}")
|
|
return formated_sentence
|
|
|
|
# 모델 추론 함수
|
|
def run_inference(batch_file_path: str, model_list_path: str):
|
|
try:
|
|
# 워커 ID 확인
|
|
worker_id = os.environ.get("HOSTNAME", "Unknown Worker")
|
|
logger.info(f"Worker {worker_id} started inference for batch file: {batch_file_path}")
|
|
|
|
# 모델 리스트 읽기
|
|
with open(model_list_path, "r") as f:
|
|
model_list = [line.strip() for line in f.readlines()]
|
|
|
|
if not model_list:
|
|
raise ValueError("The model list file is empty.")
|
|
|
|
# 배치 데이터 읽기
|
|
df = pd.read_csv(batch_file_path, encoding="utf-8")
|
|
if "input" not in df.columns:
|
|
raise ValueError("The input CSV must contain a column named 'input'.")
|
|
|
|
# 추론 수행
|
|
for model in model_list:
|
|
logger.info(f"Worker {worker_id} loading model: {model}")
|
|
try:
|
|
llm = LLM(model)
|
|
torch.cuda.empty_cache()
|
|
logger.info(f"Worker {worker_id} loaded model {model} successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Worker {worker_id} error loading model {model}: {e}")
|
|
continue
|
|
|
|
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
|
|
responses = []
|
|
|
|
# tqdm 추가: 워커별 모델 진행 상태 표시
|
|
with tqdm(total=len(df), desc=f"[{worker_id}] Model: {model}") as pbar:
|
|
model_name = model.split("/")[-1]
|
|
for _, row in df.iterrows():
|
|
try:
|
|
input_text = chat_formating(input_sentence=row["input"], model_name=model_name)
|
|
response = llm.generate(input_text, sampling_params)[0].outputs[0].text.strip()
|
|
logger.info(f"Model: {model}, Input: {input_text}, Output: {response}")
|
|
responses.append(response)
|
|
except Exception as e:
|
|
logger.error(f"Worker {worker_id} error during inference for model {model}, row {row.name}: {e}")
|
|
error_rows = pd.concat([error_rows, pd.DataFrame([row])], ignore_index=True)
|
|
responses.append(None)
|
|
finally:
|
|
pbar.update(1)
|
|
|
|
# 결과 추가
|
|
df[model_name] = responses
|
|
del llm
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# 배치 결과 저장
|
|
output_path = batch_file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
|
|
os.makedirs("processed", exist_ok=True)
|
|
df.to_csv(output_path, index=False, encoding="utf-8")
|
|
logger.info(f"Worker {worker_id} inference completed for batch. Result saved to: {output_path}")
|
|
|
|
# 에러 행 저장
|
|
if not error_rows.empty:
|
|
error_path = batch_file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
|
|
os.makedirs("errors", exist_ok=True)
|
|
error_rows.to_csv(error_path, index=False, encoding="utf-8")
|
|
logger.info(f"Error rows saved to: {error_path}")
|
|
|
|
return output_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Worker {worker_id} error during inference: {e}")
|
|
raise
|
|
|
|
@app.get("/merge-results/")
|
|
def merge_results():
|
|
try:
|
|
processed_dir = "processed"
|
|
all_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith("_result.csv")]
|
|
combined_df = pd.concat([pd.read_csv(f, encoding="utf-8") for f in all_files], ignore_index=True)
|
|
|
|
final_output_path = os.path.join(processed_dir, "final_result.csv")
|
|
combined_df.to_csv(final_output_path, index=False, encoding="utf-8")
|
|
|
|
logger.info(f"Final merged result saved to: {final_output_path}")
|
|
return {"final_result_path": final_output_path}
|
|
except Exception as e:
|
|
logger.error(f"Error during merging results: {e}")
|
|
return JSONResponse(content={"error": "Failed to merge results."}, status_code=500)
|
|
|
|
# 결과 파일 다운로드
|
|
@app.get("/download-latest", response_class=FileResponse)
|
|
def download_latest_file():
|
|
try:
|
|
# processed 디렉토리 경로
|
|
directory = "LLM_asyncio/processed"
|
|
|
|
csv_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".csv")]
|
|
|
|
if not csv_files:
|
|
return JSONResponse(content={"error": "No CSV files found in the processed directory."}, status_code=404)
|
|
|
|
latest_file = max(csv_files, key=os.path.getctime)
|
|
|
|
logger.info(f"Downloading latest file: {latest_file}")
|
|
|
|
return FileResponse(latest_file, media_type="application/csv", filename=os.path.basename(latest_file))
|
|
except Exception as e:
|
|
logger.error(f"Error during file download: {e}")
|
|
return JSONResponse(content={"error": "Failed to download the latest file."}, status_code=500) |