""" SAM3.1 control_box 단일 이미지 검출 + 결과 저장. 사용: python tools/detect_control_box.py --input [--conf 0.05] [--output ] """ import argparse import base64 import json from collections import Counter from pathlib import Path import cv2 import numpy as np import requests SAM3_SERVER = "http://localhost:8000" PROMPT = ( "small dark object on ballast, small box on ballast, " "metal cover on ground, small bright object on gravel, " "square lid on ground, control box" ) PALETTE = [ (0, 80, 255), (0, 200, 0), (255, 100, 0), (180, 0, 255), (0, 220, 220), (255, 180, 0), ] def detect(img_path: Path, conf: float) -> list: buf = np.fromfile(str(img_path), dtype=np.uint8) img = cv2.imdecode(buf, cv2.IMREAD_COLOR) if img is None: raise FileNotFoundError(img_path) _, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 95]) b64 = base64.b64encode(enc).decode() r = requests.post(f"{SAM3_SERVER}/v1/predict", json={ "model": "segment_anything_3", "image": b64, "params": { "text_prompt": PROMPT, "conf_threshold": conf, "show_masks": True, "show_boxes": False, }, }, timeout=120) r.raise_for_status() shapes = r.json().get("data", {}).get("shapes", []) return img, [s if isinstance(s, dict) else s.dict() for s in shapes] def render(img, shapes, out_path: Path): all_labels = sorted(set(s.get("label", "") for s in shapes)) lc = {l: PALETTE[i % len(PALETTE)] for i, l in enumerate(all_labels)} canvas = img.copy() font = cv2.FONT_HERSHEY_SIMPLEX for s in shapes: if s.get("shape_type") != "polygon": continue pts = np.array(s["points"], dtype=np.int32) color = lc.get(s.get("label", ""), (128, 128, 128)) ov = canvas.copy() cv2.fillPoly(ov, [pts], color) cv2.addWeighted(ov, 0.20, canvas, 0.80, 0, canvas) cv2.polylines(canvas, [pts], True, color, 1) cx = int(np.mean([p[0] for p in s["points"]])) cy = int(np.mean([p[1] for p in s["points"]])) short = (s.get("label", "") .replace("small ", "") .replace(" on ballast", "") .replace(" on ground", "") .replace(" on gravel", "")) cv2.putText(canvas, short, (cx, cy), font, 0.35, color, 1, cv2.LINE_AA) y = 20 for lbl, color in sorted(lc.items()): cnt = sum(1 for s in shapes if s.get("label") == lbl) cv2.rectangle(canvas, (10, y - 12), (22, y), color, -1) cv2.putText(canvas, f"{lbl} ({cnt})", (26, y), font, 0.40, color, 1, cv2.LINE_AA) y += 16 out_path.parent.mkdir(parents=True, exist_ok=True) cv2.imencode(".png", canvas)[1].tofile(str(out_path)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--input", required=True, type=Path) ap.add_argument("--output", default=None, type=Path) ap.add_argument("--conf", type=float, default=0.05) ap.add_argument("--save-json", action="store_true") args = ap.parse_args() out = args.output or args.input.with_name(args.input.stem + "_detected.png") print(f"input : {args.input}") print(f"output: {out}") print("detecting...") img, shapes = detect(args.input, args.conf) render(img, shapes, out) counter = Counter(s.get("label", "") for s in shapes) print(f"total : {len(shapes)}") for lbl, cnt in counter.most_common(): print(f" {lbl}: {cnt}") if args.save_json: jpath = out.with_suffix(".json") jpath.write_text(json.dumps({ "source": str(args.input), "total": len(shapes), "label_counts": dict(counter), "shapes": [{"label": s.get("label",""), "score": s.get("score",0), "points": s.get("points",[])} for s in shapes], }, ensure_ascii=False, indent=2), encoding="utf-8") print(f"json : {jpath}") print("done.") if __name__ == "__main__": main()