Compare commits
3 Commits
4ced8db541
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 75b9c50e7b | |||
| 9aab12ed50 | |||
| dcdc3157c3 |
13
docker-compose.yml
Executable file → Normal file
13
docker-compose.yml
Executable file → Normal file
@@ -3,8 +3,8 @@ version: "3.8"
|
||||
services:
|
||||
llm-asyncio:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
shm_size: "1000gb"
|
||||
volumes:
|
||||
- ./workspace:/opt/workspace/
|
||||
@@ -42,11 +42,11 @@ services:
|
||||
|
||||
worker:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
shm_size: "1000gb"
|
||||
volumes:
|
||||
- ./workspace:/opt/workspace/
|
||||
- ./workspace:/opt/workspace/
|
||||
- ./cache:/root/.cache/
|
||||
- ../model:/opt/model/
|
||||
environment:
|
||||
@@ -65,9 +65,10 @@ services:
|
||||
/bin/bash -c "
|
||||
python /opt/workspace/worker.py
|
||||
"
|
||||
restart: always
|
||||
tty: true
|
||||
scale: 2
|
||||
|
||||
|
||||
networks:
|
||||
llm-network:
|
||||
driver: bridge
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import pandas as pd
|
||||
from fastapi import FastAPI, UploadFile, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
import shutil
|
||||
from redis import Redis
|
||||
from rq import Queue
|
||||
from vllm import LLM, SamplingParams
|
||||
@@ -9,10 +10,8 @@ import logging
|
||||
import gc
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import sys
|
||||
sys.path.append("/workspace/LLM_asyncio")
|
||||
from template import LLMInference
|
||||
import pickle
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -27,18 +26,16 @@ logger = logging.getLogger(__name__)
|
||||
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
|
||||
@app.post("/start-inference/")
|
||||
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
|
||||
# 파일 형식 확인
|
||||
if not input_csv.filename.endswith(".csv"):
|
||||
return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400)
|
||||
if not model_list_txt.filename.endswith(".txt"):
|
||||
return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400)
|
||||
|
||||
# 파일 저장
|
||||
file_path = f"uploaded/{input_csv.filename}"
|
||||
model_list_path = f"uploaded/{model_list_txt.filename}"
|
||||
os.makedirs("uploaded", exist_ok=True)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await input_csv.read())
|
||||
|
||||
with open(model_list_path, "wb") as f:
|
||||
@@ -46,11 +43,35 @@ async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, backgro
|
||||
|
||||
logger.info(f"Files uploaded: {file_path}, {model_list_path}")
|
||||
|
||||
# 작업 큐에 추가
|
||||
job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800)
|
||||
logger.info(f"Job enqueued: {job.id}")
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
|
||||
# CSV를 PKL로 변환
|
||||
def save_to_pkl(dataframe: pd.DataFrame, output_path: str):
|
||||
pkl_path = output_path.replace(".csv", ".pkl")
|
||||
with open(pkl_path, "wb") as pkl_file:
|
||||
pickle.dump(dataframe, pkl_file)
|
||||
logger.info(f"Data saved as PKL: {pkl_path}")
|
||||
return pkl_path
|
||||
|
||||
# PKL을 CSV로 변환
|
||||
def convert_pkl_to_csv(pkl_path: str, csv_path: str):
|
||||
with open(pkl_path, "rb") as pkl_file:
|
||||
dataframe = pickle.load(pkl_file)
|
||||
dataframe.to_csv(csv_path, index=False, encoding="utf-8")
|
||||
logger.info(f"PKL converted to CSV: {csv_path}")
|
||||
return csv_path
|
||||
|
||||
# CSV 파일 삭제 작업
|
||||
def delete_csv_file(file_path: str):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"CSV file deleted: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting CSV file: {e}")
|
||||
|
||||
# 모델 템플릿 적용
|
||||
def chat_formating(input_sentence: str, model_name: str):
|
||||
try:
|
||||
if "llama" in model_name:
|
||||
@@ -71,22 +92,23 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
||||
try:
|
||||
logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}")
|
||||
|
||||
# 모델 리스트 읽기
|
||||
with open(model_list_path, "r") as f:
|
||||
model_list = [line.strip() for line in f.readlines()]
|
||||
|
||||
if not model_list:
|
||||
raise ValueError("The model list file is empty.")
|
||||
|
||||
# CSV 읽기
|
||||
df = pd.read_csv(file_path, encoding="euc-kr")
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding="euc-kr")
|
||||
except Exception as e:
|
||||
df = pd.read_csv(file_path, encoding="utf-8")
|
||||
logger.info(f"Failed to read {file_path} as {e}")
|
||||
|
||||
if "input" not in df.columns:
|
||||
raise ValueError("The input CSV must contain a column named 'input'.")
|
||||
|
||||
# 에러 발생한 행 저장용 DataFrame 초기화
|
||||
error_rows = pd.DataFrame(columns=df.columns)
|
||||
|
||||
# 각 모델로 추론
|
||||
|
||||
for model in model_list:
|
||||
model_name = model.split("/")[-1]
|
||||
try:
|
||||
@@ -100,7 +122,6 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
|
||||
|
||||
# 추론 수행
|
||||
responses = []
|
||||
for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"):
|
||||
batch = df.iloc[i:i+batch_size]
|
||||
@@ -118,20 +139,18 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
||||
batch_responses.append(None)
|
||||
responses.extend(batch_responses)
|
||||
|
||||
# 결과 추가
|
||||
df[model_name] = responses
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 결과 저장
|
||||
output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
|
||||
os.makedirs("processed", exist_ok=True)
|
||||
df.to_csv(output_path, index=False, encoding="utf-8")
|
||||
# df.to_csv(output_path, index=False, encoding="utf-8")
|
||||
save_to_pkl(df, output_path)
|
||||
logger.info(f"Inference completed. Result saved to: {output_path}")
|
||||
|
||||
# 에러 행 저장
|
||||
if not error_rows.empty:
|
||||
error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
|
||||
os.makedirs("errors", exist_ok=True)
|
||||
@@ -143,22 +162,28 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
|
||||
# 결과 파일 다운로드
|
||||
@app.get("/download-latest", response_class=FileResponse)
|
||||
def download_latest_file():
|
||||
|
||||
# PKL에서 CSV로 변환하여 다운로드 후 삭제
|
||||
@app.get("/download-latest-result", response_class=FileResponse)
|
||||
def download_latest_result(background_tasks: BackgroundTasks):
|
||||
try:
|
||||
# processed 디렉토리 경로
|
||||
directory = "processed"
|
||||
csv_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".csv")]
|
||||
processed_dir = "processed"
|
||||
if not os.path.exists(processed_dir):
|
||||
return JSONResponse(content={"error": "Processed directory not found."}, status_code=404)
|
||||
|
||||
if not csv_files:
|
||||
return JSONResponse(content={"error": "No CSV files found in the processed directory."}, status_code=404)
|
||||
pkl_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith(".pkl")]
|
||||
if not pkl_files:
|
||||
return JSONResponse(content={"error": "No PKL files found in the processed directory."}, status_code=404)
|
||||
|
||||
latest_file = max(csv_files, key=os.path.getctime)
|
||||
latest_pkl = max(pkl_files, key=os.path.getctime)
|
||||
csv_path = latest_pkl.replace(".pkl", ".csv")
|
||||
|
||||
convert_pkl_to_csv(latest_pkl, csv_path)
|
||||
|
||||
background_tasks.add_task(delete_csv_file, csv_path)
|
||||
|
||||
return FileResponse(csv_path, media_type="application/csv", filename=os.path.basename(csv_path))
|
||||
|
||||
logger.info(f"Downloading latest file: {latest_file}")
|
||||
return FileResponse(latest_file, media_type="application/csv", filename=os.path.basename(latest_file))
|
||||
except Exception as e:
|
||||
logger.error(f"Error during file download: {e}")
|
||||
return JSONResponse(content={"error": "Failed to download the latest file."}, status_code=500)
|
||||
return JSONResponse(content={"error": "Failed to download the result file."}, status_code=500)
|
||||
|
||||
Reference in New Issue
Block a user