149 lines
5.6 KiB
Python
149 lines
5.6 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
import requests
|
|
|
|
# API 키
|
|
API_KEY = "sk-dade0cb396c744ec431357cedd5784c2"
|
|
|
|
# 지원하는 파일 확장자
|
|
SUPPORTED_EXTENSIONS = (".pdf", ".png", ".jpg", ".jpeg", ".docx", ".json")
|
|
|
|
|
|
def send_sync_request(
|
|
api_endpoint, file_path, model_name, prompt_file_path, source_dir, result_file_path
|
|
):
|
|
"""지정된 파일과 옵션으로 API에 동기 요청을 보내고 결과를 저장합니다."""
|
|
print(
|
|
f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] '{os.path.basename(file_path)}' 파일 동기 전송 중..."
|
|
)
|
|
try:
|
|
headers = {"x-api-key": API_KEY}
|
|
with open(file_path, "rb") as input_file:
|
|
files = {"input_file": (os.path.basename(file_path), input_file)}
|
|
data = {"source_dir": source_dir}
|
|
if model_name:
|
|
data["model"] = model_name
|
|
|
|
if prompt_file_path:
|
|
try:
|
|
prompt_file_opened = open(prompt_file_path, "rb")
|
|
files["prompt_file"] = (
|
|
os.path.basename(prompt_file_path),
|
|
prompt_file_opened,
|
|
)
|
|
except FileNotFoundError:
|
|
print(
|
|
f" [오류] 프롬프트 파일을 찾을 수 없습니다: {prompt_file_path}"
|
|
)
|
|
if (
|
|
"prompt_file_opened" in locals()
|
|
and not prompt_file_opened.closed
|
|
):
|
|
prompt_file_opened.close()
|
|
return
|
|
|
|
# 동기 요청이므로 작업이 끝날 때까지 기다립니다. 타임아웃을 넉넉하게 설정합니다.
|
|
response = requests.post(
|
|
api_endpoint, headers=headers, files=files, data=data, timeout=300
|
|
) # 5분 타임아웃
|
|
|
|
if "prompt_file" in files and not files["prompt_file"][1].closed:
|
|
files["prompt_file"][1].close()
|
|
|
|
response.raise_for_status()
|
|
|
|
result_data = response.json()
|
|
|
|
# 서버로부터 받은 최종 결과를 파일에 저장합니다.
|
|
os.makedirs(os.path.dirname(result_file_path), exist_ok=True)
|
|
with open(result_file_path, "w", encoding="utf-8") as f:
|
|
json.dump(result_data, f, ensure_ascii=False, indent=4)
|
|
|
|
print(f" [성공] 결과 저장 완료: {os.path.basename(result_file_path)}")
|
|
|
|
except requests.exceptions.Timeout:
|
|
print(" [오류] 요청 시간이 초과되었습니다 (Timeout).")
|
|
except requests.exceptions.RequestException as e:
|
|
print(f" [오류] API 요청 중 문제가 발생했습니다: {e}")
|
|
except Exception as e:
|
|
print(f" [오류] 알 수 없는 오류가 발생했습니다: {e}")
|
|
|
|
|
|
def main(model_name, prompt_file_path, endpoint_name):
|
|
"""데이터 디렉토리를 순회하며 지원하는 각 파일에 대해 API 요청을 보냅니다."""
|
|
api_endpoint = f"http://localhost:8889/costs/{endpoint_name}/sync"
|
|
base_data_directory = "ocr_results"
|
|
target_dirs = ["pp-ocr", "pp-structure", "upstage"]
|
|
total_file_count = 0
|
|
skipped_file_count = 0
|
|
|
|
for dir_name in target_dirs:
|
|
data_directory = os.path.join(base_data_directory, dir_name)
|
|
result_dir = os.path.join("..", "result", f"{endpoint_name}-{dir_name}")
|
|
|
|
print(f"대상 디렉토리: {data_directory}")
|
|
print(f"결과 확인 디렉토리: {os.path.abspath(result_dir)}")
|
|
print("-" * 30)
|
|
|
|
if not os.path.isdir(data_directory):
|
|
print(f"[오류] 디렉토리를 찾을 수 없습니다: {data_directory}")
|
|
continue
|
|
|
|
file_count = 0
|
|
for root, _, files in os.walk(data_directory):
|
|
for file in files:
|
|
if file.lower().endswith(SUPPORTED_EXTENSIONS):
|
|
result_file_path = os.path.join(result_dir, file)
|
|
if os.path.exists(result_file_path):
|
|
print(f" [건너뛰기] 이미 결과가 존재합니다: {file}")
|
|
skipped_file_count += 1
|
|
continue
|
|
|
|
file_path = os.path.join(root, file)
|
|
send_sync_request(
|
|
api_endpoint,
|
|
file_path,
|
|
model_name,
|
|
prompt_file_path,
|
|
dir_name,
|
|
result_file_path,
|
|
)
|
|
# 서버 부하 감소를 위한 최소한의 대기 시간
|
|
print(" [대기] 다음 요청까지 2초 대기...")
|
|
time.sleep(2)
|
|
file_count += 1
|
|
|
|
total_file_count += file_count
|
|
print(f"'{dir_name}' 디렉토리의 {file_count}개 파일 요청 완료.")
|
|
|
|
print("-" * 30)
|
|
print(
|
|
f"총 {total_file_count}개의 파일에 대한 요청을 완료했으며, {skipped_file_count}개를 건너뛰었습니다."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="멀티모달 API에 파일들을 동기적으로 요청하는 스크립트"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--endpoint",
|
|
type=str,
|
|
default="gemma3",
|
|
help="호출할 API 엔드포인트 이름 (예: gemini, gemma3)",
|
|
dest="endpoint_name",
|
|
)
|
|
parser.add_argument("--model", type=str, help="사용할 LLM 모델 이름")
|
|
parser.add_argument(
|
|
"--prompt_file",
|
|
type=str,
|
|
help="구조화에 사용할 커스텀 .txt 프롬프트 파일의 경로",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
main(args.model, args.prompt_file, args.endpoint_name)
|