Files
railway-client/tools/group_ramen_poles.py
minsung 4c15d5ff5d sam31server 전환, 라멘 파이프라인 정리, 문서 추가
- sam31server를 SAM3.1 서버로 전환 (x-anylabeling01 대체)
- detect_raamen.py: B/C 분류 기반 라멘형 전철주 검출 파이프라인 정비
- sam3_everything_explore.py: Discovery Sweep 탐색 모드 정리
- detect_all_objects.py: 타일 검출 개선
- docs/railway-client-guide.html: 서버·도구·파이프라인 전체 가이드 추가
- tools 추가: detect_control_box, group_ramen_poles, render_everything_by_label, render_label_polygons, debug_vh

Closes #1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-02 10:11:52 +09:00

222 lines
7.7 KiB
Python

"""group_ramen_poles.py — catenary_pole 라멘 그룹 검출 및 group_id 할당.
H 빔(수평) + 인접 V 기둥(수직) 쌍을 찾아 동일 group_id 부여.
결과 JSON은 post_merge_poles.py가 group_id 없는 V만 병합할 수 있도록 선행 실행.
Usage:
python tools/group_ramen_poles.py INPUT.json [--inplace]
python tools/group_ramen_poles.py INPUT.json --x-overlap 0.15 --max-dist 400
"""
import argparse
import json
import sys
from pathlib import Path
import cv2
import numpy as np
def _poly_orient(points, H, W, cos_high=0.75, cos_low=0.45, debug=False):
"""V/H/?_ambiguous 판별.
cos_sim > cos_high → V (명확 기둥)
cos_sim < cos_low → H (명확 빔)
그 사이 또는 중앙 근처 → ?_ambiguous (후처리에서 Y 최대=V 판별)
"""
pts = np.array(points, dtype=np.float32)
rect = cv2.minAreaRect(pts)
(rx, ry), (rw, rh), angle = rect
if min(rw, rh) < 1:
if debug: print(f" → ? (min_side<1)")
return '?'
ar = max(rw, rh) / min(rw, rh)
if ar < 1.3:
if debug: print(f" → ? (ar={ar:.2f}<1.3)")
return '?'
rdx, rdy = rx - W / 2.0, ry - H / 2.0
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
center_thresh = (H ** 2 + W ** 2) ** 0.5 * 0.15
if radial_norm < center_thresh:
if debug: print(f" → ?_ambiguous (center, norm={radial_norm:.0f}<{center_thresh:.0f})")
return '?_ambiguous'
long_angle_deg = angle if rw >= rh else angle + 90
lx = float(np.cos(np.radians(long_angle_deg)))
ly = float(np.sin(np.radians(long_angle_deg)))
cos_sim = abs(lx * rdx / radial_norm + ly * rdy / radial_norm)
if cos_sim > cos_high:
orient = 'V'
elif cos_sim < cos_low:
orient = 'H'
else:
orient = '?_ambiguous'
if debug:
bw = pts[:, 0].max() - pts[:, 0].min()
bh = pts[:, 1].max() - pts[:, 1].min()
print(f"{orient} (ar={ar:.2f} cos_sim={cos_sim:.3f} "
f"bbox={int(bw)}x{int(bh)})")
return orient
def _fix_ambiguous_orients(indices, orients, shapes):
"""?_ambiguous 폴리곤: x-range 겹치는 그룹 내 Y 최대(가장 아래)=V, 나머지=H.
기둥 하단은 지면에 박혀 더 아래쪽(y 최대), 빔은 상단에 설치되어 위쪽.
"""
amb_ids = [i for i in indices if orients[i] == '?_ambiguous']
if not amb_ids:
return
bboxes = {i: _get_bbox(shapes[i]["points"]) for i in amb_ids}
assigned = set()
for i in amb_ids:
if i in assigned:
continue
bi = bboxes[i]
group = [i]
for j in amb_ids:
if j == i or j in assigned:
continue
bj = bboxes[j]
if bi[0] <= bj[2] and bj[0] <= bi[2]: # x-range 겹침
group.append(j)
bottom = max(group, key=lambda k: bboxes[k][3]) # y_max 최대 = 가장 아래 = V
for k in group:
orients[k] = 'V' if k == bottom else 'H'
assigned.add(k)
print(f" [ambiguous fix] 그룹{group}: "
f"V={bottom}, H={[k for k in group if k != bottom]}")
def _get_bbox(points):
xs = [p[0] for p in points]
ys = [p[1] for p in points]
return min(xs), min(ys), max(xs), max(ys)
def _y_gap(b1, b2):
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
def detect_ramen_groups(shapes, iH, iW, label="catenary_pole",
x_overlap_thresh=0.20, max_pole_dist=300,
cos_high=0.75, cos_low=0.45):
"""V/H 판별 → H빔 anchor로 인접 V 매칭 → 라멘 group_id 할당.
반환: [(group_id, [h_indices], [v_indices]), ...]
"""
pole_indices = [i for i, s in enumerate(shapes) if s.get("label") == label]
for i in pole_indices:
shapes[i]["group_id"] = None
if not pole_indices:
return []
# Step 1: V/H 판별
print(f" [orient] 전체 {len(pole_indices)}개 전철주 판별:")
orients = {}
for i in pole_indices:
print(f" shape[{i}]", end="")
orients[i] = _poly_orient(shapes[i]["points"], iH, iW,
cos_high=cos_high, cos_low=cos_low, debug=True)
_fix_ambiguous_orients(pole_indices, orients, shapes)
h_indices = [i for i in pole_indices if orients[i] == 'H']
v_indices = [i for i in pole_indices if orients[i] == 'V']
print(f" V={len(v_indices)}, H={len(h_indices)}, "
f"?={sum(1 for o in orients.values() if o not in ('V', 'H'))}")
# Step 2: H빔 anchor → 인접 V 매칭
used_v = set()
raw_groups = []
for hi in h_indices:
hb = _get_bbox(shapes[hi]["points"])
hx0, hy0, hx1, hy1 = hb
h_width = max(hx1 - hx0, 1)
matched_v = []
for vi in v_indices:
if vi in used_v:
continue
vb = _get_bbox(shapes[vi]["points"])
vcx = (vb[0] + vb[2]) / 2.0
margin = x_overlap_thresh * h_width
if (hx0 - margin) <= vcx <= (hx1 + margin) and _y_gap(hb, vb) <= max_pole_dist:
matched_v.append(vi)
if matched_v:
raw_groups.append(([hi], matched_v))
used_v.update(matched_v)
# Step 3: group_id 할당
existing_gids = [s.get("group_id") for s in shapes if isinstance(s.get("group_id"), int)]
next_gid = (max(existing_gids) + 1) if existing_gids else 1
result = []
for h_list, v_list in raw_groups:
gid = next_gid
next_gid += 1
for i in h_list + v_list:
shapes[i]["group_id"] = gid
result.append((gid, h_list, v_list))
print(f" → 라멘 group_id={gid}: H{h_list} V{v_list}")
return result
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", help="AnyLabeling JSON 파일")
ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기")
ap.add_argument("--output", default=None)
ap.add_argument("--label", default="catenary_pole")
ap.add_argument("--x-overlap", type=float, default=0.20,
help="H 빔 너비 기준 x 여유 비율 (기본 0.20)")
ap.add_argument("--max-dist", type=int, default=300,
help="H-V y 간격 최대값 px (기본 300)")
ap.add_argument("--cos-high", type=float, default=0.75,
help="cos_sim V 판별 상한 (기본 0.75)")
ap.add_argument("--cos-low", type=float, default=0.45,
help="cos_sim H 판별 하한 (기본 0.45)")
args = ap.parse_args()
src = Path(args.input)
if not src.exists():
print(f"파일 없음: {src}", file=sys.stderr)
sys.exit(1)
data = json.loads(src.read_text(encoding="utf-8"))
shapes = data.get("shapes", [])
iW = data.get("imageWidth", 0)
iH = data.get("imageHeight", 0)
if iW == 0 or iH == 0:
print("imageWidth/imageHeight 없음", file=sys.stderr)
sys.exit(1)
groups = detect_ramen_groups(shapes, iH, iW, args.label, args.x_overlap, args.max_dist,
args.cos_high, args.cos_low)
print(f"\n라멘 그룹 {len(groups)}개 검출:")
for gid, h_list, v_list in groups:
print(f" group_id={gid}: H{h_list} V{v_list}")
poles = [s for s in shapes if s.get("label") == args.label]
ungrouped_v = [
s for s in poles
if s.get("group_id") is None and _poly_orient(s["points"], iH, iW) == 'V'
]
print(f"그룹 미할당 V 기둥 (병합 대상): {len(ungrouped_v)}")
data["shapes"] = shapes
if args.inplace:
dst = src
elif args.output:
dst = Path(args.output)
else:
dst = src.with_stem(src.stem + "_grouped")
dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"저장: {dst}")
if __name__ == "__main__":
main()