"""group_ramen_poles.py — catenary_pole 라멘 그룹 검출 및 group_id 할당. H 빔(수평) + 인접 V 기둥(수직) 쌍을 찾아 동일 group_id 부여. 결과 JSON은 post_merge_poles.py가 group_id 없는 V만 병합할 수 있도록 선행 실행. Usage: python tools/group_ramen_poles.py INPUT.json [--inplace] python tools/group_ramen_poles.py INPUT.json --x-overlap 0.15 --max-dist 400 """ import argparse import json import sys from pathlib import Path import cv2 import numpy as np def _poly_orient(points, H, W, cos_high=0.75, cos_low=0.45, debug=False): """V/H/?_ambiguous 판별. cos_sim > cos_high → V (명확 기둥) cos_sim < cos_low → H (명확 빔) 그 사이 또는 중앙 근처 → ?_ambiguous (후처리에서 Y 최대=V 판별) """ pts = np.array(points, dtype=np.float32) rect = cv2.minAreaRect(pts) (rx, ry), (rw, rh), angle = rect if min(rw, rh) < 1: if debug: print(f" → ? (min_side<1)") return '?' ar = max(rw, rh) / min(rw, rh) if ar < 1.3: if debug: print(f" → ? (ar={ar:.2f}<1.3)") return '?' rdx, rdy = rx - W / 2.0, ry - H / 2.0 radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5 center_thresh = (H ** 2 + W ** 2) ** 0.5 * 0.15 if radial_norm < center_thresh: if debug: print(f" → ?_ambiguous (center, norm={radial_norm:.0f}<{center_thresh:.0f})") return '?_ambiguous' long_angle_deg = angle if rw >= rh else angle + 90 lx = float(np.cos(np.radians(long_angle_deg))) ly = float(np.sin(np.radians(long_angle_deg))) cos_sim = abs(lx * rdx / radial_norm + ly * rdy / radial_norm) if cos_sim > cos_high: orient = 'V' elif cos_sim < cos_low: orient = 'H' else: orient = '?_ambiguous' if debug: bw = pts[:, 0].max() - pts[:, 0].min() bh = pts[:, 1].max() - pts[:, 1].min() print(f" → {orient} (ar={ar:.2f} cos_sim={cos_sim:.3f} " f"bbox={int(bw)}x{int(bh)})") return orient def _fix_ambiguous_orients(indices, orients, shapes): """?_ambiguous 폴리곤: x-range 겹치는 그룹 내 Y 최대(가장 아래)=V, 나머지=H. 기둥 하단은 지면에 박혀 더 아래쪽(y 최대), 빔은 상단에 설치되어 위쪽. """ amb_ids = [i for i in indices if orients[i] == '?_ambiguous'] if not amb_ids: return bboxes = {i: _get_bbox(shapes[i]["points"]) for i in amb_ids} assigned = set() for i in amb_ids: if i in assigned: continue bi = bboxes[i] group = [i] for j in amb_ids: if j == i or j in assigned: continue bj = bboxes[j] if bi[0] <= bj[2] and bj[0] <= bi[2]: # x-range 겹침 group.append(j) bottom = max(group, key=lambda k: bboxes[k][3]) # y_max 최대 = 가장 아래 = V for k in group: orients[k] = 'V' if k == bottom else 'H' assigned.add(k) print(f" [ambiguous fix] 그룹{group}: " f"V={bottom}, H={[k for k in group if k != bottom]}") def _get_bbox(points): xs = [p[0] for p in points] ys = [p[1] for p in points] return min(xs), min(ys), max(xs), max(ys) def _y_gap(b1, b2): return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3])) def detect_ramen_groups(shapes, iH, iW, label="catenary_pole", x_overlap_thresh=0.20, max_pole_dist=300, cos_high=0.75, cos_low=0.45): """V/H 판별 → H빔 anchor로 인접 V 매칭 → 라멘 group_id 할당. 반환: [(group_id, [h_indices], [v_indices]), ...] """ pole_indices = [i for i, s in enumerate(shapes) if s.get("label") == label] for i in pole_indices: shapes[i]["group_id"] = None if not pole_indices: return [] # Step 1: V/H 판별 print(f" [orient] 전체 {len(pole_indices)}개 전철주 판별:") orients = {} for i in pole_indices: print(f" shape[{i}]", end="") orients[i] = _poly_orient(shapes[i]["points"], iH, iW, cos_high=cos_high, cos_low=cos_low, debug=True) _fix_ambiguous_orients(pole_indices, orients, shapes) h_indices = [i for i in pole_indices if orients[i] == 'H'] v_indices = [i for i in pole_indices if orients[i] == 'V'] print(f" V={len(v_indices)}, H={len(h_indices)}, " f"?={sum(1 for o in orients.values() if o not in ('V', 'H'))}") # Step 2: H빔 anchor → 인접 V 매칭 used_v = set() raw_groups = [] for hi in h_indices: hb = _get_bbox(shapes[hi]["points"]) hx0, hy0, hx1, hy1 = hb h_width = max(hx1 - hx0, 1) matched_v = [] for vi in v_indices: if vi in used_v: continue vb = _get_bbox(shapes[vi]["points"]) vcx = (vb[0] + vb[2]) / 2.0 margin = x_overlap_thresh * h_width if (hx0 - margin) <= vcx <= (hx1 + margin) and _y_gap(hb, vb) <= max_pole_dist: matched_v.append(vi) if matched_v: raw_groups.append(([hi], matched_v)) used_v.update(matched_v) # Step 3: group_id 할당 existing_gids = [s.get("group_id") for s in shapes if isinstance(s.get("group_id"), int)] next_gid = (max(existing_gids) + 1) if existing_gids else 1 result = [] for h_list, v_list in raw_groups: gid = next_gid next_gid += 1 for i in h_list + v_list: shapes[i]["group_id"] = gid result.append((gid, h_list, v_list)) print(f" → 라멘 group_id={gid}: H{h_list} V{v_list}") return result def main(): ap = argparse.ArgumentParser() ap.add_argument("input", help="AnyLabeling JSON 파일") ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기") ap.add_argument("--output", default=None) ap.add_argument("--label", default="catenary_pole") ap.add_argument("--x-overlap", type=float, default=0.20, help="H 빔 너비 기준 x 여유 비율 (기본 0.20)") ap.add_argument("--max-dist", type=int, default=300, help="H-V y 간격 최대값 px (기본 300)") ap.add_argument("--cos-high", type=float, default=0.75, help="cos_sim V 판별 상한 (기본 0.75)") ap.add_argument("--cos-low", type=float, default=0.45, help="cos_sim H 판별 하한 (기본 0.45)") args = ap.parse_args() src = Path(args.input) if not src.exists(): print(f"파일 없음: {src}", file=sys.stderr) sys.exit(1) data = json.loads(src.read_text(encoding="utf-8")) shapes = data.get("shapes", []) iW = data.get("imageWidth", 0) iH = data.get("imageHeight", 0) if iW == 0 or iH == 0: print("imageWidth/imageHeight 없음", file=sys.stderr) sys.exit(1) groups = detect_ramen_groups(shapes, iH, iW, args.label, args.x_overlap, args.max_dist, args.cos_high, args.cos_low) print(f"\n라멘 그룹 {len(groups)}개 검출:") for gid, h_list, v_list in groups: print(f" group_id={gid}: H{h_list} V{v_list}") poles = [s for s in shapes if s.get("label") == args.label] ungrouped_v = [ s for s in poles if s.get("group_id") is None and _poly_orient(s["points"], iH, iW) == 'V' ] print(f"그룹 미할당 V 기둥 (병합 대상): {len(ungrouped_v)}개") data["shapes"] = shapes if args.inplace: dst = src elif args.output: dst = Path(args.output) else: dst = src.with_stem(src.stem + "_grouped") dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") print(f"저장: {dst}") if __name__ == "__main__": main()