Files
railway-client/tools/sam3_everything_explore.py
minsung 4c15d5ff5d sam31server 전환, 라멘 파이프라인 정리, 문서 추가
- sam31server를 SAM3.1 서버로 전환 (x-anylabeling01 대체)
- detect_raamen.py: B/C 분류 기반 라멘형 전철주 검출 파이프라인 정비
- sam3_everything_explore.py: Discovery Sweep 탐색 모드 정리
- detect_all_objects.py: 타일 검출 개선
- docs/railway-client-guide.html: 서버·도구·파이프라인 전체 가이드 추가
- tools 추가: detect_control_box, group_ramen_poles, render_everything_by_label, render_label_polygons, debug_vh

Closes #1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-02 10:11:52 +09:00

318 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SAM3.1 탐색 모드 (Discovery Sweep)
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
사용법:
python tools/sam3_everything_explore.py \\
--input "data/역사구간/1.회덕역/..." \\
--cols 8 --rows 6
사전 조건: SAM3.1 서버 실행 (start_server.bat)
"""
import argparse
import base64
import json
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
DISCOVERY_PROMPT = (
"railroad track, railway rail, "
"catenary pole, overhead line pole, electric pole, "
"overhead wire, catenary wire, power line cable, "
"railway sleeper, concrete tie, "
"guardrail, highway barrier, road fence, "
"bridge, viaduct, overpass, "
"vegetation, tree, bush, grass, "
"building, structure, roof, wall, "
"vehicle, car, truck, "
"road, asphalt, pavement, "
"slope, embankment, retaining wall, "
"noise barrier, sound wall, "
"signal, sign board, "
"small dark object on ballast, small dark object on railway, "
"small square metal box on ground, control box on ballast, "
"gray square lid on gravel, flat metal cover on ground, "
"small bright object on ballast, small white box on ballast, "
"small gray box on ground, bright square object on gravel"
)
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
b64, scale = encode_image(tile_bgr)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {
"text_prompt": prompt,
"conf_threshold": conf,
"show_masks": True,
"show_boxes": False,
},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=300)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception as e:
print(f" [SAM3 오류] {e}")
return []
# ── NMS ───────────────────────────────────────────────────────────────────────
def _bbox(pts):
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
return min(xs), min(ys), max(xs), max(ys)
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
if not shapes:
return []
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]; keep.append(i)
if len(order) == 1: break
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [shapes[i] for i in keep]
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt,
zone=None):
"""zone=(x1,y1,x2,y2) 지정 시 겹치는 타일만 처리."""
H, W = image_bgr.shape[:2]
base_w = W / cols
base_h = H / rows
pad_x = int(base_w * overlap)
pad_y = int(base_h * overlap)
def overlaps_zone(tx0, ty0, tx1, ty1):
if zone is None:
return True
zx1, zy1, zx2, zy2 = zone
return tx0 < zx2 and tx1 > zx1 and ty0 < zy2 and ty1 > zy1
tiles = []
skipped = 0
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
if overlaps_zone(x0, y0, x1, y1):
tiles.append((idx, x0, y0, x1, y1))
else:
skipped += 1
total = len(tiles)
if skipped:
print(f"zone 필터: {skipped}타일 스킵, {total}타일 처리")
done = [0]
all_shapes = []
def process(args):
idx, x0, y0, x1, y1 = args
tile = image_bgr[y0:y1, x0:x1]
shapes = sam3_everything(tile, conf, prompt)
for s in shapes:
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
return shapes
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(process, t): t for t in tiles}
for fut in as_completed(futs):
result = fut.result()
all_shapes.extend(result)
done[0] += 1
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}", end="\r")
print()
return all_shapes
# ── 시각화 ────────────────────────────────────────────────────────────────────
def draw_everything(image_bgr, shapes, cols, rows):
vis = image_bgr.copy()
H, W = vis.shape[:2]
# 타일 경계
for r in range(rows):
for c in range(cols):
bx0, by0 = int(c * W / cols), int(r * H / rows)
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
rng = np.random.default_rng(42)
for s in shapes:
pts = np.array(s["points"], dtype=np.int32)
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# 라벨 표시 (있을 경우)
label = s.get("label", "")
if label:
cx = int(np.mean([p[0] for p in s["points"]]))
cy = int(np.mean([p[1] for p in s["points"]]))
cv2.putText(vis, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
cv2.putText(vis, f"total segments: {len(shapes)}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
return vis
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
def analyze_labels(shapes):
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
if not labels:
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
return
counter = Counter(labels)
print(f"\n{''*50}")
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
print(f"{''*50}")
for label, cnt in counter.most_common(30):
bar = "#" * min(cnt, 40)
print(f" {label:35s} {cnt:4d} {bar}")
print(f"{''*50}")
top_labels = [lb for lb, _ in counter.most_common(10)]
print(f"\n[text_prompt 후보]")
print(f' "{", ".join(top_labels)}"')
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--input", required=True, help="입력 이미지")
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
ap.add_argument("--zone", type=int, nargs=4, metavar=("X1","Y1","X2","Y2"), default=None,
help="처리 zone 제한 (이 범위와 겹치는 타일만 처리)")
args = ap.parse_args()
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f"이미지 로드 실패: {img_path}"); return
H, W = image_bgr.shape[:2]
print(f"이미지 : {W}×{H}")
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}")
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
import time
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
for item in prompt.split(","):
print(f" · {item.strip()}")
print()
zone = tuple(args.zone) if args.zone else None
if zone:
print(f"zone 제한: x={zone[0]}~{zone[2]} y={zone[1]}~{zone[3]}\n")
t0 = time.time()
shapes = detect_everything_tiled(
image_bgr, args.cols, args.rows, args.overlap,
args.conf, args.workers, prompt, zone=zone
)
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
shapes = nms_shapes(shapes, iou_thresh=args.nms)
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
analyze_labels(shapes)
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w*s), int(h*s)))
out_path = (Path(args.output) if args.output
else img_path.parent / (img_path.stem + "_everything.jpg"))
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
print(f"\n저장: {out_path}")
# JSON으로 라벨 데이터도 저장 (분석용)
json_path = out_path.with_suffix(".json")
label_data = {
"total_segments": len(shapes),
"label_counts": dict(Counter(
s.get("label", "(no label)") for s in shapes
)),
"segments": [
{"label": s.get("label",""), "score": s.get("score",0),
"bbox": list(_bbox(s["points"])),
"points": s["points"]}
for s in shapes
]
}
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"라벨 데이터: {json_path}")
if __name__ == "__main__":
main()