278
tools/sam3_batch_label.py
Normal file
278
tools/sam3_batch_label.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
SAM3 배치 자동 레이블링 파이프라인
|
||||
===================================
|
||||
SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 →
|
||||
X-AnyLabeling 호환 JSON annotation 파일 자동 생성
|
||||
|
||||
사용법:
|
||||
# 기본 (sample/rail 폴더, 서버 localhost:8000)
|
||||
python tools/sam3_batch_label.py
|
||||
|
||||
# 폴더 지정
|
||||
python tools/sam3_batch_label.py --input sample/rail --output output/labels
|
||||
|
||||
# conf 조정 (낮출수록 더 많이 검출, 오탐도 증가)
|
||||
python tools/sam3_batch_label.py --conf 0.20
|
||||
|
||||
SAM3 서버 실행:
|
||||
cd X-AnyLabeling-Server
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
MODEL_ID = "segment_anything_3"
|
||||
|
||||
# 검출 대상 + 한국어 레이블 매핑
|
||||
TARGETS = {
|
||||
"pole": "전철주_세로",
|
||||
"catenary arm": "전철주_가로",
|
||||
"junction box": "통신박스",
|
||||
"electrical box": "전기박스",
|
||||
"fence": "펜스",
|
||||
}
|
||||
|
||||
# 시각화 색상 (BGR)
|
||||
COLORS = {
|
||||
"전철주_세로": (0, 200, 255),
|
||||
"전철주_가로": (0, 100, 255),
|
||||
"통신박스": (255, 180, 0),
|
||||
"전기박스": (100, 255, 200),
|
||||
"펜스": (0, 255, 100),
|
||||
}
|
||||
|
||||
|
||||
def encode_image(image_bgr: np.ndarray) -> str:
|
||||
_, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도
|
||||
return base64.b64encode(buf).decode("utf-8")
|
||||
|
||||
|
||||
def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list:
|
||||
"""SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환."""
|
||||
payload = {
|
||||
"model": MODEL_ID,
|
||||
"image": encode_image(image_bgr),
|
||||
"params": {
|
||||
"text_prompt": text_prompt,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
"conf_threshold": conf,
|
||||
"epsilon_factor": 0.002,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if data.get("success"):
|
||||
return data.get("data", {}).get("shapes", [])
|
||||
except Exception as e:
|
||||
print(f" [ERROR] SAM3 호출 실패: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict:
|
||||
"""이미지 1장: 모든 클래스 검출 → annotation dict 반환."""
|
||||
# 한글 경로 대응: np.fromfile + imdecode
|
||||
buf = np.fromfile(str(image_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
return {}
|
||||
h, w = image_bgr.shape[:2]
|
||||
|
||||
all_shapes = []
|
||||
class_counts = {}
|
||||
|
||||
for eng_label, kor_label in TARGETS.items():
|
||||
shapes = sam3_text_predict(image_bgr, eng_label, conf)
|
||||
# label 필드를 한국어로 교체
|
||||
for s in shapes:
|
||||
s["label"] = kor_label
|
||||
all_shapes.extend(shapes)
|
||||
if shapes:
|
||||
class_counts[kor_label] = len(shapes)
|
||||
|
||||
# X-AnyLabeling JSON 형식
|
||||
annotation = {
|
||||
"version": "3.3.9",
|
||||
"flags": {},
|
||||
"shapes": [
|
||||
{
|
||||
"label": s["label"],
|
||||
"points": s["points"],
|
||||
"group_id": None,
|
||||
"shape_type": s.get("shape_type", "polygon"),
|
||||
"flags": {},
|
||||
"score": round(float(s.get("score", 0)), 4),
|
||||
}
|
||||
for s in all_shapes
|
||||
],
|
||||
"imagePath": image_path.name,
|
||||
"imageData": None,
|
||||
"imageHeight": h,
|
||||
"imageWidth": w,
|
||||
}
|
||||
|
||||
# 시각화
|
||||
if vis_dir is not None:
|
||||
vis = draw_vis(image_bgr, all_shapes)
|
||||
vis_path = vis_dir / f"{image_path.stem}_vis.jpg"
|
||||
cv2.imencode(".jpg", vis)[1].tofile(str(vis_path))
|
||||
|
||||
return annotation, class_counts
|
||||
|
||||
|
||||
def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray:
|
||||
vis = image_bgr.copy()
|
||||
overlay = image_bgr.copy()
|
||||
|
||||
for shape in shapes:
|
||||
pts = np.array(shape["points"], dtype=np.int32)
|
||||
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
|
||||
pts = pts[:-1]
|
||||
label = shape.get("label", "unknown")
|
||||
color = COLORS.get(label, (200, 200, 200))
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.polylines(vis, [pts], True, color, 2)
|
||||
|
||||
# 레이블 텍스트
|
||||
cx = int(pts[:, 0].mean())
|
||||
cy = int(pts[:, 1].mean())
|
||||
cv2.putText(vis, label, (cx - 30, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
|
||||
|
||||
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
|
||||
|
||||
# 범례
|
||||
y = 25
|
||||
for kor, color in COLORS.items():
|
||||
cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1)
|
||||
cv2.putText(vis, kor, (28, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
y += 20
|
||||
|
||||
return vis
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", default="sample/rail")
|
||||
parser.add_argument("--output", default="output/sam3_labels",
|
||||
help="JSON annotation 저장 폴더")
|
||||
parser.add_argument("--vis", default="output/sam3_vis",
|
||||
help="시각화 이미지 저장 폴더 ('' 로 비활성화)")
|
||||
parser.add_argument("--conf", type=float, default=0.20,
|
||||
help="SAM3 confidence threshold (기본 0.20)")
|
||||
parser.add_argument("--classes", nargs="+",
|
||||
help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm")
|
||||
parser.add_argument("--server", default="http://localhost:8000")
|
||||
args = parser.parse_args()
|
||||
|
||||
global SAM3_SERVER, TARGETS
|
||||
SAM3_SERVER = args.server
|
||||
|
||||
if args.classes:
|
||||
key_map = {
|
||||
"catenary_pole": ("catenary pole", "전철주_세로"),
|
||||
"concrete_pole": ("concrete pole", "전철주_세로"),
|
||||
"catenary_arm": ("catenary arm", "전철주_가로"),
|
||||
"junction_box": ("junction box", "통신박스"),
|
||||
"electrical_box": ("electrical box", "전기박스"),
|
||||
"fence": ("fence", "펜스"),
|
||||
}
|
||||
TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map}
|
||||
|
||||
# 서버 확인
|
||||
try:
|
||||
r = requests.get(f"{SAM3_SERVER}/health", timeout=5)
|
||||
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
|
||||
except Exception:
|
||||
print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}")
|
||||
print(" 서버를 먼저 실행하세요:")
|
||||
print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000")
|
||||
sys.exit(1)
|
||||
|
||||
# 입력
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
images = [input_path]
|
||||
elif input_path.is_dir():
|
||||
images = sorted(
|
||||
list(input_path.glob("*.jpg")) +
|
||||
list(input_path.glob("*.jpeg")) +
|
||||
list(input_path.glob("*.png"))
|
||||
)
|
||||
else:
|
||||
print(f"[ERROR] 경로 없음: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not images:
|
||||
print(f"[ERROR] 이미지 없음: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
out_dir = Path(args.output)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
vis_dir = None
|
||||
if args.vis:
|
||||
vis_dir = Path(args.vis)
|
||||
vis_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n처리 대상: {len(images)}장")
|
||||
print(f"검출 클래스: {list(TARGETS.values())}")
|
||||
print(f"SAM3 conf: {args.conf}")
|
||||
print(f"annotation 저장: {out_dir}")
|
||||
if vis_dir:
|
||||
print(f"시각화 저장: {vis_dir}")
|
||||
print()
|
||||
|
||||
total_counts: dict = {v: 0 for v in TARGETS.values()}
|
||||
processed = 0
|
||||
|
||||
for img_path in images:
|
||||
print(f"[{processed+1}/{len(images)}] {img_path.name}")
|
||||
result = process_image(img_path, args.conf, vis_dir)
|
||||
if not result:
|
||||
print(f" [SKIP] 처리 실패")
|
||||
continue
|
||||
|
||||
annotation, class_counts = result
|
||||
n = len(annotation["shapes"])
|
||||
print(f" → {n}개 객체 검출: {class_counts if class_counts else '없음'}")
|
||||
|
||||
# JSON 저장
|
||||
json_path = out_dir / f"{img_path.stem}.json"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotation, f, ensure_ascii=False, indent=2)
|
||||
|
||||
for k, v in class_counts.items():
|
||||
total_counts[k] = total_counts.get(k, 0) + v
|
||||
processed += 1
|
||||
|
||||
# 요약
|
||||
print("\n" + "="*50)
|
||||
print(f"완료: {processed}/{len(images)}장 처리")
|
||||
print("\n클래스별 총 검출 수:")
|
||||
for cls, cnt in total_counts.items():
|
||||
avg = cnt / max(processed, 1)
|
||||
bar = "#" * min(cnt, 30)
|
||||
print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}")
|
||||
print(f"\nJSON 저장: {out_dir.resolve()}")
|
||||
if vis_dir:
|
||||
print(f"시각화: {vis_dir.resolve()}")
|
||||
print("\n다음 단계:")
|
||||
print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택")
|
||||
print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user