279 lines
8.9 KiB
Python
279 lines
8.9 KiB
Python
"""
|
|
SAM3 배치 자동 레이블링 파이프라인
|
|
===================================
|
|
SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 →
|
|
X-AnyLabeling 호환 JSON annotation 파일 자동 생성
|
|
|
|
사용법:
|
|
# 기본 (sample/rail 폴더, 서버 localhost:8000)
|
|
python tools/sam3_batch_label.py
|
|
|
|
# 폴더 지정
|
|
python tools/sam3_batch_label.py --input sample/rail --output output/labels
|
|
|
|
# conf 조정 (낮출수록 더 많이 검출, 오탐도 증가)
|
|
python tools/sam3_batch_label.py --conf 0.20
|
|
|
|
SAM3 서버 실행:
|
|
cd X-AnyLabeling-Server
|
|
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
|
|
SAM3_SERVER = "http://localhost:8000"
|
|
MODEL_ID = "segment_anything_3"
|
|
|
|
# 검출 대상 + 한국어 레이블 매핑
|
|
TARGETS = {
|
|
"pole": "전철주_세로",
|
|
"catenary arm": "전철주_가로",
|
|
"junction box": "통신박스",
|
|
"electrical box": "전기박스",
|
|
"fence": "펜스",
|
|
}
|
|
|
|
# 시각화 색상 (BGR)
|
|
COLORS = {
|
|
"전철주_세로": (0, 200, 255),
|
|
"전철주_가로": (0, 100, 255),
|
|
"통신박스": (255, 180, 0),
|
|
"전기박스": (100, 255, 200),
|
|
"펜스": (0, 255, 100),
|
|
}
|
|
|
|
|
|
def encode_image(image_bgr: np.ndarray) -> str:
|
|
_, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도
|
|
return base64.b64encode(buf).decode("utf-8")
|
|
|
|
|
|
def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list:
|
|
"""SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환."""
|
|
payload = {
|
|
"model": MODEL_ID,
|
|
"image": encode_image(image_bgr),
|
|
"params": {
|
|
"text_prompt": text_prompt,
|
|
"show_masks": True,
|
|
"show_boxes": False,
|
|
"conf_threshold": conf,
|
|
"epsilon_factor": 0.002,
|
|
},
|
|
}
|
|
try:
|
|
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
if data.get("success"):
|
|
return data.get("data", {}).get("shapes", [])
|
|
except Exception as e:
|
|
print(f" [ERROR] SAM3 호출 실패: {e}")
|
|
return []
|
|
|
|
|
|
def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict:
|
|
"""이미지 1장: 모든 클래스 검출 → annotation dict 반환."""
|
|
# 한글 경로 대응: np.fromfile + imdecode
|
|
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
|
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
|
if image_bgr is None:
|
|
return {}
|
|
h, w = image_bgr.shape[:2]
|
|
|
|
all_shapes = []
|
|
class_counts = {}
|
|
|
|
for eng_label, kor_label in TARGETS.items():
|
|
shapes = sam3_text_predict(image_bgr, eng_label, conf)
|
|
# label 필드를 한국어로 교체
|
|
for s in shapes:
|
|
s["label"] = kor_label
|
|
all_shapes.extend(shapes)
|
|
if shapes:
|
|
class_counts[kor_label] = len(shapes)
|
|
|
|
# X-AnyLabeling JSON 형식
|
|
annotation = {
|
|
"version": "3.3.9",
|
|
"flags": {},
|
|
"shapes": [
|
|
{
|
|
"label": s["label"],
|
|
"points": s["points"],
|
|
"group_id": None,
|
|
"shape_type": s.get("shape_type", "polygon"),
|
|
"flags": {},
|
|
"score": round(float(s.get("score", 0)), 4),
|
|
}
|
|
for s in all_shapes
|
|
],
|
|
"imagePath": image_path.name,
|
|
"imageData": None,
|
|
"imageHeight": h,
|
|
"imageWidth": w,
|
|
}
|
|
|
|
# 시각화
|
|
if vis_dir is not None:
|
|
vis = draw_vis(image_bgr, all_shapes)
|
|
vis_path = vis_dir / f"{image_path.stem}_vis.jpg"
|
|
cv2.imencode(".jpg", vis)[1].tofile(str(vis_path))
|
|
|
|
return annotation, class_counts
|
|
|
|
|
|
def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray:
|
|
vis = image_bgr.copy()
|
|
overlay = image_bgr.copy()
|
|
|
|
for shape in shapes:
|
|
pts = np.array(shape["points"], dtype=np.int32)
|
|
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
|
|
pts = pts[:-1]
|
|
label = shape.get("label", "unknown")
|
|
color = COLORS.get(label, (200, 200, 200))
|
|
cv2.fillPoly(overlay, [pts], color)
|
|
cv2.polylines(vis, [pts], True, color, 2)
|
|
|
|
# 레이블 텍스트
|
|
cx = int(pts[:, 0].mean())
|
|
cy = int(pts[:, 1].mean())
|
|
cv2.putText(vis, label, (cx - 30, cy),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
|
|
|
|
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
|
|
|
|
# 범례
|
|
y = 25
|
|
for kor, color in COLORS.items():
|
|
cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1)
|
|
cv2.putText(vis, kor, (28, y),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
y += 20
|
|
|
|
return vis
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input", default="sample/rail")
|
|
parser.add_argument("--output", default="output/sam3_labels",
|
|
help="JSON annotation 저장 폴더")
|
|
parser.add_argument("--vis", default="output/sam3_vis",
|
|
help="시각화 이미지 저장 폴더 ('' 로 비활성화)")
|
|
parser.add_argument("--conf", type=float, default=0.20,
|
|
help="SAM3 confidence threshold (기본 0.20)")
|
|
parser.add_argument("--classes", nargs="+",
|
|
help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm")
|
|
parser.add_argument("--server", default="http://localhost:8000")
|
|
args = parser.parse_args()
|
|
|
|
global SAM3_SERVER, TARGETS
|
|
SAM3_SERVER = args.server
|
|
|
|
if args.classes:
|
|
key_map = {
|
|
"catenary_pole": ("catenary pole", "전철주_세로"),
|
|
"concrete_pole": ("concrete pole", "전철주_세로"),
|
|
"catenary_arm": ("catenary arm", "전철주_가로"),
|
|
"junction_box": ("junction box", "통신박스"),
|
|
"electrical_box": ("electrical box", "전기박스"),
|
|
"fence": ("fence", "펜스"),
|
|
}
|
|
TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map}
|
|
|
|
# 서버 확인
|
|
try:
|
|
r = requests.get(f"{SAM3_SERVER}/health", timeout=5)
|
|
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
|
|
except Exception:
|
|
print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}")
|
|
print(" 서버를 먼저 실행하세요:")
|
|
print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000")
|
|
sys.exit(1)
|
|
|
|
# 입력
|
|
input_path = Path(args.input)
|
|
if input_path.is_file():
|
|
images = [input_path]
|
|
elif input_path.is_dir():
|
|
images = sorted(
|
|
list(input_path.glob("*.jpg")) +
|
|
list(input_path.glob("*.jpeg")) +
|
|
list(input_path.glob("*.png"))
|
|
)
|
|
else:
|
|
print(f"[ERROR] 경로 없음: {input_path}")
|
|
sys.exit(1)
|
|
|
|
if not images:
|
|
print(f"[ERROR] 이미지 없음: {input_path}")
|
|
sys.exit(1)
|
|
|
|
out_dir = Path(args.output)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
vis_dir = None
|
|
if args.vis:
|
|
vis_dir = Path(args.vis)
|
|
vis_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"\n처리 대상: {len(images)}장")
|
|
print(f"검출 클래스: {list(TARGETS.values())}")
|
|
print(f"SAM3 conf: {args.conf}")
|
|
print(f"annotation 저장: {out_dir}")
|
|
if vis_dir:
|
|
print(f"시각화 저장: {vis_dir}")
|
|
print()
|
|
|
|
total_counts: dict = {v: 0 for v in TARGETS.values()}
|
|
processed = 0
|
|
|
|
for img_path in images:
|
|
print(f"[{processed+1}/{len(images)}] {img_path.name}")
|
|
result = process_image(img_path, args.conf, vis_dir)
|
|
if not result:
|
|
print(f" [SKIP] 처리 실패")
|
|
continue
|
|
|
|
annotation, class_counts = result
|
|
n = len(annotation["shapes"])
|
|
print(f" → {n}개 객체 검출: {class_counts if class_counts else '없음'}")
|
|
|
|
# JSON 저장
|
|
json_path = out_dir / f"{img_path.stem}.json"
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump(annotation, f, ensure_ascii=False, indent=2)
|
|
|
|
for k, v in class_counts.items():
|
|
total_counts[k] = total_counts.get(k, 0) + v
|
|
processed += 1
|
|
|
|
# 요약
|
|
print("\n" + "="*50)
|
|
print(f"완료: {processed}/{len(images)}장 처리")
|
|
print("\n클래스별 총 검출 수:")
|
|
for cls, cnt in total_counts.items():
|
|
avg = cnt / max(processed, 1)
|
|
bar = "#" * min(cnt, 30)
|
|
print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}")
|
|
print(f"\nJSON 저장: {out_dir.resolve()}")
|
|
if vis_dir:
|
|
print(f"시각화: {vis_dir.resolve()}")
|
|
print("\n다음 단계:")
|
|
print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택")
|
|
print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|