From 9aab12ed50261abbf9c61e6e0f9734724c78acb4 Mon Sep 17 00:00:00 2001 From: kyy Date: Thu, 9 Jan 2025 16:30:41 +0900 Subject: [PATCH] Add PKL storage and CSV conversion --- workspace/main.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/workspace/main.py b/workspace/main.py index e827792..71cbc49 100755 --- a/workspace/main.py +++ b/workspace/main.py @@ -26,13 +26,11 @@ logger = logging.getLogger(__name__) # FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리 @app.post("/start-inference/") async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks): - # 파일 형식 확인 if not input_csv.filename.endswith(".csv"): return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400) if not model_list_txt.filename.endswith(".txt"): return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400) - # 파일 저장 file_path = f"uploaded/{input_csv.filename}" model_list_path = f"uploaded/{model_list_txt.filename}" os.makedirs("uploaded", exist_ok=True) @@ -45,7 +43,6 @@ async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, backgro logger.info(f"Files uploaded: {file_path}, {model_list_path}") - # 작업 큐에 추가 job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800) logger.info(f"Job enqueued: {job.id}") return {"job_id": job.id, "status": "queued"} @@ -74,6 +71,7 @@ def delete_csv_file(file_path: str): except Exception as e: logger.error(f"Error deleting CSV file: {e}") +# 모델 템플릿 적용 def chat_formating(input_sentence: str, model_name: str): try: if "llama" in model_name: @@ -94,14 +92,12 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): try: logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}") - # 모델 리스트 읽기 with open(model_list_path, "r") as f: model_list = [line.strip() for line in f.readlines()] if not model_list: raise ValueError("The model list file is empty.") - # CSV 읽기 try: df = pd.read_csv(file_path, encoding="euc-kr") except Exception as e: @@ -111,10 +107,8 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): if "input" not in df.columns: raise ValueError("The input CSV must contain a column named 'input'.") - # 에러 발생한 행 저장용 DataFrame 초기화 error_rows = pd.DataFrame(columns=df.columns) - # 각 모델로 추론 for model in model_list: model_name = model.split("/")[-1] try: @@ -128,7 +122,6 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50) - # 추론 수행 responses = [] for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"): batch = df.iloc[i:i+batch_size] @@ -146,21 +139,18 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): batch_responses.append(None) responses.extend(batch_responses) - # 결과 추가 df[model_name] = responses del llm torch.cuda.empty_cache() gc.collect() - # 결과 저장 output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv") os.makedirs("processed", exist_ok=True) # df.to_csv(output_path, index=False, encoding="utf-8") save_to_pkl(df, output_path) logger.info(f"Inference completed. Result saved to: {output_path}") - # 에러 행 저장 if not error_rows.empty: error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv") os.makedirs("errors", exist_ok=True) @@ -177,12 +167,10 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32): @app.get("/download-latest-result", response_class=FileResponse) def download_latest_result(background_tasks: BackgroundTasks): try: - # PKL 파일 저장 디렉토리 processed_dir = "processed" if not os.path.exists(processed_dir): return JSONResponse(content={"error": "Processed directory not found."}, status_code=404) - # 디렉토리 내 가장 최근에 저장된 PKL 파일 찾기 pkl_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith(".pkl")] if not pkl_files: return JSONResponse(content={"error": "No PKL files found in the processed directory."}, status_code=404) @@ -190,13 +178,10 @@ def download_latest_result(background_tasks: BackgroundTasks): latest_pkl = max(pkl_files, key=os.path.getctime) csv_path = latest_pkl.replace(".pkl", ".csv") - # PKL 파일을 CSV로 변환 convert_pkl_to_csv(latest_pkl, csv_path) - # Background task에 파일 삭제 작업 추가 background_tasks.add_task(delete_csv_file, csv_path) - # CSV 파일 응답 반환 return FileResponse(csv_path, media_type="application/csv", filename=os.path.basename(csv_path)) except Exception as e: