"""post_merge_poles.py — detect_all_objects.py 출력 JSON에서 catenary_pole 병합. detecting은 한 번만, 병합 파라미터 조정 시 이 스크립트만 재실행. Usage: python tools/post_merge_poles.py INPUT.json [--x-overlap 0.30] [--y-gap 150] python tools/post_merge_poles.py INPUT.json --inplace python tools/post_merge_poles.py INPUT.json --output OUTPUT.json """ import argparse import json import sys from pathlib import Path import cv2 import numpy as np def _poly_orient(points: list, H: int, W: int) -> str: """폴리곤 장축 방향 판별. V: 장축이 이미지 중심 radial 방향 정렬(cos_sim > 0.7) → 세로 기둥 H: 직교 → 수평 빔 ?: aspect ratio < 1.3 """ pts = np.array(points, dtype=np.float32) rect = cv2.minAreaRect(pts) (rx, ry), (rw, rh), angle = rect if min(rw, rh) < 1: return '?' ar = max(rw, rh) / min(rw, rh) if ar < 1.3: return '?' long_angle_deg = angle if rw >= rh else angle + 90 lx = float(np.cos(np.radians(long_angle_deg))) ly = float(np.sin(np.radians(long_angle_deg))) img_cx, img_cy = W / 2.0, H / 2.0 rdx, rdy = rx - img_cx, ry - img_cy radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5 if radial_norm < 1: return '?' rdx, rdy = rdx / radial_norm, rdy / radial_norm cos_sim = abs(lx * rdx + ly * rdy) return 'V' if cos_sim > 0.7 else 'H' def merge_poles(shapes: list, H: int, W: int, x_overlap_thresh: float = 0.30, y_gap_thresh: int = 150) -> list: """V+V 조합 전철주만 타일 경계 병합.""" if len(shapes) <= 1: return shapes orients = [_poly_orient(s["points"], H, W) for s in shapes] v_count = sum(1 for o in orients if o == 'V') h_count = sum(1 for o in orients if o == 'H') print(f" orient: V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}") def get_bbox(s): xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]] return min(xs), min(ys), max(xs), max(ys) def x_overlap_ratio(b1, b2): ox = min(b1[2], b2[2]) - max(b1[0], b2[0]) ux = max(b1[2], b2[2]) - min(b1[0], b2[0]) return ox / ux if ux > 0 else 0.0 def y_gap(b1, b2): return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3])) def merge_two(s1, s2): mask = np.zeros((H, W), dtype=np.uint8) for s in (s1, s2): cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return s1 c = max(contours, key=cv2.contourArea) eps = 0.002 * cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, eps, True) merged = dict(s1) merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx] merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0))) return merged merged_flags = [False] * len(shapes) result = [] merged_count = 0 for i in range(len(shapes)): if merged_flags[i]: continue cur = shapes[i] cur_ori = orients[i] cb = get_bbox(cur) for j in range(i + 1, len(shapes)): if merged_flags[j]: continue if cur_ori != 'V' or orients[j] != 'V': continue jb = get_bbox(shapes[j]) if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh: cur = merge_two(cur, shapes[j]) cur_ori = 'V' cb = get_bbox(cur) merged_flags[j] = True merged_count += 1 result.append(cur) print(f" 병합: {len(shapes)} → {len(result)}개 ({merged_count}쌍 합침)") return result def main(): ap = argparse.ArgumentParser() ap.add_argument("input", help="AnyLabeling JSON 파일") ap.add_argument("--output", default=None, help="출력 JSON (기본: INPUT_merged.json)") ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기") ap.add_argument("--x-overlap", type=float, default=0.30, help="x-range 겹침 비율 임계값") ap.add_argument("--y-gap", type=int, default=150, help="y-range 간격 임계값 (px)") ap.add_argument("--label", default="catenary_pole", help="병합 대상 라벨명") args = ap.parse_args() src = Path(args.input) if not src.exists(): print(f"파일 없음: {src}", file=sys.stderr) sys.exit(1) data = json.loads(src.read_text(encoding="utf-8")) shapes = data.get("shapes", []) iW = data.get("imageWidth", 0) iH = data.get("imageHeight", 0) if iW == 0 or iH == 0: print("imageWidth/imageHeight 정보 없음", file=sys.stderr) sys.exit(1) target = [s for s in shapes if s.get("label") == args.label] others = [s for s in shapes if s.get("label") != args.label] print(f"{args.label}: {len(target)}개 → 병합 처리") merged = merge_poles(target, iH, iW, args.x_overlap, args.y_gap) data["shapes"] = others + merged if args.inplace: dst = src elif args.output: dst = Path(args.output) else: dst = src.with_stem(src.stem + "_merged") dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") print(f"저장: {dst}") if __name__ == "__main__": main()