Add PKL storage and CSV conversion
This commit is contained in:
@@ -26,13 +26,11 @@ logger = logging.getLogger(__name__)
|
|||||||
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
|
# FastAPI 엔드포인트: CSV 파일 및 모델 리스트 업로드 처리
|
||||||
@app.post("/start-inference/")
|
@app.post("/start-inference/")
|
||||||
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
|
async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, background_tasks: BackgroundTasks):
|
||||||
# 파일 형식 확인
|
|
||||||
if not input_csv.filename.endswith(".csv"):
|
if not input_csv.filename.endswith(".csv"):
|
||||||
return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400)
|
return JSONResponse(content={"error": "Uploaded file is not a CSV."}, status_code=400)
|
||||||
if not model_list_txt.filename.endswith(".txt"):
|
if not model_list_txt.filename.endswith(".txt"):
|
||||||
return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400)
|
return JSONResponse(content={"error": "Uploaded model list is not a TXT file."}, status_code=400)
|
||||||
|
|
||||||
# 파일 저장
|
|
||||||
file_path = f"uploaded/{input_csv.filename}"
|
file_path = f"uploaded/{input_csv.filename}"
|
||||||
model_list_path = f"uploaded/{model_list_txt.filename}"
|
model_list_path = f"uploaded/{model_list_txt.filename}"
|
||||||
os.makedirs("uploaded", exist_ok=True)
|
os.makedirs("uploaded", exist_ok=True)
|
||||||
@@ -45,7 +43,6 @@ async def process_csv(input_csv: UploadFile, model_list_txt: UploadFile, backgro
|
|||||||
|
|
||||||
logger.info(f"Files uploaded: {file_path}, {model_list_path}")
|
logger.info(f"Files uploaded: {file_path}, {model_list_path}")
|
||||||
|
|
||||||
# 작업 큐에 추가
|
|
||||||
job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800)
|
job = queue.enqueue(run_inference, file_path, model_list_path, job_timeout=1800)
|
||||||
logger.info(f"Job enqueued: {job.id}")
|
logger.info(f"Job enqueued: {job.id}")
|
||||||
return {"job_id": job.id, "status": "queued"}
|
return {"job_id": job.id, "status": "queued"}
|
||||||
@@ -74,6 +71,7 @@ def delete_csv_file(file_path: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting CSV file: {e}")
|
logger.error(f"Error deleting CSV file: {e}")
|
||||||
|
|
||||||
|
# 모델 템플릿 적용
|
||||||
def chat_formating(input_sentence: str, model_name: str):
|
def chat_formating(input_sentence: str, model_name: str):
|
||||||
try:
|
try:
|
||||||
if "llama" in model_name:
|
if "llama" in model_name:
|
||||||
@@ -94,14 +92,12 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
|||||||
try:
|
try:
|
||||||
logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}")
|
logger.info(f"Starting inference for file: {file_path} using models from {model_list_path}")
|
||||||
|
|
||||||
# 모델 리스트 읽기
|
|
||||||
with open(model_list_path, "r") as f:
|
with open(model_list_path, "r") as f:
|
||||||
model_list = [line.strip() for line in f.readlines()]
|
model_list = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
if not model_list:
|
if not model_list:
|
||||||
raise ValueError("The model list file is empty.")
|
raise ValueError("The model list file is empty.")
|
||||||
|
|
||||||
# CSV 읽기
|
|
||||||
try:
|
try:
|
||||||
df = pd.read_csv(file_path, encoding="euc-kr")
|
df = pd.read_csv(file_path, encoding="euc-kr")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -111,10 +107,8 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
|||||||
if "input" not in df.columns:
|
if "input" not in df.columns:
|
||||||
raise ValueError("The input CSV must contain a column named 'input'.")
|
raise ValueError("The input CSV must contain a column named 'input'.")
|
||||||
|
|
||||||
# 에러 발생한 행 저장용 DataFrame 초기화
|
|
||||||
error_rows = pd.DataFrame(columns=df.columns)
|
error_rows = pd.DataFrame(columns=df.columns)
|
||||||
|
|
||||||
# 각 모델로 추론
|
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
model_name = model.split("/")[-1]
|
model_name = model.split("/")[-1]
|
||||||
try:
|
try:
|
||||||
@@ -128,7 +122,6 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
|||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
|
sampling_params = SamplingParams(max_tokens=50, temperature=0.7, top_p=0.9, top_k=50)
|
||||||
|
|
||||||
# 추론 수행
|
|
||||||
responses = []
|
responses = []
|
||||||
for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"):
|
for i in tqdm(range(0, len(df), batch_size), desc=f"Processing {model}"):
|
||||||
batch = df.iloc[i:i+batch_size]
|
batch = df.iloc[i:i+batch_size]
|
||||||
@@ -146,21 +139,18 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
|||||||
batch_responses.append(None)
|
batch_responses.append(None)
|
||||||
responses.extend(batch_responses)
|
responses.extend(batch_responses)
|
||||||
|
|
||||||
# 결과 추가
|
|
||||||
df[model_name] = responses
|
df[model_name] = responses
|
||||||
|
|
||||||
del llm
|
del llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# 결과 저장
|
|
||||||
output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
|
output_path = file_path.replace("uploaded", "processed").replace(".csv", "_result.csv")
|
||||||
os.makedirs("processed", exist_ok=True)
|
os.makedirs("processed", exist_ok=True)
|
||||||
# df.to_csv(output_path, index=False, encoding="utf-8")
|
# df.to_csv(output_path, index=False, encoding="utf-8")
|
||||||
save_to_pkl(df, output_path)
|
save_to_pkl(df, output_path)
|
||||||
logger.info(f"Inference completed. Result saved to: {output_path}")
|
logger.info(f"Inference completed. Result saved to: {output_path}")
|
||||||
|
|
||||||
# 에러 행 저장
|
|
||||||
if not error_rows.empty:
|
if not error_rows.empty:
|
||||||
error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
|
error_path = file_path.replace("uploaded", "errors").replace(".csv", "_errors.csv")
|
||||||
os.makedirs("errors", exist_ok=True)
|
os.makedirs("errors", exist_ok=True)
|
||||||
@@ -177,12 +167,10 @@ def run_inference(file_path: str, model_list_path: str, batch_size: int = 32):
|
|||||||
@app.get("/download-latest-result", response_class=FileResponse)
|
@app.get("/download-latest-result", response_class=FileResponse)
|
||||||
def download_latest_result(background_tasks: BackgroundTasks):
|
def download_latest_result(background_tasks: BackgroundTasks):
|
||||||
try:
|
try:
|
||||||
# PKL 파일 저장 디렉토리
|
|
||||||
processed_dir = "processed"
|
processed_dir = "processed"
|
||||||
if not os.path.exists(processed_dir):
|
if not os.path.exists(processed_dir):
|
||||||
return JSONResponse(content={"error": "Processed directory not found."}, status_code=404)
|
return JSONResponse(content={"error": "Processed directory not found."}, status_code=404)
|
||||||
|
|
||||||
# 디렉토리 내 가장 최근에 저장된 PKL 파일 찾기
|
|
||||||
pkl_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith(".pkl")]
|
pkl_files = [os.path.join(processed_dir, f) for f in os.listdir(processed_dir) if f.endswith(".pkl")]
|
||||||
if not pkl_files:
|
if not pkl_files:
|
||||||
return JSONResponse(content={"error": "No PKL files found in the processed directory."}, status_code=404)
|
return JSONResponse(content={"error": "No PKL files found in the processed directory."}, status_code=404)
|
||||||
@@ -190,13 +178,10 @@ def download_latest_result(background_tasks: BackgroundTasks):
|
|||||||
latest_pkl = max(pkl_files, key=os.path.getctime)
|
latest_pkl = max(pkl_files, key=os.path.getctime)
|
||||||
csv_path = latest_pkl.replace(".pkl", ".csv")
|
csv_path = latest_pkl.replace(".pkl", ".csv")
|
||||||
|
|
||||||
# PKL 파일을 CSV로 변환
|
|
||||||
convert_pkl_to_csv(latest_pkl, csv_path)
|
convert_pkl_to_csv(latest_pkl, csv_path)
|
||||||
|
|
||||||
# Background task에 파일 삭제 작업 추가
|
|
||||||
background_tasks.add_task(delete_csv_file, csv_path)
|
background_tasks.add_task(delete_csv_file, csv_path)
|
||||||
|
|
||||||
# CSV 파일 응답 반환
|
|
||||||
return FileResponse(csv_path, media_type="application/csv", filename=os.path.basename(csv_path))
|
return FileResponse(csv_path, media_type="application/csv", filename=os.path.basename(csv_path))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user