- sam31server를 SAM3.1 서버로 전환 (x-anylabeling01 대체) - detect_raamen.py: B/C 분류 기반 라멘형 전철주 검출 파이프라인 정비 - sam3_everything_explore.py: Discovery Sweep 탐색 모드 정리 - detect_all_objects.py: 타일 검출 개선 - docs/railway-client-guide.html: 서버·도구·파이프라인 전체 가이드 추가 - tools 추가: detect_control_box, group_ramen_poles, render_everything_by_label, render_label_polygons, debug_vh Closes #1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
"""
|
|
SAM3.1 control_box 단일 이미지 검출 + 결과 저장.
|
|
|
|
사용:
|
|
python tools/detect_control_box.py --input <image_path> [--conf 0.05] [--output <path>]
|
|
"""
|
|
import argparse
|
|
import base64
|
|
import json
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
|
|
SAM3_SERVER = "http://localhost:8000"
|
|
PROMPT = (
|
|
"small dark object on ballast, small box on ballast, "
|
|
"metal cover on ground, small bright object on gravel, "
|
|
"square lid on ground, control box"
|
|
)
|
|
PALETTE = [
|
|
(0, 80, 255), (0, 200, 0), (255, 100, 0),
|
|
(180, 0, 255), (0, 220, 220), (255, 180, 0),
|
|
]
|
|
|
|
|
|
def detect(img_path: Path, conf: float) -> list:
|
|
buf = np.fromfile(str(img_path), dtype=np.uint8)
|
|
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
|
if img is None:
|
|
raise FileNotFoundError(img_path)
|
|
_, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
b64 = base64.b64encode(enc).decode()
|
|
r = requests.post(f"{SAM3_SERVER}/v1/predict", json={
|
|
"model": "segment_anything_3",
|
|
"image": b64,
|
|
"params": {
|
|
"text_prompt": PROMPT,
|
|
"conf_threshold": conf,
|
|
"show_masks": True,
|
|
"show_boxes": False,
|
|
},
|
|
}, timeout=120)
|
|
r.raise_for_status()
|
|
shapes = r.json().get("data", {}).get("shapes", [])
|
|
return img, [s if isinstance(s, dict) else s.dict() for s in shapes]
|
|
|
|
|
|
def render(img, shapes, out_path: Path):
|
|
all_labels = sorted(set(s.get("label", "") for s in shapes))
|
|
lc = {l: PALETTE[i % len(PALETTE)] for i, l in enumerate(all_labels)}
|
|
canvas = img.copy()
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
for s in shapes:
|
|
if s.get("shape_type") != "polygon":
|
|
continue
|
|
pts = np.array(s["points"], dtype=np.int32)
|
|
color = lc.get(s.get("label", ""), (128, 128, 128))
|
|
ov = canvas.copy()
|
|
cv2.fillPoly(ov, [pts], color)
|
|
cv2.addWeighted(ov, 0.20, canvas, 0.80, 0, canvas)
|
|
cv2.polylines(canvas, [pts], True, color, 1)
|
|
cx = int(np.mean([p[0] for p in s["points"]]))
|
|
cy = int(np.mean([p[1] for p in s["points"]]))
|
|
short = (s.get("label", "")
|
|
.replace("small ", "")
|
|
.replace(" on ballast", "")
|
|
.replace(" on ground", "")
|
|
.replace(" on gravel", ""))
|
|
cv2.putText(canvas, short, (cx, cy), font, 0.35, color, 1, cv2.LINE_AA)
|
|
|
|
y = 20
|
|
for lbl, color in sorted(lc.items()):
|
|
cnt = sum(1 for s in shapes if s.get("label") == lbl)
|
|
cv2.rectangle(canvas, (10, y - 12), (22, y), color, -1)
|
|
cv2.putText(canvas, f"{lbl} ({cnt})", (26, y), font, 0.40, color, 1, cv2.LINE_AA)
|
|
y += 16
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
cv2.imencode(".png", canvas)[1].tofile(str(out_path))
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--input", required=True, type=Path)
|
|
ap.add_argument("--output", default=None, type=Path)
|
|
ap.add_argument("--conf", type=float, default=0.05)
|
|
ap.add_argument("--save-json", action="store_true")
|
|
args = ap.parse_args()
|
|
|
|
out = args.output or args.input.with_name(args.input.stem + "_detected.png")
|
|
|
|
print(f"input : {args.input}")
|
|
print(f"output: {out}")
|
|
print("detecting...")
|
|
|
|
img, shapes = detect(args.input, args.conf)
|
|
render(img, shapes, out)
|
|
|
|
counter = Counter(s.get("label", "") for s in shapes)
|
|
print(f"total : {len(shapes)}")
|
|
for lbl, cnt in counter.most_common():
|
|
print(f" {lbl}: {cnt}")
|
|
|
|
if args.save_json:
|
|
jpath = out.with_suffix(".json")
|
|
jpath.write_text(json.dumps({
|
|
"source": str(args.input),
|
|
"total": len(shapes),
|
|
"label_counts": dict(counter),
|
|
"shapes": [{"label": s.get("label",""), "score": s.get("score",0),
|
|
"points": s.get("points",[])} for s in shapes],
|
|
}, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
print(f"json : {jpath}")
|
|
|
|
print("done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|