- sam31server를 SAM3.1 서버로 전환 (x-anylabeling01 대체) - detect_raamen.py: B/C 분류 기반 라멘형 전철주 검출 파이프라인 정비 - sam3_everything_explore.py: Discovery Sweep 탐색 모드 정리 - detect_all_objects.py: 타일 검출 개선 - docs/railway-client-guide.html: 서버·도구·파이프라인 전체 가이드 추가 - tools 추가: detect_control_box, group_ramen_poles, render_everything_by_label, render_label_polygons, debug_vh Closes #1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
318 lines
13 KiB
Python
318 lines
13 KiB
Python
"""
|
||
SAM3.1 탐색 모드 (Discovery Sweep)
|
||
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
|
||
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
|
||
|
||
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
|
||
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
|
||
|
||
사용법:
|
||
python tools/sam3_everything_explore.py \\
|
||
--input "data/역사구간/1.회덕역/..." \\
|
||
--cols 8 --rows 6
|
||
|
||
사전 조건: SAM3.1 서버 실행 (start_server.bat)
|
||
"""
|
||
import argparse
|
||
import base64
|
||
import json
|
||
from collections import Counter
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import requests
|
||
|
||
SAM3_SERVER = "http://localhost:8000"
|
||
SAM3_MODEL_ID = "segment_anything_3"
|
||
|
||
|
||
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
|
||
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
|
||
h, w = image_bgr.shape[:2]
|
||
scale = 1.0
|
||
if max(h, w) > max_size:
|
||
scale = max_size / max(h, w)
|
||
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
|
||
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||
return base64.b64encode(buf).decode("utf-8"), scale
|
||
|
||
|
||
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
|
||
DISCOVERY_PROMPT = (
|
||
"railroad track, railway rail, "
|
||
"catenary pole, overhead line pole, electric pole, "
|
||
"overhead wire, catenary wire, power line cable, "
|
||
"railway sleeper, concrete tie, "
|
||
"guardrail, highway barrier, road fence, "
|
||
"bridge, viaduct, overpass, "
|
||
"vegetation, tree, bush, grass, "
|
||
"building, structure, roof, wall, "
|
||
"vehicle, car, truck, "
|
||
"road, asphalt, pavement, "
|
||
"slope, embankment, retaining wall, "
|
||
"noise barrier, sound wall, "
|
||
"signal, sign board, "
|
||
"small dark object on ballast, small dark object on railway, "
|
||
"small square metal box on ground, control box on ballast, "
|
||
"gray square lid on gravel, flat metal cover on ground, "
|
||
"small bright object on ballast, small white box on ballast, "
|
||
"small gray box on ground, bright square object on gravel"
|
||
)
|
||
|
||
|
||
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
|
||
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
|
||
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
|
||
b64, scale = encode_image(tile_bgr)
|
||
payload = {
|
||
"model": SAM3_MODEL_ID,
|
||
"image": b64,
|
||
"params": {
|
||
"text_prompt": prompt,
|
||
"conf_threshold": conf,
|
||
"show_masks": True,
|
||
"show_boxes": False,
|
||
},
|
||
}
|
||
try:
|
||
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=300)
|
||
r.raise_for_status()
|
||
resp = r.json()
|
||
if not resp.get("success"):
|
||
return []
|
||
shapes = resp.get("data", {}).get("shapes", [])
|
||
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
|
||
if scale < 1.0:
|
||
inv = 1.0 / scale
|
||
for s in shapes:
|
||
if s.get("shape_type") == "polygon":
|
||
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
|
||
return [s for s in shapes if s.get("shape_type") == "polygon"]
|
||
except Exception as e:
|
||
print(f" [SAM3 오류] {e}")
|
||
return []
|
||
|
||
|
||
# ── NMS ───────────────────────────────────────────────────────────────────────
|
||
def _bbox(pts):
|
||
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
|
||
return min(xs), min(ys), max(xs), max(ys)
|
||
|
||
|
||
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
|
||
if not shapes:
|
||
return []
|
||
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
|
||
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
|
||
order = scores.argsort()[::-1]
|
||
keep = []
|
||
while len(order):
|
||
i = order[0]; keep.append(i)
|
||
if len(order) == 1: break
|
||
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
|
||
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
|
||
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
|
||
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
|
||
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
|
||
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
|
||
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
|
||
iou = inter / (a_i + a_j - inter + 1e-6)
|
||
order = order[1:][iou < iou_thresh]
|
||
return [shapes[i] for i in keep]
|
||
|
||
|
||
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
|
||
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt,
|
||
zone=None):
|
||
"""zone=(x1,y1,x2,y2) 지정 시 겹치는 타일만 처리."""
|
||
H, W = image_bgr.shape[:2]
|
||
base_w = W / cols
|
||
base_h = H / rows
|
||
pad_x = int(base_w * overlap)
|
||
pad_y = int(base_h * overlap)
|
||
|
||
def overlaps_zone(tx0, ty0, tx1, ty1):
|
||
if zone is None:
|
||
return True
|
||
zx1, zy1, zx2, zy2 = zone
|
||
return tx0 < zx2 and tx1 > zx1 and ty0 < zy2 and ty1 > zy1
|
||
|
||
tiles = []
|
||
skipped = 0
|
||
for r in range(rows):
|
||
for c in range(cols):
|
||
idx = r * cols + c + 1
|
||
x0 = max(0, int(c * base_w) - pad_x)
|
||
x1 = min(W, int((c + 1) * base_w) + pad_x)
|
||
y0 = max(0, int(r * base_h) - pad_y)
|
||
y1 = min(H, int((r + 1) * base_h) + pad_y)
|
||
if overlaps_zone(x0, y0, x1, y1):
|
||
tiles.append((idx, x0, y0, x1, y1))
|
||
else:
|
||
skipped += 1
|
||
|
||
total = len(tiles)
|
||
if skipped:
|
||
print(f"zone 필터: {skipped}타일 스킵, {total}타일 처리")
|
||
done = [0]
|
||
all_shapes = []
|
||
|
||
def process(args):
|
||
idx, x0, y0, x1, y1 = args
|
||
tile = image_bgr[y0:y1, x0:x1]
|
||
shapes = sam3_everything(tile, conf, prompt)
|
||
for s in shapes:
|
||
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
|
||
return shapes
|
||
|
||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||
futs = {ex.submit(process, t): t for t in tiles}
|
||
for fut in as_completed(futs):
|
||
result = fut.result()
|
||
all_shapes.extend(result)
|
||
done[0] += 1
|
||
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r")
|
||
print()
|
||
return all_shapes
|
||
|
||
|
||
# ── 시각화 ────────────────────────────────────────────────────────────────────
|
||
def draw_everything(image_bgr, shapes, cols, rows):
|
||
vis = image_bgr.copy()
|
||
H, W = vis.shape[:2]
|
||
|
||
# 타일 경계
|
||
for r in range(rows):
|
||
for c in range(cols):
|
||
bx0, by0 = int(c * W / cols), int(r * H / rows)
|
||
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
|
||
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
|
||
|
||
rng = np.random.default_rng(42)
|
||
for s in shapes:
|
||
pts = np.array(s["points"], dtype=np.int32)
|
||
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
|
||
overlay = vis.copy()
|
||
cv2.fillPoly(overlay, [pts], color)
|
||
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
|
||
cv2.polylines(vis, [pts], True, color, 1)
|
||
|
||
# 라벨 표시 (있을 경우)
|
||
label = s.get("label", "")
|
||
if label:
|
||
cx = int(np.mean([p[0] for p in s["points"]]))
|
||
cy = int(np.mean([p[1] for p in s["points"]]))
|
||
cv2.putText(vis, label, (cx, cy),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
|
||
|
||
cv2.putText(vis, f"total segments: {len(shapes)}",
|
||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
|
||
return vis
|
||
|
||
|
||
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
|
||
def analyze_labels(shapes):
|
||
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
|
||
if not labels:
|
||
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
|
||
return
|
||
|
||
counter = Counter(labels)
|
||
print(f"\n{'─'*50}")
|
||
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
|
||
print(f"{'─'*50}")
|
||
for label, cnt in counter.most_common(30):
|
||
bar = "#" * min(cnt, 40)
|
||
print(f" {label:35s} {cnt:4d} {bar}")
|
||
print(f"{'─'*50}")
|
||
|
||
top_labels = [lb for lb, _ in counter.most_common(10)]
|
||
print(f"\n[text_prompt 후보]")
|
||
print(f' "{", ".join(top_labels)}"')
|
||
|
||
|
||
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
||
def main():
|
||
ap = argparse.ArgumentParser(description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
ap.add_argument("--input", required=True, help="입력 이미지")
|
||
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
|
||
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
|
||
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
|
||
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
|
||
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
|
||
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
|
||
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
|
||
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
|
||
ap.add_argument("--zone", type=int, nargs=4, metavar=("X1","Y1","X2","Y2"), default=None,
|
||
help="처리 zone 제한 (이 범위와 겹치는 타일만 처리)")
|
||
args = ap.parse_args()
|
||
|
||
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
|
||
|
||
img_path = Path(args.input)
|
||
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
||
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||
if image_bgr is None:
|
||
print(f"이미지 로드 실패: {img_path}"); return
|
||
|
||
H, W = image_bgr.shape[:2]
|
||
print(f"이미지 : {W}×{H}")
|
||
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}개")
|
||
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
|
||
|
||
import time
|
||
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
|
||
for item in prompt.split(","):
|
||
print(f" · {item.strip()}")
|
||
print()
|
||
|
||
zone = tuple(args.zone) if args.zone else None
|
||
if zone:
|
||
print(f"zone 제한: x={zone[0]}~{zone[2]} y={zone[1]}~{zone[3]}\n")
|
||
|
||
t0 = time.time()
|
||
shapes = detect_everything_tiled(
|
||
image_bgr, args.cols, args.rows, args.overlap,
|
||
args.conf, args.workers, prompt, zone=zone
|
||
)
|
||
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
|
||
shapes = nms_shapes(shapes, iou_thresh=args.nms)
|
||
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
|
||
|
||
analyze_labels(shapes)
|
||
|
||
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
|
||
h, w = vis.shape[:2]
|
||
if max(h, w) > 4096:
|
||
s = 4096 / max(h, w)
|
||
vis = cv2.resize(vis, (int(w*s), int(h*s)))
|
||
|
||
out_path = (Path(args.output) if args.output
|
||
else img_path.parent / (img_path.stem + "_everything.jpg"))
|
||
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
|
||
print(f"\n저장: {out_path}")
|
||
|
||
# JSON으로 라벨 데이터도 저장 (분석용)
|
||
json_path = out_path.with_suffix(".json")
|
||
label_data = {
|
||
"total_segments": len(shapes),
|
||
"label_counts": dict(Counter(
|
||
s.get("label", "(no label)") for s in shapes
|
||
)),
|
||
"segments": [
|
||
{"label": s.get("label",""), "score": s.get("score",0),
|
||
"bbox": list(_bbox(s["points"])),
|
||
"points": s["points"]}
|
||
for s in shapes
|
||
]
|
||
}
|
||
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
print(f"라벨 데이터: {json_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|