291
tools/sam3_everything_explore.py
Normal file
291
tools/sam3_everything_explore.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
SAM3.1 탐색 모드 (Discovery Sweep)
|
||||
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
|
||||
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
|
||||
|
||||
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
|
||||
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
|
||||
|
||||
사용법:
|
||||
python tools/sam3_everything_explore.py \\
|
||||
--input "data/역사구간/1.회덕역/..." \\
|
||||
--cols 8 --rows 6
|
||||
|
||||
사전 조건: SAM3.1 서버 실행 (start_server.bat)
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
SAM3_SERVER = "http://localhost:8000"
|
||||
SAM3_MODEL_ID = "segment_anything_3"
|
||||
|
||||
|
||||
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
|
||||
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
|
||||
h, w = image_bgr.shape[:2]
|
||||
scale = 1.0
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
|
||||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
return base64.b64encode(buf).decode("utf-8"), scale
|
||||
|
||||
|
||||
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
|
||||
DISCOVERY_PROMPT = (
|
||||
"railroad track, railway rail, "
|
||||
"catenary pole, overhead line pole, electric pole, "
|
||||
"overhead wire, catenary wire, power line cable, "
|
||||
"railway sleeper, concrete tie, "
|
||||
"guardrail, highway barrier, road fence, "
|
||||
"bridge, viaduct, overpass, "
|
||||
"vegetation, tree, bush, grass, "
|
||||
"building, structure, roof, wall, "
|
||||
"vehicle, car, truck, "
|
||||
"road, asphalt, pavement, "
|
||||
"slope, embankment, retaining wall, "
|
||||
"noise barrier, sound wall, "
|
||||
"signal, sign board"
|
||||
)
|
||||
|
||||
|
||||
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
|
||||
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
|
||||
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
|
||||
b64, scale = encode_image(tile_bgr)
|
||||
payload = {
|
||||
"model": SAM3_MODEL_ID,
|
||||
"image": b64,
|
||||
"params": {
|
||||
"text_prompt": prompt,
|
||||
"conf_threshold": conf,
|
||||
"show_masks": True,
|
||||
"show_boxes": False,
|
||||
},
|
||||
}
|
||||
try:
|
||||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if not resp.get("success"):
|
||||
return []
|
||||
shapes = resp.get("data", {}).get("shapes", [])
|
||||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||||
if scale < 1.0:
|
||||
inv = 1.0 / scale
|
||||
for s in shapes:
|
||||
if s.get("shape_type") == "polygon":
|
||||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||||
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
||||
except Exception as e:
|
||||
print(f" [SAM3 오류] {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ── NMS ───────────────────────────────────────────────────────────────────────
|
||||
def _bbox(pts):
|
||||
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
|
||||
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
|
||||
if not shapes:
|
||||
return []
|
||||
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
|
||||
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while len(order):
|
||||
i = order[0]; keep.append(i)
|
||||
if len(order) == 1: break
|
||||
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
|
||||
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
|
||||
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
|
||||
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
|
||||
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
|
||||
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
|
||||
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
|
||||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||||
order = order[1:][iou < iou_thresh]
|
||||
return [shapes[i] for i in keep]
|
||||
|
||||
|
||||
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
|
||||
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt):
|
||||
H, W = image_bgr.shape[:2]
|
||||
base_w = W / cols
|
||||
base_h = H / rows
|
||||
pad_x = int(base_w * overlap)
|
||||
pad_y = int(base_h * overlap)
|
||||
|
||||
tiles = []
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
idx = r * cols + c + 1
|
||||
x0 = max(0, int(c * base_w) - pad_x)
|
||||
x1 = min(W, int((c + 1) * base_w) + pad_x)
|
||||
y0 = max(0, int(r * base_h) - pad_y)
|
||||
y1 = min(H, int((r + 1) * base_h) + pad_y)
|
||||
tiles.append((idx, x0, y0, x1, y1))
|
||||
|
||||
total = len(tiles)
|
||||
done = [0]
|
||||
all_shapes = []
|
||||
|
||||
def process(args):
|
||||
idx, x0, y0, x1, y1 = args
|
||||
tile = image_bgr[y0:y1, x0:x1]
|
||||
shapes = sam3_everything(tile, conf, prompt)
|
||||
for s in shapes:
|
||||
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
|
||||
return shapes
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
futs = {ex.submit(process, t): t for t in tiles}
|
||||
for fut in as_completed(futs):
|
||||
result = fut.result()
|
||||
all_shapes.extend(result)
|
||||
done[0] += 1
|
||||
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r")
|
||||
print()
|
||||
return all_shapes
|
||||
|
||||
|
||||
# ── 시각화 ────────────────────────────────────────────────────────────────────
|
||||
def draw_everything(image_bgr, shapes, cols, rows):
|
||||
vis = image_bgr.copy()
|
||||
H, W = vis.shape[:2]
|
||||
|
||||
# 타일 경계
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
bx0, by0 = int(c * W / cols), int(r * H / rows)
|
||||
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
|
||||
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
for s in shapes:
|
||||
pts = np.array(s["points"], dtype=np.int32)
|
||||
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
|
||||
cv2.polylines(vis, [pts], True, color, 1)
|
||||
|
||||
# 라벨 표시 (있을 경우)
|
||||
label = s.get("label", "")
|
||||
if label:
|
||||
cx = int(np.mean([p[0] for p in s["points"]]))
|
||||
cy = int(np.mean([p[1] for p in s["points"]]))
|
||||
cv2.putText(vis, label, (cx, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
|
||||
|
||||
cv2.putText(vis, f"total segments: {len(shapes)}",
|
||||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
|
||||
return vis
|
||||
|
||||
|
||||
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
|
||||
def analyze_labels(shapes):
|
||||
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
|
||||
if not labels:
|
||||
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
|
||||
return
|
||||
|
||||
counter = Counter(labels)
|
||||
print(f"\n{'─'*50}")
|
||||
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
|
||||
print(f"{'─'*50}")
|
||||
for label, cnt in counter.most_common(30):
|
||||
bar = "#" * min(cnt, 40)
|
||||
print(f" {label:35s} {cnt:4d} {bar}")
|
||||
print(f"{'─'*50}")
|
||||
|
||||
top_labels = [lb for lb, _ in counter.most_common(10)]
|
||||
print(f"\n[text_prompt 후보]")
|
||||
print(f' "{", ".join(top_labels)}"')
|
||||
|
||||
|
||||
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
ap.add_argument("--input", required=True, help="입력 이미지")
|
||||
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
|
||||
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
|
||||
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
|
||||
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
|
||||
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
|
||||
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
|
||||
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
|
||||
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
|
||||
args = ap.parse_args()
|
||||
|
||||
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
|
||||
|
||||
img_path = Path(args.input)
|
||||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if image_bgr is None:
|
||||
print(f"이미지 로드 실패: {img_path}"); return
|
||||
|
||||
H, W = image_bgr.shape[:2]
|
||||
print(f"이미지 : {W}×{H}")
|
||||
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}개")
|
||||
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
|
||||
|
||||
import time
|
||||
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
|
||||
for item in prompt.split(","):
|
||||
print(f" · {item.strip()}")
|
||||
print()
|
||||
|
||||
t0 = time.time()
|
||||
shapes = detect_everything_tiled(
|
||||
image_bgr, args.cols, args.rows, args.overlap,
|
||||
args.conf, args.workers, prompt
|
||||
)
|
||||
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
|
||||
shapes = nms_shapes(shapes, iou_thresh=args.nms)
|
||||
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
|
||||
|
||||
analyze_labels(shapes)
|
||||
|
||||
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
|
||||
h, w = vis.shape[:2]
|
||||
if max(h, w) > 4096:
|
||||
s = 4096 / max(h, w)
|
||||
vis = cv2.resize(vis, (int(w*s), int(h*s)))
|
||||
|
||||
out_path = (Path(args.output) if args.output
|
||||
else img_path.parent / (img_path.stem + "_everything.jpg"))
|
||||
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
|
||||
print(f"\n저장: {out_path}")
|
||||
|
||||
# JSON으로 라벨 데이터도 저장 (분석용)
|
||||
json_path = out_path.with_suffix(".json")
|
||||
label_data = {
|
||||
"total_segments": len(shapes),
|
||||
"label_counts": dict(Counter(
|
||||
s.get("label", "(no label)") for s in shapes
|
||||
)),
|
||||
"segments": [
|
||||
{"label": s.get("label",""), "score": s.get("score",0),
|
||||
"bbox": list(_bbox(s["points"]))}
|
||||
for s in shapes
|
||||
]
|
||||
}
|
||||
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"라벨 데이터: {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user