Files
railway-client/tools/sam3_batch_label.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

279 lines
8.9 KiB
Python

"""
SAM3 배치 자동 레이블링 파이프라인
===================================
SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 →
X-AnyLabeling 호환 JSON annotation 파일 자동 생성
사용법:
# 기본 (sample/rail 폴더, 서버 localhost:8000)
python tools/sam3_batch_label.py
# 폴더 지정
python tools/sam3_batch_label.py --input sample/rail --output output/labels
# conf 조정 (낮출수록 더 많이 검출, 오탐도 증가)
python tools/sam3_batch_label.py --conf 0.20
SAM3 서버 실행:
cd X-AnyLabeling-Server
uvicorn app.main:app --host 0.0.0.0 --port 8000
"""
import argparse
import base64
import json
import sys
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
MODEL_ID = "segment_anything_3"
# 검출 대상 + 한국어 레이블 매핑
TARGETS = {
"pole": "전철주_세로",
"catenary arm": "전철주_가로",
"junction box": "통신박스",
"electrical box": "전기박스",
"fence": "펜스",
}
# 시각화 색상 (BGR)
COLORS = {
"전철주_세로": (0, 200, 255),
"전철주_가로": (0, 100, 255),
"통신박스": (255, 180, 0),
"전기박스": (100, 255, 200),
"펜스": (0, 255, 100),
}
def encode_image(image_bgr: np.ndarray) -> str:
_, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도
return base64.b64encode(buf).decode("utf-8")
def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list:
"""SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환."""
payload = {
"model": MODEL_ID,
"image": encode_image(image_bgr),
"params": {
"text_prompt": text_prompt,
"show_masks": True,
"show_boxes": False,
"conf_threshold": conf,
"epsilon_factor": 0.002,
},
}
try:
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
if data.get("success"):
return data.get("data", {}).get("shapes", [])
except Exception as e:
print(f" [ERROR] SAM3 호출 실패: {e}")
return []
def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict:
"""이미지 1장: 모든 클래스 검출 → annotation dict 반환."""
# 한글 경로 대응: np.fromfile + imdecode
buf = np.fromfile(str(image_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
return {}
h, w = image_bgr.shape[:2]
all_shapes = []
class_counts = {}
for eng_label, kor_label in TARGETS.items():
shapes = sam3_text_predict(image_bgr, eng_label, conf)
# label 필드를 한국어로 교체
for s in shapes:
s["label"] = kor_label
all_shapes.extend(shapes)
if shapes:
class_counts[kor_label] = len(shapes)
# X-AnyLabeling JSON 형식
annotation = {
"version": "3.3.9",
"flags": {},
"shapes": [
{
"label": s["label"],
"points": s["points"],
"group_id": None,
"shape_type": s.get("shape_type", "polygon"),
"flags": {},
"score": round(float(s.get("score", 0)), 4),
}
for s in all_shapes
],
"imagePath": image_path.name,
"imageData": None,
"imageHeight": h,
"imageWidth": w,
}
# 시각화
if vis_dir is not None:
vis = draw_vis(image_bgr, all_shapes)
vis_path = vis_dir / f"{image_path.stem}_vis.jpg"
cv2.imencode(".jpg", vis)[1].tofile(str(vis_path))
return annotation, class_counts
def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray:
vis = image_bgr.copy()
overlay = image_bgr.copy()
for shape in shapes:
pts = np.array(shape["points"], dtype=np.int32)
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
pts = pts[:-1]
label = shape.get("label", "unknown")
color = COLORS.get(label, (200, 200, 200))
cv2.fillPoly(overlay, [pts], color)
cv2.polylines(vis, [pts], True, color, 2)
# 레이블 텍스트
cx = int(pts[:, 0].mean())
cy = int(pts[:, 1].mean())
cv2.putText(vis, label, (cx - 30, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
# 범례
y = 25
for kor, color in COLORS.items():
cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1)
cv2.putText(vis, kor, (28, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
y += 20
return vis
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="sample/rail")
parser.add_argument("--output", default="output/sam3_labels",
help="JSON annotation 저장 폴더")
parser.add_argument("--vis", default="output/sam3_vis",
help="시각화 이미지 저장 폴더 ('' 로 비활성화)")
parser.add_argument("--conf", type=float, default=0.20,
help="SAM3 confidence threshold (기본 0.20)")
parser.add_argument("--classes", nargs="+",
help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm")
parser.add_argument("--server", default="http://localhost:8000")
args = parser.parse_args()
global SAM3_SERVER, TARGETS
SAM3_SERVER = args.server
if args.classes:
key_map = {
"catenary_pole": ("catenary pole", "전철주_세로"),
"concrete_pole": ("concrete pole", "전철주_세로"),
"catenary_arm": ("catenary arm", "전철주_가로"),
"junction_box": ("junction box", "통신박스"),
"electrical_box": ("electrical box", "전기박스"),
"fence": ("fence", "펜스"),
}
TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map}
# 서버 확인
try:
r = requests.get(f"{SAM3_SERVER}/health", timeout=5)
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
except Exception:
print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}")
print(" 서버를 먼저 실행하세요:")
print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000")
sys.exit(1)
# 입력
input_path = Path(args.input)
if input_path.is_file():
images = [input_path]
elif input_path.is_dir():
images = sorted(
list(input_path.glob("*.jpg")) +
list(input_path.glob("*.jpeg")) +
list(input_path.glob("*.png"))
)
else:
print(f"[ERROR] 경로 없음: {input_path}")
sys.exit(1)
if not images:
print(f"[ERROR] 이미지 없음: {input_path}")
sys.exit(1)
out_dir = Path(args.output)
out_dir.mkdir(parents=True, exist_ok=True)
vis_dir = None
if args.vis:
vis_dir = Path(args.vis)
vis_dir.mkdir(parents=True, exist_ok=True)
print(f"\n처리 대상: {len(images)}")
print(f"검출 클래스: {list(TARGETS.values())}")
print(f"SAM3 conf: {args.conf}")
print(f"annotation 저장: {out_dir}")
if vis_dir:
print(f"시각화 저장: {vis_dir}")
print()
total_counts: dict = {v: 0 for v in TARGETS.values()}
processed = 0
for img_path in images:
print(f"[{processed+1}/{len(images)}] {img_path.name}")
result = process_image(img_path, args.conf, vis_dir)
if not result:
print(f" [SKIP] 처리 실패")
continue
annotation, class_counts = result
n = len(annotation["shapes"])
print(f"{n}개 객체 검출: {class_counts if class_counts else '없음'}")
# JSON 저장
json_path = out_dir / f"{img_path.stem}.json"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotation, f, ensure_ascii=False, indent=2)
for k, v in class_counts.items():
total_counts[k] = total_counts.get(k, 0) + v
processed += 1
# 요약
print("\n" + "="*50)
print(f"완료: {processed}/{len(images)}장 처리")
print("\n클래스별 총 검출 수:")
for cls, cnt in total_counts.items():
avg = cnt / max(processed, 1)
bar = "#" * min(cnt, 30)
print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}")
print(f"\nJSON 저장: {out_dir.resolve()}")
if vis_dir:
print(f"시각화: {vis_dir.resolve()}")
print("\n다음 단계:")
print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택")
print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습")
if __name__ == "__main__":
main()