sam31server 전환, 라멘 파이프라인 정리, 문서 추가
- sam31server를 SAM3.1 서버로 전환 (x-anylabeling01 대체) - detect_raamen.py: B/C 분류 기반 라멘형 전철주 검출 파이프라인 정비 - sam3_everything_explore.py: Discovery Sweep 탐색 모드 정리 - detect_all_objects.py: 타일 검출 개선 - docs/railway-client-guide.html: 서버·도구·파이프라인 전체 가이드 추가 - tools 추가: detect_control_box, group_ramen_poles, render_everything_by_label, render_label_polygons, debug_vh Closes #1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
221
tools/group_ramen_poles.py
Normal file
221
tools/group_ramen_poles.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""group_ramen_poles.py — catenary_pole 라멘 그룹 검출 및 group_id 할당.
|
||||
|
||||
H 빔(수평) + 인접 V 기둥(수직) 쌍을 찾아 동일 group_id 부여.
|
||||
결과 JSON은 post_merge_poles.py가 group_id 없는 V만 병합할 수 있도록 선행 실행.
|
||||
|
||||
Usage:
|
||||
python tools/group_ramen_poles.py INPUT.json [--inplace]
|
||||
python tools/group_ramen_poles.py INPUT.json --x-overlap 0.15 --max-dist 400
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _poly_orient(points, H, W, cos_high=0.75, cos_low=0.45, debug=False):
|
||||
"""V/H/?_ambiguous 판별.
|
||||
|
||||
cos_sim > cos_high → V (명확 기둥)
|
||||
cos_sim < cos_low → H (명확 빔)
|
||||
그 사이 또는 중앙 근처 → ?_ambiguous (후처리에서 Y 최대=V 판별)
|
||||
"""
|
||||
pts = np.array(points, dtype=np.float32)
|
||||
rect = cv2.minAreaRect(pts)
|
||||
(rx, ry), (rw, rh), angle = rect
|
||||
if min(rw, rh) < 1:
|
||||
if debug: print(f" → ? (min_side<1)")
|
||||
return '?'
|
||||
ar = max(rw, rh) / min(rw, rh)
|
||||
if ar < 1.3:
|
||||
if debug: print(f" → ? (ar={ar:.2f}<1.3)")
|
||||
return '?'
|
||||
rdx, rdy = rx - W / 2.0, ry - H / 2.0
|
||||
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
|
||||
center_thresh = (H ** 2 + W ** 2) ** 0.5 * 0.15
|
||||
if radial_norm < center_thresh:
|
||||
if debug: print(f" → ?_ambiguous (center, norm={radial_norm:.0f}<{center_thresh:.0f})")
|
||||
return '?_ambiguous'
|
||||
long_angle_deg = angle if rw >= rh else angle + 90
|
||||
lx = float(np.cos(np.radians(long_angle_deg)))
|
||||
ly = float(np.sin(np.radians(long_angle_deg)))
|
||||
cos_sim = abs(lx * rdx / radial_norm + ly * rdy / radial_norm)
|
||||
if cos_sim > cos_high:
|
||||
orient = 'V'
|
||||
elif cos_sim < cos_low:
|
||||
orient = 'H'
|
||||
else:
|
||||
orient = '?_ambiguous'
|
||||
if debug:
|
||||
bw = pts[:, 0].max() - pts[:, 0].min()
|
||||
bh = pts[:, 1].max() - pts[:, 1].min()
|
||||
print(f" → {orient} (ar={ar:.2f} cos_sim={cos_sim:.3f} "
|
||||
f"bbox={int(bw)}x{int(bh)})")
|
||||
return orient
|
||||
|
||||
|
||||
def _fix_ambiguous_orients(indices, orients, shapes):
|
||||
"""?_ambiguous 폴리곤: x-range 겹치는 그룹 내 Y 최대(가장 아래)=V, 나머지=H.
|
||||
|
||||
기둥 하단은 지면에 박혀 더 아래쪽(y 최대), 빔은 상단에 설치되어 위쪽.
|
||||
"""
|
||||
amb_ids = [i for i in indices if orients[i] == '?_ambiguous']
|
||||
if not amb_ids:
|
||||
return
|
||||
bboxes = {i: _get_bbox(shapes[i]["points"]) for i in amb_ids}
|
||||
assigned = set()
|
||||
for i in amb_ids:
|
||||
if i in assigned:
|
||||
continue
|
||||
bi = bboxes[i]
|
||||
group = [i]
|
||||
for j in amb_ids:
|
||||
if j == i or j in assigned:
|
||||
continue
|
||||
bj = bboxes[j]
|
||||
if bi[0] <= bj[2] and bj[0] <= bi[2]: # x-range 겹침
|
||||
group.append(j)
|
||||
bottom = max(group, key=lambda k: bboxes[k][3]) # y_max 최대 = 가장 아래 = V
|
||||
for k in group:
|
||||
orients[k] = 'V' if k == bottom else 'H'
|
||||
assigned.add(k)
|
||||
print(f" [ambiguous fix] 그룹{group}: "
|
||||
f"V={bottom}, H={[k for k in group if k != bottom]}")
|
||||
|
||||
|
||||
def _get_bbox(points):
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
|
||||
def _y_gap(b1, b2):
|
||||
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
|
||||
|
||||
|
||||
def detect_ramen_groups(shapes, iH, iW, label="catenary_pole",
|
||||
x_overlap_thresh=0.20, max_pole_dist=300,
|
||||
cos_high=0.75, cos_low=0.45):
|
||||
"""V/H 판별 → H빔 anchor로 인접 V 매칭 → 라멘 group_id 할당.
|
||||
|
||||
반환: [(group_id, [h_indices], [v_indices]), ...]
|
||||
"""
|
||||
pole_indices = [i for i, s in enumerate(shapes) if s.get("label") == label]
|
||||
|
||||
for i in pole_indices:
|
||||
shapes[i]["group_id"] = None
|
||||
|
||||
if not pole_indices:
|
||||
return []
|
||||
|
||||
# Step 1: V/H 판별
|
||||
print(f" [orient] 전체 {len(pole_indices)}개 전철주 판별:")
|
||||
orients = {}
|
||||
for i in pole_indices:
|
||||
print(f" shape[{i}]", end="")
|
||||
orients[i] = _poly_orient(shapes[i]["points"], iH, iW,
|
||||
cos_high=cos_high, cos_low=cos_low, debug=True)
|
||||
_fix_ambiguous_orients(pole_indices, orients, shapes)
|
||||
|
||||
h_indices = [i for i in pole_indices if orients[i] == 'H']
|
||||
v_indices = [i for i in pole_indices if orients[i] == 'V']
|
||||
print(f" V={len(v_indices)}, H={len(h_indices)}, "
|
||||
f"?={sum(1 for o in orients.values() if o not in ('V', 'H'))}")
|
||||
|
||||
# Step 2: H빔 anchor → 인접 V 매칭
|
||||
used_v = set()
|
||||
raw_groups = []
|
||||
for hi in h_indices:
|
||||
hb = _get_bbox(shapes[hi]["points"])
|
||||
hx0, hy0, hx1, hy1 = hb
|
||||
h_width = max(hx1 - hx0, 1)
|
||||
matched_v = []
|
||||
for vi in v_indices:
|
||||
if vi in used_v:
|
||||
continue
|
||||
vb = _get_bbox(shapes[vi]["points"])
|
||||
vcx = (vb[0] + vb[2]) / 2.0
|
||||
margin = x_overlap_thresh * h_width
|
||||
if (hx0 - margin) <= vcx <= (hx1 + margin) and _y_gap(hb, vb) <= max_pole_dist:
|
||||
matched_v.append(vi)
|
||||
if matched_v:
|
||||
raw_groups.append(([hi], matched_v))
|
||||
used_v.update(matched_v)
|
||||
|
||||
# Step 3: group_id 할당
|
||||
existing_gids = [s.get("group_id") for s in shapes if isinstance(s.get("group_id"), int)]
|
||||
next_gid = (max(existing_gids) + 1) if existing_gids else 1
|
||||
|
||||
result = []
|
||||
for h_list, v_list in raw_groups:
|
||||
gid = next_gid
|
||||
next_gid += 1
|
||||
for i in h_list + v_list:
|
||||
shapes[i]["group_id"] = gid
|
||||
result.append((gid, h_list, v_list))
|
||||
print(f" → 라멘 group_id={gid}: H{h_list} V{v_list}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("input", help="AnyLabeling JSON 파일")
|
||||
ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기")
|
||||
ap.add_argument("--output", default=None)
|
||||
ap.add_argument("--label", default="catenary_pole")
|
||||
ap.add_argument("--x-overlap", type=float, default=0.20,
|
||||
help="H 빔 너비 기준 x 여유 비율 (기본 0.20)")
|
||||
ap.add_argument("--max-dist", type=int, default=300,
|
||||
help="H-V y 간격 최대값 px (기본 300)")
|
||||
ap.add_argument("--cos-high", type=float, default=0.75,
|
||||
help="cos_sim V 판별 상한 (기본 0.75)")
|
||||
ap.add_argument("--cos-low", type=float, default=0.45,
|
||||
help="cos_sim H 판별 하한 (기본 0.45)")
|
||||
args = ap.parse_args()
|
||||
|
||||
src = Path(args.input)
|
||||
if not src.exists():
|
||||
print(f"파일 없음: {src}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
data = json.loads(src.read_text(encoding="utf-8"))
|
||||
shapes = data.get("shapes", [])
|
||||
iW = data.get("imageWidth", 0)
|
||||
iH = data.get("imageHeight", 0)
|
||||
if iW == 0 or iH == 0:
|
||||
print("imageWidth/imageHeight 없음", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
groups = detect_ramen_groups(shapes, iH, iW, args.label, args.x_overlap, args.max_dist,
|
||||
args.cos_high, args.cos_low)
|
||||
|
||||
print(f"\n라멘 그룹 {len(groups)}개 검출:")
|
||||
for gid, h_list, v_list in groups:
|
||||
print(f" group_id={gid}: H{h_list} V{v_list}")
|
||||
|
||||
poles = [s for s in shapes if s.get("label") == args.label]
|
||||
ungrouped_v = [
|
||||
s for s in poles
|
||||
if s.get("group_id") is None and _poly_orient(s["points"], iH, iW) == 'V'
|
||||
]
|
||||
print(f"그룹 미할당 V 기둥 (병합 대상): {len(ungrouped_v)}개")
|
||||
|
||||
data["shapes"] = shapes
|
||||
|
||||
if args.inplace:
|
||||
dst = src
|
||||
elif args.output:
|
||||
dst = Path(args.output)
|
||||
else:
|
||||
dst = src.with_stem(src.stem + "_grouped")
|
||||
|
||||
dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"저장: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user