프로젝트 분리 이동

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
minsung
2026-05-20 14:28:27 +09:00
commit ccba1266b5
24 changed files with 7900 additions and 0 deletions

View File

@@ -0,0 +1,291 @@
"""
SAM3.1 탐색 모드 (Discovery Sweep)
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
사용법:
python tools/sam3_everything_explore.py \\
--input "data/역사구간/1.회덕역/..." \\
--cols 8 --rows 6
사전 조건: SAM3.1 서버 실행 (start_server.bat)
"""
import argparse
import base64
import json
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
DISCOVERY_PROMPT = (
"railroad track, railway rail, "
"catenary pole, overhead line pole, electric pole, "
"overhead wire, catenary wire, power line cable, "
"railway sleeper, concrete tie, "
"guardrail, highway barrier, road fence, "
"bridge, viaduct, overpass, "
"vegetation, tree, bush, grass, "
"building, structure, roof, wall, "
"vehicle, car, truck, "
"road, asphalt, pavement, "
"slope, embankment, retaining wall, "
"noise barrier, sound wall, "
"signal, sign board"
)
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
b64, scale = encode_image(tile_bgr)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {
"text_prompt": prompt,
"conf_threshold": conf,
"show_masks": True,
"show_boxes": False,
},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception as e:
print(f" [SAM3 오류] {e}")
return []
# ── NMS ───────────────────────────────────────────────────────────────────────
def _bbox(pts):
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
return min(xs), min(ys), max(xs), max(ys)
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
if not shapes:
return []
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]; keep.append(i)
if len(order) == 1: break
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [shapes[i] for i in keep]
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt):
H, W = image_bgr.shape[:2]
base_w = W / cols
base_h = H / rows
pad_x = int(base_w * overlap)
pad_y = int(base_h * overlap)
tiles = []
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
tiles.append((idx, x0, y0, x1, y1))
total = len(tiles)
done = [0]
all_shapes = []
def process(args):
idx, x0, y0, x1, y1 = args
tile = image_bgr[y0:y1, x0:x1]
shapes = sam3_everything(tile, conf, prompt)
for s in shapes:
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
return shapes
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(process, t): t for t in tiles}
for fut in as_completed(futs):
result = fut.result()
all_shapes.extend(result)
done[0] += 1
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}", end="\r")
print()
return all_shapes
# ── 시각화 ────────────────────────────────────────────────────────────────────
def draw_everything(image_bgr, shapes, cols, rows):
vis = image_bgr.copy()
H, W = vis.shape[:2]
# 타일 경계
for r in range(rows):
for c in range(cols):
bx0, by0 = int(c * W / cols), int(r * H / rows)
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
rng = np.random.default_rng(42)
for s in shapes:
pts = np.array(s["points"], dtype=np.int32)
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# 라벨 표시 (있을 경우)
label = s.get("label", "")
if label:
cx = int(np.mean([p[0] for p in s["points"]]))
cy = int(np.mean([p[1] for p in s["points"]]))
cv2.putText(vis, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
cv2.putText(vis, f"total segments: {len(shapes)}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
return vis
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
def analyze_labels(shapes):
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
if not labels:
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
return
counter = Counter(labels)
print(f"\n{''*50}")
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
print(f"{''*50}")
for label, cnt in counter.most_common(30):
bar = "#" * min(cnt, 40)
print(f" {label:35s} {cnt:4d} {bar}")
print(f"{''*50}")
top_labels = [lb for lb, _ in counter.most_common(10)]
print(f"\n[text_prompt 후보]")
print(f' "{", ".join(top_labels)}"')
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--input", required=True, help="입력 이미지")
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
args = ap.parse_args()
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f"이미지 로드 실패: {img_path}"); return
H, W = image_bgr.shape[:2]
print(f"이미지 : {W}×{H}")
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}")
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
import time
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
for item in prompt.split(","):
print(f" · {item.strip()}")
print()
t0 = time.time()
shapes = detect_everything_tiled(
image_bgr, args.cols, args.rows, args.overlap,
args.conf, args.workers, prompt
)
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
shapes = nms_shapes(shapes, iou_thresh=args.nms)
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
analyze_labels(shapes)
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w*s), int(h*s)))
out_path = (Path(args.output) if args.output
else img_path.parent / (img_path.stem + "_everything.jpg"))
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
print(f"\n저장: {out_path}")
# JSON으로 라벨 데이터도 저장 (분석용)
json_path = out_path.with_suffix(".json")
label_data = {
"total_segments": len(shapes),
"label_counts": dict(Counter(
s.get("label", "(no label)") for s in shapes
)),
"segments": [
{"label": s.get("label",""), "score": s.get("score",0),
"bbox": list(_bbox(s["points"]))}
for s in shapes
]
}
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"라벨 데이터: {json_path}")
if __name__ == "__main__":
main()