""" SAM3.1 탐색 모드 (Discovery Sweep) 이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 → 나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정. 이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음. 대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다. 사용법: python tools/sam3_everything_explore.py \\ --input "data/역사구간/1.회덕역/..." \\ --cols 8 --rows 6 사전 조건: SAM3.1 서버 실행 (start_server.bat) """ import argparse import base64 import json from collections import Counter from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import cv2 import numpy as np import requests SAM3_SERVER = "http://localhost:8000" SAM3_MODEL_ID = "segment_anything_3" # ── 이미지 인코딩 ───────────────────────────────────────────────────────────── def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple: h, w = image_bgr.shape[:2] scale = 1.0 if max(h, w) > max_size: scale = max_size / max(h, w) image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale))) _, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) return base64.b64encode(buf).decode("utf-8"), scale # 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함 DISCOVERY_PROMPT = ( "railroad track, railway rail, " "catenary pole, overhead line pole, electric pole, " "overhead wire, catenary wire, power line cable, " "railway sleeper, concrete tie, " "guardrail, highway barrier, road fence, " "bridge, viaduct, overpass, " "vegetation, tree, bush, grass, " "building, structure, roof, wall, " "vehicle, car, truck, " "road, asphalt, pavement, " "slope, embankment, retaining wall, " "noise barrier, sound wall, " "signal, sign board, " "small dark object on ballast, small dark object on railway, " "small square metal box on ground, control box on ballast, " "gray square lid on gravel, flat metal cover on ground, " "small bright object on ballast, small white box on ballast, " "small gray box on ground, bright square object on gravel" ) # ── SAM3.1 discovery sweep 호출 ─────────────────────────────────────────────── def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list: """넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트.""" b64, scale = encode_image(tile_bgr) payload = { "model": SAM3_MODEL_ID, "image": b64, "params": { "text_prompt": prompt, "conf_threshold": conf, "show_masks": True, "show_boxes": False, }, } try: r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=300) r.raise_for_status() resp = r.json() if not resp.get("success"): return [] shapes = resp.get("data", {}).get("shapes", []) shapes = [s if isinstance(s, dict) else s.dict() for s in shapes] if scale < 1.0: inv = 1.0 / scale for s in shapes: if s.get("shape_type") == "polygon": s["points"] = [[x * inv, y * inv] for x, y in s["points"]] return [s for s in shapes if s.get("shape_type") == "polygon"] except Exception as e: print(f" [SAM3 오류] {e}") return [] # ── NMS ─────────────────────────────────────────────────────────────────────── def _bbox(pts): xs = [p[0] for p in pts]; ys = [p[1] for p in pts] return min(xs), min(ys), max(xs), max(ys) def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list: if not shapes: return [] bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32) scores = np.array([float(s.get("score", 0.5)) for s in shapes]) order = scores.argsort()[::-1] keep = [] while len(order): i = order[0]; keep.append(i) if len(order) == 1: break xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0]) yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1]) xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2]) yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3]) inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1) a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1]) a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1]) iou = inter / (a_i + a_j - inter + 1e-6) order = order[1:][iou < iou_thresh] return [shapes[i] for i in keep] # ── 타일 분할 + 병렬 검출 ───────────────────────────────────────────────────── def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt, zone=None): """zone=(x1,y1,x2,y2) 지정 시 겹치는 타일만 처리.""" H, W = image_bgr.shape[:2] base_w = W / cols base_h = H / rows pad_x = int(base_w * overlap) pad_y = int(base_h * overlap) def overlaps_zone(tx0, ty0, tx1, ty1): if zone is None: return True zx1, zy1, zx2, zy2 = zone return tx0 < zx2 and tx1 > zx1 and ty0 < zy2 and ty1 > zy1 tiles = [] skipped = 0 for r in range(rows): for c in range(cols): idx = r * cols + c + 1 x0 = max(0, int(c * base_w) - pad_x) x1 = min(W, int((c + 1) * base_w) + pad_x) y0 = max(0, int(r * base_h) - pad_y) y1 = min(H, int((r + 1) * base_h) + pad_y) if overlaps_zone(x0, y0, x1, y1): tiles.append((idx, x0, y0, x1, y1)) else: skipped += 1 total = len(tiles) if skipped: print(f"zone 필터: {skipped}타일 스킵, {total}타일 처리") done = [0] all_shapes = [] def process(args): idx, x0, y0, x1, y1 = args tile = image_bgr[y0:y1, x0:x1] shapes = sam3_everything(tile, conf, prompt) for s in shapes: s["points"] = [[px + x0, py + y0] for px, py in s["points"]] return shapes with ThreadPoolExecutor(max_workers=workers) as ex: futs = {ex.submit(process, t): t for t in tiles} for fut in as_completed(futs): result = fut.result() all_shapes.extend(result) done[0] += 1 print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r") print() return all_shapes # ── 시각화 ──────────────────────────────────────────────────────────────────── def draw_everything(image_bgr, shapes, cols, rows): vis = image_bgr.copy() H, W = vis.shape[:2] # 타일 경계 for r in range(rows): for c in range(cols): bx0, by0 = int(c * W / cols), int(r * H / rows) bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows) cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1) rng = np.random.default_rng(42) for s in shapes: pts = np.array(s["points"], dtype=np.int32) color = tuple(int(v) for v in rng.integers(80, 255, size=3)) overlay = vis.copy() cv2.fillPoly(overlay, [pts], color) cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis) cv2.polylines(vis, [pts], True, color, 1) # 라벨 표시 (있을 경우) label = s.get("label", "") if label: cx = int(np.mean([p[0] for p in s["points"]])) cy = int(np.mean([p[1] for p in s["points"]])) cv2.putText(vis, label, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA) cv2.putText(vis, f"total segments: {len(shapes)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2) return vis # ── 라벨 분석 → text_prompt 후보 출력 ──────────────────────────────────────── def analyze_labels(shapes): labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()] if not labels: print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.") return counter = Counter(labels) print(f"\n{'─'*50}") print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)") print(f"{'─'*50}") for label, cnt in counter.most_common(30): bar = "#" * min(cnt, 40) print(f" {label:35s} {cnt:4d} {bar}") print(f"{'─'*50}") top_labels = [lb for lb, _ in counter.most_common(10)] print(f"\n[text_prompt 후보]") print(f' "{", ".join(top_labels)}"') # ── 메인 ───────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--input", required=True, help="입력 이미지") ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)") ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)") ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)") ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)") ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)") ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)") ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)") ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)") ap.add_argument("--zone", type=int, nargs=4, metavar=("X1","Y1","X2","Y2"), default=None, help="처리 zone 제한 (이 범위와 겹치는 타일만 처리)") args = ap.parse_args() prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "") img_path = Path(args.input) buf = np.fromfile(str(img_path), dtype=np.uint8) image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) if image_bgr is None: print(f"이미지 로드 실패: {img_path}"); return H, W = image_bgr.shape[:2] print(f"이미지 : {W}×{H}") print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}개") print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n") import time print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):") for item in prompt.split(","): print(f" · {item.strip()}") print() zone = tuple(args.zone) if args.zone else None if zone: print(f"zone 제한: x={zone[0]}~{zone[2]} y={zone[1]}~{zone[3]}\n") t0 = time.time() shapes = detect_everything_tiled( image_bgr, args.cols, args.rows, args.overlap, args.conf, args.workers, prompt, zone=zone ) print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...") shapes = nms_shapes(shapes, iou_thresh=args.nms) print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n") analyze_labels(shapes) vis = draw_everything(image_bgr, shapes, args.cols, args.rows) h, w = vis.shape[:2] if max(h, w) > 4096: s = 4096 / max(h, w) vis = cv2.resize(vis, (int(w*s), int(h*s))) out_path = (Path(args.output) if args.output else img_path.parent / (img_path.stem + "_everything.jpg")) cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path)) print(f"\n저장: {out_path}") # JSON으로 라벨 데이터도 저장 (분석용) json_path = out_path.with_suffix(".json") label_data = { "total_segments": len(shapes), "label_counts": dict(Counter( s.get("label", "(no label)") for s in shapes )), "segments": [ {"label": s.get("label",""), "score": s.get("score",0), "bbox": list(_bbox(s["points"])), "points": s["points"]} for s in shapes ] } json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8") print(f"라벨 데이터: {json_path}") if __name__ == "__main__": main()