334 lines
14 KiB
Python
334 lines
14 KiB
Python
"""
|
|
YOLO-World + SAM3 반자동 레이블링 파이프라인
|
|
=============================================
|
|
YOLO-World로 bbox 검출 → SAM3로 polygon mask 생성 → 시각화 저장
|
|
|
|
사용법:
|
|
python tools/yoloworld_sam3_pipeline.py --input sample/rail --output output/labeled
|
|
python tools/yoloworld_sam3_pipeline.py --input sample/rail/frame_00000.jpg
|
|
|
|
SAM3 서버가 실행 중이어야 합니다:
|
|
cd X-AnyLabeling-Server && uvicorn app.main:app --host 0.0.0.0 --port 8000
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
|
|
# ── 검출 대상 클래스 (YOLO-World 텍스트 프롬프트) ──────────────────────────
|
|
TARGET_CLASSES = [
|
|
"catenary pole", # 전철주 (세로 기둥)
|
|
"catenary arm", # 전철주 (가로 암)
|
|
"junction box", # 통신/전기 박스
|
|
"fence", # 펜스
|
|
]
|
|
|
|
# 클래스별 색상 (BGR)
|
|
CLASS_COLORS = {
|
|
"catenary pole": (0, 200, 255), # 주황
|
|
"catenary arm": (0, 100, 255), # 빨강
|
|
"junction box": (255, 180, 0), # 파랑
|
|
"fence": (0, 255, 100), # 초록
|
|
}
|
|
|
|
SAM3_SERVER = "http://localhost:8000"
|
|
MODEL_ID = "segment_anything_3"
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# YOLO-World 초기화
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
def load_yolo_world(model_size: str = "s"):
|
|
"""YOLO-World 모델 로드 (자동 다운로드)."""
|
|
from ultralytics import YOLOWorld
|
|
|
|
model_name = f"yolov8{model_size}-worldv2.pt"
|
|
print(f"[YOLO-World] 모델 로드: {model_name}")
|
|
model = YOLOWorld(model_name)
|
|
model.set_classes(TARGET_CLASSES)
|
|
print(f"[YOLO-World] 검출 클래스: {TARGET_CLASSES}")
|
|
return model
|
|
|
|
|
|
def detect_with_yoloworld(model, image_bgr: np.ndarray, conf: float = 0.15):
|
|
"""YOLO-World로 bbox 검출. [(x1,y1,x2,y2,conf,class_name), ...] 반환."""
|
|
results = model.predict(image_bgr, conf=conf, verbose=False)
|
|
detections = []
|
|
if results and len(results) > 0:
|
|
r = results[0]
|
|
boxes = r.boxes
|
|
if boxes is not None and len(boxes) > 0:
|
|
for box in boxes:
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
|
score = float(box.conf[0].cpu())
|
|
cls_idx = int(box.cls[0].cpu())
|
|
cls_name = TARGET_CLASSES[cls_idx] if cls_idx < len(TARGET_CLASSES) else f"class_{cls_idx}"
|
|
detections.append((float(x1), float(y1), float(x2), float(y2), score, cls_name))
|
|
return detections
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# SAM3 서버 호출
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
def encode_image(image_bgr: np.ndarray) -> str:
|
|
"""이미지를 base64 문자열로 인코딩."""
|
|
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
return base64.b64encode(buf).decode("utf-8")
|
|
|
|
|
|
def sam3_segment(image_bgr: np.ndarray, boxes: list, conf_threshold: float = 0.25):
|
|
"""SAM3 서버에 bbox 전달 → polygon masks 반환.
|
|
|
|
Args:
|
|
image_bgr: 원본 이미지
|
|
boxes: [(x1,y1,x2,y2,score,class_name), ...]
|
|
conf_threshold: SAM3 신뢰도 임계값
|
|
|
|
Returns:
|
|
shapes: [{"label": str, "points": [[x,y],...], "score": float}, ...]
|
|
"""
|
|
marks = [
|
|
{
|
|
"type": "rectangle",
|
|
"label": 1,
|
|
"data": [b[0], b[1], b[2], b[3]],
|
|
}
|
|
for b in boxes
|
|
]
|
|
|
|
payload = {
|
|
"model": MODEL_ID,
|
|
"image": encode_image(image_bgr),
|
|
"params": {
|
|
"marks": marks,
|
|
"show_masks": True,
|
|
"show_boxes": False,
|
|
"conf_threshold": conf_threshold,
|
|
"epsilon_factor": 0.002,
|
|
},
|
|
}
|
|
|
|
try:
|
|
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
except requests.exceptions.ConnectionError:
|
|
print(f" [ERROR] SAM3 서버에 연결할 수 없습니다: {SAM3_SERVER}")
|
|
return []
|
|
except Exception as e:
|
|
print(f" [ERROR] SAM3 호출 실패: {e}")
|
|
return []
|
|
|
|
if data.get("status") != "success":
|
|
print(f" [ERROR] SAM3 응답 오류: {data}")
|
|
return []
|
|
|
|
return data.get("data", {}).get("shapes", [])
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# 시각화
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
def draw_results(image_bgr: np.ndarray, detections: list, shapes: list) -> np.ndarray:
|
|
"""bbox + mask를 이미지에 그리기."""
|
|
vis = image_bgr.copy()
|
|
overlay = image_bgr.copy()
|
|
|
|
# SAM3 polygon masks 그리기
|
|
for shape in shapes:
|
|
pts = np.array(shape["points"], dtype=np.int32)
|
|
# 첫점=끝점이면 마지막 제거
|
|
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
|
|
pts = pts[:-1]
|
|
label = shape.get("label", "unknown")
|
|
color = CLASS_COLORS.get(label, (200, 200, 200))
|
|
cv2.fillPoly(overlay, [pts], color)
|
|
cv2.polylines(vis, [pts], True, color, 2)
|
|
|
|
cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis)
|
|
|
|
# YOLO-World bbox + 라벨 그리기
|
|
for x1, y1, x2, y2, score, cls_name in detections:
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
color = CLASS_COLORS.get(cls_name, (200, 200, 200))
|
|
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
|
|
label_text = f"{cls_name} {score:.2f}"
|
|
(tw, th), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
|
cv2.rectangle(vis, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1)
|
|
cv2.putText(vis, label_text, (x1 + 2, y1 - 4),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
|
|
|
|
# 범례
|
|
legend_y = 20
|
|
for cls_name, color in CLASS_COLORS.items():
|
|
cv2.rectangle(vis, (10, legend_y), (25, legend_y + 15), color, -1)
|
|
cv2.putText(vis, cls_name, (30, legend_y + 13),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
legend_y += 22
|
|
|
|
return vis
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# 단일 이미지 처리
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
def process_image(yolo_model, image_path: Path, output_dir: Path,
|
|
yolo_conf: float, sam3_conf: float, skip_sam3: bool) -> dict:
|
|
"""이미지 1장 처리. 결과 딕셔너리 반환."""
|
|
image_bgr = cv2.imread(str(image_path))
|
|
if image_bgr is None:
|
|
print(f" [SKIP] 이미지 읽기 실패: {image_path}")
|
|
return {}
|
|
|
|
h, w = image_bgr.shape[:2]
|
|
print(f"\n 이미지: {image_path.name} ({w}x{h})")
|
|
|
|
# ── Step 1: YOLO-World 검출 ────────────────────────────────────────────
|
|
detections = detect_with_yoloworld(yolo_model, image_bgr, conf=yolo_conf)
|
|
print(f" YOLO-World: {len(detections)}개 검출")
|
|
for d in detections:
|
|
print(f" [{d[5]}] conf={d[4]:.3f} bbox=({d[0]:.0f},{d[1]:.0f},{d[2]:.0f},{d[3]:.0f})")
|
|
|
|
shapes = []
|
|
if not skip_sam3 and detections:
|
|
# ── Step 2: SAM3 mask 생성 ─────────────────────────────────────────
|
|
print(f" SAM3: {len(detections)}개 bbox → mask 요청 중...")
|
|
shapes = sam3_segment(image_bgr, detections, conf_threshold=sam3_conf)
|
|
print(f" SAM3: {len(shapes)}개 mask 반환")
|
|
for s in shapes:
|
|
pts_count = len(s.get("points", []))
|
|
print(f" [{s.get('label')}] score={s.get('score', 0):.3f} points={pts_count}")
|
|
|
|
# ── Step 3: 시각화 저장 ───────────────────────────────────────────────
|
|
vis = draw_results(image_bgr, detections, shapes)
|
|
output_path = output_dir / f"{image_path.stem}_result.jpg"
|
|
cv2.imwrite(str(output_path), vis)
|
|
print(f" 저장: {output_path}")
|
|
|
|
return {
|
|
"image": image_path.name,
|
|
"detections": [
|
|
{"class": d[5], "conf": round(d[4], 3),
|
|
"bbox": [round(d[0]), round(d[1]), round(d[2]), round(d[3])]}
|
|
for d in detections
|
|
],
|
|
"masks": len(shapes),
|
|
}
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# 메인
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="YOLO-World + SAM3 파이프라인")
|
|
parser.add_argument("--input", default="sample/rail",
|
|
help="이미지 파일 또는 폴더 경로")
|
|
parser.add_argument("--output", default="output/yoloworld_sam3",
|
|
help="결과 저장 폴더")
|
|
parser.add_argument("--model-size", default="s", choices=["s", "m", "l", "x"],
|
|
help="YOLO-World 모델 크기 (s/m/l/x)")
|
|
parser.add_argument("--yolo-conf", type=float, default=0.10,
|
|
help="YOLO-World 검출 임계값 (기본 0.10)")
|
|
parser.add_argument("--sam3-conf", type=float, default=0.20,
|
|
help="SAM3 마스크 임계값 (기본 0.20)")
|
|
parser.add_argument("--skip-sam3", action="store_true",
|
|
help="SAM3 건너뛰고 YOLO-World bbox만 시각화")
|
|
parser.add_argument("--server", default="http://localhost:8000",
|
|
help="SAM3 서버 주소")
|
|
args = parser.parse_args()
|
|
|
|
global SAM3_SERVER
|
|
SAM3_SERVER = args.server
|
|
|
|
# SAM3 서버 상태 확인
|
|
if not args.skip_sam3:
|
|
try:
|
|
resp = requests.get(f"{SAM3_SERVER}/health", timeout=5)
|
|
if resp.status_code == 200:
|
|
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
|
|
else:
|
|
print(f"[WARN] SAM3 서버 응답 이상 (status={resp.status_code})")
|
|
except Exception:
|
|
print(f"[WARN] SAM3 서버 연결 실패 ({SAM3_SERVER}). --skip-sam3 로 bbox만 볼 수 있음.")
|
|
ans = input("계속 진행하시겠습니까? (y/N): ").strip().lower()
|
|
if ans != "y":
|
|
sys.exit(0)
|
|
|
|
# 입력 경로 처리
|
|
input_path = Path(args.input)
|
|
if input_path.is_file():
|
|
image_files = [input_path]
|
|
elif input_path.is_dir():
|
|
image_files = sorted(
|
|
list(input_path.glob("*.jpg")) +
|
|
list(input_path.glob("*.jpeg")) +
|
|
list(input_path.glob("*.png"))
|
|
)
|
|
else:
|
|
print(f"[ERROR] 입력 경로를 찾을 수 없습니다: {input_path}")
|
|
sys.exit(1)
|
|
|
|
if not image_files:
|
|
print(f"[ERROR] 이미지 파일 없음: {input_path}")
|
|
sys.exit(1)
|
|
|
|
print(f"\n총 {len(image_files)}개 이미지 처리 예정")
|
|
|
|
# 출력 폴더 생성
|
|
output_dir = Path(args.output)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# YOLO-World 로드
|
|
yolo_model = load_yolo_world(args.model_size)
|
|
|
|
# 처리
|
|
summary = []
|
|
for img_path in image_files:
|
|
result = process_image(
|
|
yolo_model, img_path, output_dir,
|
|
yolo_conf=args.yolo_conf,
|
|
sam3_conf=args.sam3_conf,
|
|
skip_sam3=args.skip_sam3,
|
|
)
|
|
if result:
|
|
summary.append(result)
|
|
|
|
# 요약 저장
|
|
summary_path = output_dir / "summary.json"
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
# 통계 출력
|
|
print("\n" + "="*50)
|
|
print("처리 완료 요약")
|
|
print("="*50)
|
|
total_det = sum(len(r["detections"]) for r in summary)
|
|
total_mask = sum(r["masks"] for r in summary)
|
|
print(f"처리 이미지: {len(summary)}장")
|
|
print(f"YOLO 검출: {total_det}개 (평균 {total_det/max(len(summary),1):.1f}/장)")
|
|
print(f"SAM3 마스크: {total_mask}개 (평균 {total_mask/max(len(summary),1):.1f}/장)")
|
|
|
|
# 클래스별 집계
|
|
class_counts: dict = {}
|
|
for r in summary:
|
|
for d in r["detections"]:
|
|
cls = d["class"]
|
|
class_counts[cls] = class_counts.get(cls, 0) + 1
|
|
if class_counts:
|
|
print("\n클래스별 검출 수:")
|
|
for cls, cnt in sorted(class_counts.items(), key=lambda x: -x[1]):
|
|
print(f" {cls:20s}: {cnt}개")
|
|
|
|
print(f"\n결과 저장: {output_dir.resolve()}")
|
|
print(f"요약 JSON: {summary_path.resolve()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|