Files
railway-client/tools/sam3_everything_explore.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

292 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SAM3.1 탐색 모드 (Discovery Sweep)
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
사용법:
python tools/sam3_everything_explore.py \\
--input "data/역사구간/1.회덕역/..." \\
--cols 8 --rows 6
사전 조건: SAM3.1 서버 실행 (start_server.bat)
"""
import argparse
import base64
import json
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
DISCOVERY_PROMPT = (
"railroad track, railway rail, "
"catenary pole, overhead line pole, electric pole, "
"overhead wire, catenary wire, power line cable, "
"railway sleeper, concrete tie, "
"guardrail, highway barrier, road fence, "
"bridge, viaduct, overpass, "
"vegetation, tree, bush, grass, "
"building, structure, roof, wall, "
"vehicle, car, truck, "
"road, asphalt, pavement, "
"slope, embankment, retaining wall, "
"noise barrier, sound wall, "
"signal, sign board"
)
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
b64, scale = encode_image(tile_bgr)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {
"text_prompt": prompt,
"conf_threshold": conf,
"show_masks": True,
"show_boxes": False,
},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception as e:
print(f" [SAM3 오류] {e}")
return []
# ── NMS ───────────────────────────────────────────────────────────────────────
def _bbox(pts):
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
return min(xs), min(ys), max(xs), max(ys)
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
if not shapes:
return []
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]; keep.append(i)
if len(order) == 1: break
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [shapes[i] for i in keep]
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt):
H, W = image_bgr.shape[:2]
base_w = W / cols
base_h = H / rows
pad_x = int(base_w * overlap)
pad_y = int(base_h * overlap)
tiles = []
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
tiles.append((idx, x0, y0, x1, y1))
total = len(tiles)
done = [0]
all_shapes = []
def process(args):
idx, x0, y0, x1, y1 = args
tile = image_bgr[y0:y1, x0:x1]
shapes = sam3_everything(tile, conf, prompt)
for s in shapes:
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
return shapes
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(process, t): t for t in tiles}
for fut in as_completed(futs):
result = fut.result()
all_shapes.extend(result)
done[0] += 1
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}", end="\r")
print()
return all_shapes
# ── 시각화 ────────────────────────────────────────────────────────────────────
def draw_everything(image_bgr, shapes, cols, rows):
vis = image_bgr.copy()
H, W = vis.shape[:2]
# 타일 경계
for r in range(rows):
for c in range(cols):
bx0, by0 = int(c * W / cols), int(r * H / rows)
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
rng = np.random.default_rng(42)
for s in shapes:
pts = np.array(s["points"], dtype=np.int32)
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# 라벨 표시 (있을 경우)
label = s.get("label", "")
if label:
cx = int(np.mean([p[0] for p in s["points"]]))
cy = int(np.mean([p[1] for p in s["points"]]))
cv2.putText(vis, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
cv2.putText(vis, f"total segments: {len(shapes)}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
return vis
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
def analyze_labels(shapes):
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
if not labels:
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
return
counter = Counter(labels)
print(f"\n{''*50}")
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
print(f"{''*50}")
for label, cnt in counter.most_common(30):
bar = "#" * min(cnt, 40)
print(f" {label:35s} {cnt:4d} {bar}")
print(f"{''*50}")
top_labels = [lb for lb, _ in counter.most_common(10)]
print(f"\n[text_prompt 후보]")
print(f' "{", ".join(top_labels)}"')
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--input", required=True, help="입력 이미지")
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
args = ap.parse_args()
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f"이미지 로드 실패: {img_path}"); return
H, W = image_bgr.shape[:2]
print(f"이미지 : {W}×{H}")
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}")
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
import time
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
for item in prompt.split(","):
print(f" · {item.strip()}")
print()
t0 = time.time()
shapes = detect_everything_tiled(
image_bgr, args.cols, args.rows, args.overlap,
args.conf, args.workers, prompt
)
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
shapes = nms_shapes(shapes, iou_thresh=args.nms)
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
analyze_labels(shapes)
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w*s), int(h*s)))
out_path = (Path(args.output) if args.output
else img_path.parent / (img_path.stem + "_everything.jpg"))
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
print(f"\n저장: {out_path}")
# JSON으로 라벨 데이터도 저장 (분석용)
json_path = out_path.with_suffix(".json")
label_data = {
"total_segments": len(shapes),
"label_counts": dict(Counter(
s.get("label", "(no label)") for s in shapes
)),
"segments": [
{"label": s.get("label",""), "score": s.get("score",0),
"bbox": list(_bbox(s["points"]))}
for s in shapes
]
}
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"라벨 데이터: {json_path}")
if __name__ == "__main__":
main()