import os import pandas as pd from fastapi import FastAPI, UploadFile, BackgroundTasks from fastapi.responses import JSONResponse, FileResponse from redis import Redis from rq import Queue from vllm import LLM, SamplingParams import logging import gc import torch from tqdm import tqdm import sys sys.path.append("/workspace/LLM_asyncio") from template import LLMInference app = FastAPI() # Redis 설정 redis_conn = Redis(host="redis-server", port=6379, decode_responses=True) queue = Queue("model_tasks", connection=redis_conn) # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리 @app.post("/start-inference/") async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks): # 파일 형식 확인 if not input_csv.filename.endswith(".csv"): return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400) if not model_list_txt.filename.endswith(".txt"): return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400) # 파일 저장 file_path = f"uploaded/{input_csv.filename}" model_list_path = f"uploaded/{model_list_txt.filename}" os.makedirs("uploaded", exist_ok=True) with open(file_path, "wb") as f: f.write(await input_csv.read()) with open(model_list_path, "wb") as f: f.write(await model_list_txt.read()) logger.info(f"Files uploaded: {file_path}, {model_list_path}") # 작업 큐에 추가 job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800) logger.info(f"Job enqueued: {job.id}") return {"job_id": job.id, "status": "queued"} def chat_formating(input_sentence: str, model_name: str): try: if "llama" in model_name: hidden_prompt = LLMInference.llama_template() if "gemma" in model_name: hidden_prompt = LLMInference.gemma_template() if "exaone" in model_name: hidden_prompt = LLMInference.exaone_template() formated_sentence = hidden_prompt.format(input_sent=input_sentence) return formated_sentence except Exception as e: logger.error(f"Not formatting input sentence: {e}") return input_sentence # 모델 추론 함수 def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): try: logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}") # 모델 리스트 읽기 with open(model_list_path, "r") as f: model_list = [line.strip() for line in f.readlines()] if not model_list: raise ValueError("The model list file is empty.") # CSV 읽기 df = pd.read_csv(file_path, encoding="euc-kr") if "input" not in df.columns: raise ValueError("The input CSV must contain a column named 'input'.") # 에러 발생한 행 저장용 DataFrame 초기화 error_rows = pd.DataFrame(columns=df.columns) # 각 모델로 추론 for model in model_list: model_name = model.split("/")[-1] try: logger.info(f"Loading model: {model}") llm = LLM(model) torch.cuda.empty_cache() logger.info(f"Model {model} loaded successfully.") except Exception as e: logger.error(f"Error loading model {model}: {e}") continue sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50) # 추론 수행 responses = [] for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"): batch = df.iloc[i:i+batch_size] batch_responses = [] for _, row in batch.iterrows(): try: original_input = row["input"] formating_input = chat_formating(input_sentence=row["input"], model_name=model_name.lower()) response = llm.generate(formating_input, sampling_params)[0].outputs[0].text.strip() logger.info(f"Model: {model}, Input: {original_input}, Output: {response}") batch_responses.append(response) except Exception as e: logger.error(f"Error during inference for model {model}, row {row.name}: {e}") error_rows = pd.concat([error_rows, pd.DataFrame([row])], ignore_index=True) batch_responses.append(None) responses.extend(batch_responses) # 결과 추가 df[model_name] = responses del llm torch.cuda.empty_cache() gc.collect() # 결과 저장 output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv") os.makedirs("processed", exist_ok=True) df.to_csv(output_path, index=False, encoding="utf-8") logger.info(f"Inference completed. Result saved to: {output_path}") # 에러 행 저장 if not error_rows.empty: error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv") os.makedirs("errors", exist_ok=True) error_rows.to_csv(error_path, index=False, encoding="utf-8") logger.info(f"Error rows saved to: {error_path}") return output_path except Exception as e: logger.error(f"Error during inference: {e}") raise # 결과 파일 다운로드 @app.get("/download-latest", response_class=FileResponse) def download_latest_file(): try: # processed 디렉토리 경로 directory = "processed" csv_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".csv")] if not csv_files: return JSONResponse(content={"error": "No CSV files found in the processed directory."}, status_code=404) latest_file = max(csv_files, key=os.path.getctime) logger.info(f"Downloading latest file: {latest_file}") return FileResponse(latest_file, media_type="application/csv", filename=os.path.basename(latest_file)) except Exception as e: logger.error(f"Error during file download: {e}") return JSONResponse(content={"error": "Failed to download the latest file."}, status_code=500)