Files
railway-client/tools/post_merge_poles.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

155 lines
5.3 KiB
Python

"""post_merge_poles.py — detect_all_objects.py 출력 JSON에서 catenary_pole 병합.
detecting은 한 번만, 병합 파라미터 조정 시 이 스크립트만 재실행.
Usage:
python tools/post_merge_poles.py INPUT.json [--x-overlap 0.30] [--y-gap 150]
python tools/post_merge_poles.py INPUT.json --inplace
python tools/post_merge_poles.py INPUT.json --output OUTPUT.json
"""
import argparse
import json
import sys
from pathlib import Path
import cv2
import numpy as np
def _poly_orient(points: list, H: int, W: int) -> str:
"""폴리곤 장축 방향 판별.
V: 장축이 이미지 중심 radial 방향 정렬(cos_sim > 0.7) → 세로 기둥
H: 직교 → 수평 빔
?: aspect ratio < 1.3
"""
pts = np.array(points, dtype=np.float32)
rect = cv2.minAreaRect(pts)
(rx, ry), (rw, rh), angle = rect
if min(rw, rh) < 1:
return '?'
ar = max(rw, rh) / min(rw, rh)
if ar < 1.3:
return '?'
long_angle_deg = angle if rw >= rh else angle + 90
lx = float(np.cos(np.radians(long_angle_deg)))
ly = float(np.sin(np.radians(long_angle_deg)))
img_cx, img_cy = W / 2.0, H / 2.0
rdx, rdy = rx - img_cx, ry - img_cy
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
if radial_norm < 1:
return '?'
rdx, rdy = rdx / radial_norm, rdy / radial_norm
cos_sim = abs(lx * rdx + ly * rdy)
return 'V' if cos_sim > 0.7 else 'H'
def merge_poles(shapes: list, H: int, W: int,
x_overlap_thresh: float = 0.30,
y_gap_thresh: int = 150) -> list:
"""V+V 조합 전철주만 타일 경계 병합."""
if len(shapes) <= 1:
return shapes
orients = [_poly_orient(s["points"], H, W) for s in shapes]
v_count = sum(1 for o in orients if o == 'V')
h_count = sum(1 for o in orients if o == 'H')
print(f" orient: V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}")
def get_bbox(s):
xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]]
return min(xs), min(ys), max(xs), max(ys)
def x_overlap_ratio(b1, b2):
ox = min(b1[2], b2[2]) - max(b1[0], b2[0])
ux = max(b1[2], b2[2]) - min(b1[0], b2[0])
return ox / ux if ux > 0 else 0.0
def y_gap(b1, b2):
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
def merge_two(s1, s2):
mask = np.zeros((H, W), dtype=np.uint8)
for s in (s1, s2):
cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return s1
c = max(contours, key=cv2.contourArea)
eps = 0.002 * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, eps, True)
merged = dict(s1)
merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx]
merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0)))
return merged
merged_flags = [False] * len(shapes)
result = []
merged_count = 0
for i in range(len(shapes)):
if merged_flags[i]:
continue
cur = shapes[i]
cur_ori = orients[i]
cb = get_bbox(cur)
for j in range(i + 1, len(shapes)):
if merged_flags[j]:
continue
if cur_ori != 'V' or orients[j] != 'V':
continue
jb = get_bbox(shapes[j])
if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh:
cur = merge_two(cur, shapes[j])
cur_ori = 'V'
cb = get_bbox(cur)
merged_flags[j] = True
merged_count += 1
result.append(cur)
print(f" 병합: {len(shapes)}{len(result)}개 ({merged_count}쌍 합침)")
return result
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", help="AnyLabeling JSON 파일")
ap.add_argument("--output", default=None, help="출력 JSON (기본: INPUT_merged.json)")
ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기")
ap.add_argument("--x-overlap", type=float, default=0.30, help="x-range 겹침 비율 임계값")
ap.add_argument("--y-gap", type=int, default=150, help="y-range 간격 임계값 (px)")
ap.add_argument("--label", default="catenary_pole", help="병합 대상 라벨명")
args = ap.parse_args()
src = Path(args.input)
if not src.exists():
print(f"파일 없음: {src}", file=sys.stderr)
sys.exit(1)
data = json.loads(src.read_text(encoding="utf-8"))
shapes = data.get("shapes", [])
iW = data.get("imageWidth", 0)
iH = data.get("imageHeight", 0)
if iW == 0 or iH == 0:
print("imageWidth/imageHeight 정보 없음", file=sys.stderr)
sys.exit(1)
target = [s for s in shapes if s.get("label") == args.label]
others = [s for s in shapes if s.get("label") != args.label]
print(f"{args.label}: {len(target)}개 → 병합 처리")
merged = merge_poles(target, iH, iW, args.x_overlap, args.y_gap)
data["shapes"] = others + merged
if args.inplace:
dst = src
elif args.output:
dst = Path(args.output)
else:
dst = src.with_stem(src.stem + "_merged")
dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"저장: {dst}")
if __name__ == "__main__":
main()