From ccba1266b5a9dcf8cc4164b34f0b2c22afa2a703 Mon Sep 17 00:00:00 2001 From: minsung Date: Wed, 20 May 2026 14:28:27 +0900 Subject: [PATCH] =?UTF-8?q?=ED=94=84=EB=A1=9C=EC=A0=9D=ED=8A=B8=20?= =?UTF-8?q?=EB=B6=84=EB=A6=AC=20=EC=9D=B4=EB=8F=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 67 +++ CLAUDE.md | 65 +++ configs/railway_zone.json | 143 +++++ tools/auto_rail_detect.py | 301 +++++++++++ tools/detect_all_objects.py | 565 ++++++++++++++++++++ tools/detect_hollow_section.py | 152 ++++++ tools/detect_raamen.py | 666 +++++++++++++++++++++++ tools/labeling_server.py | 503 ++++++++++++++++++ tools/merge_tiles_vis.py | 90 ++++ tools/post_merge_poles.py | 154 ++++++ tools/rail_alignment_fit.py | 370 +++++++++++++ tools/rail_centerline_dxf.py | 175 +++++++ tools/rail_to_dxf.py | 234 +++++++++ tools/railway_pipeline.py | 870 +++++++++++++++++++++++++++++++ tools/render_skeleton_overlay.py | 535 +++++++++++++++++++ tools/sam3_autolabel.py | 777 +++++++++++++++++++++++++++ tools/sam3_batch_label.py | 278 ++++++++++ tools/sam3_everything_explore.py | 291 +++++++++++ tools/sam3_receipt_ocr.py | 219 ++++++++ tools/sam3_segment_everything.py | 185 +++++++ tools/show_tiles.py | 72 +++ tools/video_sam3_segment.py | 237 +++++++++ tools/web_ui.py | 618 ++++++++++++++++++++++ tools/yoloworld_sam3_pipeline.py | 333 ++++++++++++ 24 files changed, 7900 insertions(+) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 configs/railway_zone.json create mode 100644 tools/auto_rail_detect.py create mode 100644 tools/detect_all_objects.py create mode 100644 tools/detect_hollow_section.py create mode 100644 tools/detect_raamen.py create mode 100644 tools/labeling_server.py create mode 100644 tools/merge_tiles_vis.py create mode 100644 tools/post_merge_poles.py create mode 100644 tools/rail_alignment_fit.py create mode 100644 tools/rail_centerline_dxf.py create mode 100644 tools/rail_to_dxf.py create mode 100644 tools/railway_pipeline.py create mode 100644 tools/render_skeleton_overlay.py create mode 100644 tools/sam3_autolabel.py create mode 100644 tools/sam3_batch_label.py create mode 100644 tools/sam3_everything_explore.py create mode 100644 tools/sam3_receipt_ocr.py create mode 100644 tools/sam3_segment_everything.py create mode 100644 tools/show_tiles.py create mode 100644 tools/video_sam3_segment.py create mode 100644 tools/web_ui.py create mode 100644 tools/yoloworld_sam3_pipeline.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5dfb8f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,67 @@ +# 가상환경 (용량 크므로 제외) +.venv/ +.venv_client/ + +# 모델 파일 (용량 매우 큼) +*.pt +*.pth +*.onnx +*.gz + +# 로그 +X-AnyLabeling-Server/logs/ +*.log + +# Python +__pycache__/ +*.pyc +*.pyo + +# 캐시 +.cache/ + +# IDE +.vscode/ +.idea/ + +# OS +.DS_Store +Thumbs.db + +# 임시 파일 +create_shortcuts.ps1 + +# 클로드 개인 설정/메모리 +.claude/ + +# X-AnyLabeling-Server (별도 git repo, submodule 아님) +X-AnyLabeling-Server/ + +# 용량 큰 데이터/결과 폴더 +drone_2cm/ +output/ +runs/ +data/ + +# 용량 큰 원본 영상 파일 +*.tif +경부선_2-2구간_미션33-34(5cm).png + +# 논문/문서 파일 +*.pdf +*.txt + +# 미정리 툴 (작업 중) +tools/render_polygons_rainbow.py +tools/render_yolo_labels.py +tools/sam3_sleeper_detect.py +tools/tif_hollow_detect.py + +# 기타 +src/ +PRPs/ +nul +triton-*.whl +pyproject.toml +Project_Status.md +.usage/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..daced9b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,65 @@ +# CLAUDE.md + +Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed. + +**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment. + +## 1. Think Before Coding + +**Don't assume. Don't hide confusion. Surface tradeoffs.** + +Before implementing: +- State your assumptions explicitly. If uncertain, ask. +- If multiple interpretations exist, present them - don't pick silently. +- If a simpler approach exists, say so. Push back when warranted. +- If something is unclear, stop. Name what's confusing. Ask. + +## 2. Simplicity First + +**Minimum code that solves the problem. Nothing speculative.** + +- No features beyond what was asked. +- No abstractions for single-use code. +- No "flexibility" or "configurability" that wasn't requested. +- No error handling for impossible scenarios. +- If you write 200 lines and it could be 50, rewrite it. + +Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify. + +## 3. Surgical Changes + +**Touch only what you must. Clean up only your own mess.** + +When editing existing code: +- Don't "improve" adjacent code, comments, or formatting. +- Don't refactor things that aren't broken. +- Match existing style, even if you'd do it differently. +- If you notice unrelated dead code, mention it - don't delete it. + +When your changes create orphans: +- Remove imports/variables/functions that YOUR changes made unused. +- Don't remove pre-existing dead code unless asked. + +The test: Every changed line should trace directly to the user's request. + +## 4. Goal-Driven Execution + +**Define success criteria. Loop until verified.** + +Transform tasks into verifiable goals: +- "Add validation" → "Write tests for invalid inputs, then make them pass" +- "Fix the bug" → "Write a test that reproduces it, then make it pass" +- "Refactor X" → "Ensure tests pass before and after" + +For multi-step tasks, state a brief plan: +``` +1. [Step] → verify: [check] +2. [Step] → verify: [check] +3. [Step] → verify: [check] +``` + +Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification. + +--- + +**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes. diff --git a/configs/railway_zone.json b/configs/railway_zone.json new file mode 100644 index 0000000..87b7415 --- /dev/null +++ b/configs/railway_zone.json @@ -0,0 +1,143 @@ +{ + "_comment": "철도 구간 집중 검출용 카테고리 설정 (타일 9-24 기준)", + "_conf_note": "conf: 클래스별 신뢰도 임계값. priority: 낮을수록 cross-class NMS에서 우선 보존(화면 위에 표시).", + "_priority_order": "2=컨트롤박스 > 3=전철주 > 4=팬스 > 5=레일 > 6=침목 > 7=자갈(맨 밑)", + "cross_class_nms_iou": 0.45, + "categories": [ + { + "name": "control_box", + "name_kr": "컨트롤박스", + "prompt": "small square gray metal box beside rail, compact trackside junction box, small near-square electrical enclosure on the ground, small cube-shaped equipment box next to track", + "keywords": ["small square box", "junction box", "compact enclosure", "square metal box", "small cube box"], + "color_bgr": [255, 255, 0], + "conf": 0.15, + "priority": 2 + }, + { + "name": "catenary_pole", + "name_kr": "전철주", + "prompt": "railway catenary pole, overhead line pole, catenary mast, electric railway pole", + "keywords": ["catenary pole", "overhead line pole", "catenary mast", "railway pole"], + "color_bgr": [255, 0, 255], + "conf": 0.25, + "priority": 3 + }, + { + "name": "fence", + "name_kr": "팬스/울타리", + "prompt": "railway fence, trackside fence, perimeter fence, chain link fence, railway boundary fence", + "keywords": ["fence", "chain link", "perimeter fence", "boundary fence"], + "color_bgr": [0, 255, 0], + "conf": 0.50, + "priority": 5 + }, + { + "name": "railway", + "name_kr": "철도 레일", + "prompt": "two parallel steel rails, narrow metallic longitudinal beam, shiny steel rail line on sleeper", + "keywords": ["railroad", "railway rail", "steel rail", "train track", "parallel rail"], + "color_bgr": [0, 0, 255], + "conf": 0.25, + "priority": 4 + }, + { + "name": "sleeper", + "name_kr": "침목", + "prompt": "rectangular concrete sleeper, dark crosswise tie, evenly spaced railroad tie perpendicular to rail, concrete slab between ballast", + "keywords": ["sleeper", "railroad tie", "rail tie", "wooden tie", "concrete tie", "crosswise tie"], + "color_bgr": [0, 128, 255], + "conf": 0.20, + "priority": 6 + }, + { + "name": "ballast", + "name_kr": "자갈도상", + "prompt": "crushed gray stone gravel, railway ballast aggregate, coarse gravel track bed, angular stone between rails", + "keywords": ["ballast", "gravel between rails", "track bed", "crushed stone", "aggregate"], + "color_bgr": [30, 50, 100], + "conf": 0.20, + "priority": 7 + }, + { + "name": "bracket", + "name_kr": "브라켓/암", + "prompt": "catenary bracket arm, horizontal cantilever arm, pole cross arm, overhead wire bracket", + "keywords": ["bracket", "cantilever", "cross arm"], + "color_bgr": [0, 80, 200], + "conf": 0.20, + "priority": 8 + }, + { + "name": "bridge", + "name_kr": "교량/교각", + "prompt": "railway bridge, road bridge, overpass, viaduct, concrete bridge structure", + "keywords": ["bridge", "overpass", "viaduct"], + "color_bgr": [0, 165, 255], + "conf": 0.25, + "priority": 8 + }, + { + "name": "retaining_wall", + "name_kr": "방음벽/옹벽", + "prompt": "railway retaining wall, noise barrier wall, sound barrier, concrete embankment wall", + "keywords": ["retaining wall", "noise barrier", "sound barrier", "embankment wall"], + "color_bgr": [60, 180, 60], + "conf": 0.25, + "priority": 9 + }, + { + "name": "service_road", + "name_kr": "유지보수 도로", + "prompt": "railway maintenance road, service road alongside track, unpaved railway access road", + "keywords": ["service road", "maintenance road", "access road"], + "color_bgr": [80, 120, 160], + "conf": 0.30, + "priority": 10 + }, + { + "name": "culvert", + "name_kr": "암거/소교량", + "prompt": "railway culvert, drainage culvert, small underpass tunnel, concrete drainage structure", + "keywords": ["culvert", "underpass", "drainage tunnel"], + "color_bgr": [180, 60, 180], + "conf": 0.20, + "priority": 9 + }, + { + "name": "vehicle", + "name_kr": "차량", + "prompt": "car, truck, vehicle, automobile, van", + "keywords": ["car", "truck", "vehicle", "automobile", "van"], + "color_bgr": [255, 255, 255], + "conf": 0.25, + "priority": 2 + }, + { + "name": "building", + "name_kr": "건물", + "prompt": "building, house, rooftop, structure, roof", + "keywords": ["building", "house", "rooftop", "structure", "roof"], + "color_bgr": [50, 50, 255], + "conf": 0.25, + "priority": 8 + }, + { + "name": "farmland", + "name_kr": "농지", + "prompt": "farmland, agricultural field, cropland, vegetable garden", + "keywords": ["farmland", "field", "cropland", "vegetable", "agricultural"], + "color_bgr": [50, 200, 50], + "conf": 0.25, + "priority": 11 + }, + { + "name": "vegetation", + "name_kr": "식생", + "prompt": "trees, forest, shrubs, vegetation, bushes", + "keywords": ["tree", "forest", "shrub", "vegetation", "bush"], + "color_bgr": [0, 120, 0], + "conf": 0.25, + "priority": 12 + } + ] +} diff --git a/tools/auto_rail_detect.py b/tools/auto_rail_detect.py new file mode 100644 index 0000000..3ba72ec --- /dev/null +++ b/tools/auto_rail_detect.py @@ -0,0 +1,301 @@ +""" +auto_rail_detect.py +=================== +항공 이미지에서 레일 라인을 자동 검출하여 Rhino용 DXF로 저장. +수동 라벨링 없이 이미지 1장에서 바로 실행. + +사용법: + python tools/auto_rail_detect.py [output.dxf] + +예시: + python tools/auto_rail_detect.py "경부선_2-2구간_미션33-34(5cm).png" + python tools/auto_rail_detect.py "경부선.png" output/rail.dxf + +원리: + 이미지 → 엣지 검출 → Hough 직선 검출 → 방향별 클러스터링 → DXF 폴리라인 +""" + +import sys +import cv2 +import numpy as np +from pathlib import Path + + +# ─── 설정 (조정 가능) ───────────────────────────────────────────────────────── +CANNY_LOW = 30 # Canny 엣지 낮은 임계값 (낮을수록 엣지 많이 검출) +CANNY_HIGH = 80 # Canny 엣지 높은 임계값 +HOUGH_THRESHOLD = 100 # Hough 투표 임계값 (높을수록 긴 선만 검출) +HOUGH_MIN_LEN = 200 # 최소 선 길이 (픽셀) +HOUGH_MAX_GAP = 30 # 선 간격 허용 (픽셀, 클수록 끊긴 선도 연결) +ANGLE_TOLERANCE = 5 # 같은 방향으로 볼 각도 범위 (도) +CLUSTER_DIST = 20 # 같은 레일로 볼 선 간격 (픽셀) +MIN_TOTAL_LEN = 500 # 최종 레일 최소 총 길이 (픽셀, 짧은 잡선 제거) +# ────────────────────────────────────────────────────────────────────────────── + + +def detect_edges(img_gray): + """CLAHE 대비 강화 → Gaussian 블러 → Canny 엣지""" + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(img_gray) + blurred = cv2.GaussianBlur(enhanced, (5, 5), 0) + edges = cv2.Canny(blurred, CANNY_LOW, CANNY_HIGH, apertureSize=3) + return edges + + +def detect_lines(edges): + """확률적 Hough 변환으로 직선 검출""" + lines = cv2.HoughLinesP( + edges, + rho=1, + theta=np.pi / 180, + threshold=HOUGH_THRESHOLD, + minLineLength=HOUGH_MIN_LEN, + maxLineGap=HOUGH_MAX_GAP, + ) + if lines is None: + return [] + return [tuple(l[0]) for l in lines] # [(x1,y1,x2,y2), ...] + + +def line_angle(x1, y1, x2, y2): + """선의 각도 (0~180도)""" + angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180 + return angle + + +def line_length(x1, y1, x2, y2): + return np.hypot(x2 - x1, y2 - y1) + + +def line_midpoint(x1, y1, x2, y2): + return ((x1 + x2) / 2, (y1 + y2) / 2) + + +def perpendicular_dist(x1, y1, x2, y2, px, py): + """점 (px,py)에서 선 (x1,y1)-(x2,y2)까지 수직거리""" + dx, dy = x2 - x1, y2 - y1 + length = np.hypot(dx, dy) + if length == 0: + return np.hypot(px - x1, py - y1) + return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / length + + +def cluster_lines(lines): + """비슷한 방향 + 가까운 위치의 선들을 같은 레일로 묶기""" + if not lines: + return [] + + # 각도별 그룹 먼저 분리 + angle_groups = {} + for line in lines: + x1, y1, x2, y2 = line + ang = round(line_angle(x1, y1, x2, y2) / ANGLE_TOLERANCE) * ANGLE_TOLERANCE + angle_groups.setdefault(ang, []).append(line) + + clusters = [] + for ang, grp in angle_groups.items(): + # 같은 방향 그룹 내에서 거리 기반 클러스터링 + used = [False] * len(grp) + for i, line_i in enumerate(grp): + if used[i]: + continue + cluster = [line_i] + used[i] = True + mx_i, my_i = line_midpoint(*line_i) + for j, line_j in enumerate(grp): + if used[j]: + continue + mx_j, my_j = line_midpoint(*line_j) + dist = perpendicular_dist(*line_i, mx_j, my_j) + if dist < CLUSTER_DIST: + cluster.append(line_j) + used[j] = True + clusters.append(cluster) + + return clusters + + +def merge_cluster_to_polyline(cluster): + """클러스터의 선들을 하나의 정렬된 폴리라인으로 합치기""" + # 모든 끝점 수집 + all_pts = [] + for x1, y1, x2, y2 in cluster: + all_pts.append((x1, y1)) + all_pts.append((x2, y2)) + + if not all_pts: + return [] + + # 주성분 방향으로 정렬 + pts_arr = np.array(all_pts, dtype=float) + mean = pts_arr.mean(axis=0) + centered = pts_arr - mean + cov = np.cov(centered.T) + if cov.ndim < 2: + direction = np.array([1.0, 0.0]) + else: + eigenvalues, eigenvectors = np.linalg.eig(cov) + direction = eigenvectors[:, np.argmax(eigenvalues)] + + # 방향으로 투영하여 정렬 + projections = centered.dot(direction) + sorted_idx = np.argsort(projections) + sorted_pts = pts_arr[sorted_idx] + + # 중복 제거 (가까운 점 합치기) + merged = [sorted_pts[0]] + for pt in sorted_pts[1:]: + if np.hypot(pt[0] - merged[-1][0], pt[1] - merged[-1][1]) > 5: + merged.append(pt) + + return merged + + +def total_polyline_length(polyline): + if len(polyline) < 2: + return 0 + total = 0 + for i in range(len(polyline) - 1): + total += np.hypot( + polyline[i+1][0] - polyline[i][0], + polyline[i+1][1] - polyline[i][1] + ) + return total + + +def save_debug_image(img, all_polylines, output_path): + """검출 결과를 이미지에 시각화 (확인용)""" + vis = img.copy() if len(img.shape) == 3 else cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), + (255, 0, 255), (255, 255, 0)] + for i, poly in enumerate(all_polylines): + color = colors[i % len(colors)] + pts = np.array([[int(p[0]), int(p[1])] for p in poly]) + cv2.polylines(vis, [pts], False, color, 2) + # 번호 표시 + cx, cy = int(poly[len(poly)//2][0]), int(poly[len(poly)//2][1]) + cv2.putText(vis, str(i+1), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, + 1.5, color, 3) + cv2.imwrite(str(output_path), vis) + print(f" 시각화 저장: {output_path}") + + +def process(image_path: str, dxf_path: str): + import ezdxf + + print(f"[입력] {image_path}") + # 한글 경로 대응: numpy로 읽어서 cv2 디코딩 + import numpy as np_io + with open(image_path, "rb") as f: + data = f.read() + img_arr = np_io.frombuffer(data, dtype=np_io.uint8) + img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img is None: + print(f"오류: 이미지를 열 수 없습니다: {image_path}") + sys.exit(1) + + H, W = img.shape[:2] + print(f"[이미지] {W} x {H} px") + + # 전처리: 작은 이미지로 축소 (속도) + scale = 1.0 + if W > 4000: + scale = 4000 / W + small = cv2.resize(img, (int(W * scale), int(H * scale))) + print(f" 축소: {int(W*scale)} x {int(H*scale)} (scale={scale:.2f})") + else: + small = img.copy() + + gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) + + # 엣지 검출 + print("[1단계] 엣지 검출...") + edges = detect_edges(gray) + edge_count = int(edges.sum() / 255) + print(f" 엣지 픽셀: {edge_count:,}") + + # Hough 직선 검출 + print("[2단계] Hough 직선 검출...") + lines = detect_lines(edges) + print(f" 검출된 선: {len(lines)}개") + + if not lines: + print("선을 검출하지 못했습니다. HOUGH_THRESHOLD를 낮춰보세요.") + sys.exit(1) + + # 클러스터링 + print("[3단계] 레일 클러스터링...") + clusters = cluster_lines(lines) + print(f" 클러스터: {len(clusters)}개") + + # 폴리라인 변환 + 길이 필터 + polylines = [] + for cluster in clusters: + poly = merge_cluster_to_polyline(cluster) + length = total_polyline_length(poly) + if length >= MIN_TOTAL_LEN / scale: + # 원본 해상도로 역변환 + poly_orig = [(p[0] / scale, p[1] / scale) for p in poly] + polylines.append((poly_orig, length / scale)) + + # 길이순 정렬 + polylines.sort(key=lambda x: -x[1]) + print(f" 유효 레일 라인: {len(polylines)}개") + + if not polylines: + print("유효한 레일을 찾지 못했습니다. MIN_TOTAL_LEN을 낮춰보세요.") + sys.exit(1) + + for i, (poly, length) in enumerate(polylines): + print(f" 레일 {i+1}: {length:.0f}px ({length*0.05:.1f}m)") + + # DXF 저장 + print("[4단계] DXF 저장...") + doc = ezdxf.new("R2010") + msp = doc.modelspace() + doc.layers.add("RAIL_AUTO", color=1) # 빨강 — 자동검출 레일 + doc.layers.add("RAIL_AUTO_LABEL", color=2) # 노랑 — 번호 + + for i, (poly, length) in enumerate(polylines): + # Y축 반전 (이미지→DXF) + dxf_pts = [(float(p[0]), float(-p[1])) for p in poly] + msp.add_lwpolyline(dxf_pts, dxfattribs={"layer": "RAIL_AUTO"}) + + # 중간점에 번호 텍스트 + mid = poly[len(poly) // 2] + msp.add_text( + f"Rail_{i+1}", + dxfattribs={ + "layer": "RAIL_AUTO_LABEL", + "height": 20, + "insert": (float(mid[0]), float(-mid[1])), + } + ) + + Path(dxf_path).parent.mkdir(parents=True, exist_ok=True) + doc.saveas(dxf_path) + print(f"[완료] DXF 저장: {dxf_path}") + print(f" 레이어: RAIL_AUTO(빨강) = {len(polylines)}개 레일 중심선") + + # 디버그 이미지 저장 + debug_path = Path(dxf_path).parent / (Path(dxf_path).stem + "_debug.jpg") + all_polys = [p for p, _ in polylines] + save_debug_image(small, [([(x*scale, y*scale) for x, y in p]) for p in all_polys], + debug_path) + + print(f"\nRhino 사용법:") + print(f" 1. File -> Import -> {Path(dxf_path).name}") + print(f" 2. RAIL_AUTO 레이어 선택") + print(f" 3. Sweep1 명령 -> 레일 단면 선택 -> 3D 레일 완성") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("사용법: python tools/auto_rail_detect.py [output.dxf]") + sys.exit(1) + + img_path = sys.argv[1] + dxf_path = sys.argv[2] if len(sys.argv) >= 3 else str( + Path(img_path).with_suffix(".dxf") + ) + process(img_path, dxf_path) diff --git a/tools/detect_all_objects.py b/tools/detect_all_objects.py new file mode 100644 index 0000000..39d3729 --- /dev/null +++ b/tools/detect_all_objects.py @@ -0,0 +1,565 @@ +""" +이미지에서 객체를 SAM3.1로 검출하여 색상별로 시각화. + +전략: + - cols×rows 타일로 분할 (overlap 중복) + - --tiles 로 처리할 타일 번호 지정 (예: 9-24, 1,5,9, 전체=all) + - --categories 로 JSON 설정 파일 로드 (카테고리·프롬프트·색상 정의) + - 타일당 SAM3.1 1회 호출 (모든 카테고리 프롬프트 합산) + - 병렬 처리(ThreadPoolExecutor) → NMS → 색상 시각화 + +사용법: + python tools/detect_all_objects.py \\ + --input data/역사이미지/slope/DJI_20260306113839_0005.JPG \\ + --categories configs/railway_zone.json \\ + --tiles 9-24 \\ + --cols 8 --rows 6 --overlap 0.10 --workers 4 +""" +import argparse +import base64 +import json +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import cv2 +import numpy as np +import requests + +SAM3_SERVER = "http://localhost:8000" +SAM3_MODEL_ID = "segment_anything_3" + +# 기본 카테고리 (--categories 미지정 시 사용) +_DEFAULT_CATEGORIES = [ + {"name": "railway", "prompt": "railroad track, railway rail, steel rail", + "keywords": ["railroad", "railway rail", "steel rail"], "color_bgr": [0, 200, 255]}, + {"name": "catenary_pole", "prompt": "railway catenary pole, overhead line pole, catenary mast", + "keywords": ["catenary pole", "overhead line pole", "catenary mast"], "color_bgr": [255, 130, 0]}, + {"name": "highway", "prompt": "highway road, expressway asphalt, paved road lane", + "keywords": ["highway", "expressway", "paved road"], "color_bgr": [160, 160, 160]}, + {"name": "vehicle", "prompt": "car, truck, vehicle, automobile", + "keywords": ["car", "truck", "vehicle", "automobile"], "color_bgr": [0, 255, 0]}, + {"name": "building", "prompt": "building, house, rooftop, structure", + "keywords": ["building", "house", "rooftop", "structure"], "color_bgr": [50, 50, 255]}, + {"name": "farmland", "prompt": "farmland, agricultural field, cropland, vegetable garden", + "keywords": ["farmland", "field", "cropland", "vegetable"], "color_bgr": [50, 200, 50]}, + {"name": "vegetation", "prompt": "trees, forest, shrubs, vegetation, bushes", + "keywords": ["tree", "forest", "shrub", "vegetation", "bush"], "color_bgr": [0, 120, 0]}, + {"name": "guardrail", "prompt": "guardrail, highway barrier, road fence, crash barrier", + "keywords": ["guardrail", "highway barrier", "road fence", "crash barrier"], "color_bgr": [200, 0, 200]}, + {"name": "bridge", "prompt": "bridge, overpass, viaduct", + "keywords": ["bridge", "overpass", "viaduct"], "color_bgr": [0, 165, 255]}, + {"name": "wire", "prompt": "overhead wire, catenary wire, electric cable line", + "keywords": ["catenary wire", "overhead wire", "electric cable"], "color_bgr": [200, 200, 255]}, +] + + +# ── 타일 번호 파싱 ──────────────────────────────────────────────────────────── +def parse_tiles(tile_str: str, total: int) -> set: + """'9-24', '1,3,5', 'all' → tile index 집합 (1-based).""" + if tile_str.lower() == "all": + return set(range(1, total + 1)) + result = set() + for part in tile_str.split(","): + part = part.strip() + if "-" in part: + a, b = part.split("-", 1) + result.update(range(int(a), int(b) + 1)) + else: + result.add(int(part)) + return result + + +# ── 카테고리 로드 ───────────────────────────────────────────────────────────── +def load_categories(json_path: str | None) -> list: + if json_path: + data = json.loads(Path(json_path).read_text(encoding="utf-8")) + return data["categories"] + return _DEFAULT_CATEGORIES + + +def label_to_category(label: str, categories: list) -> int: + label_l = label.lower() + for i, cat in enumerate(categories): + for kw in cat["keywords"]: + if kw in label_l: + return i + return -1 + + +def build_combined_prompt(categories: list) -> str: + return ", ".join(cat["prompt"] for cat in categories) + + +# ── SAM3 호출 ───────────────────────────────────────────────────────────────── +def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple: + h, w = image_bgr.shape[:2] + scale = 1.0 + if max(h, w) > max_size: + scale = max_size / max(h, w) + image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale))) + _, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) + return base64.b64encode(buf).decode("utf-8"), scale + + +def sam3_segment_tile(tile_bgr: np.ndarray, prompt: str, conf: float) -> list: + b64, scale = encode_image(tile_bgr) + payload = { + "model": SAM3_MODEL_ID, + "image": b64, + "params": {"text_prompt": prompt, "conf_threshold": conf, + "show_masks": True, "show_boxes": False}, + } + try: + r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120) + r.raise_for_status() + resp = r.json() + if not resp.get("success"): + return [] + shapes = resp.get("data", {}).get("shapes", []) + shapes = [s if isinstance(s, dict) else s.dict() for s in shapes] + if scale < 1.0: + inv = 1.0 / scale + for s in shapes: + if s.get("shape_type") == "polygon": + s["points"] = [[x * inv, y * inv] for x, y in s["points"]] + return [s for s in shapes if s.get("shape_type") == "polygon"] + except Exception: + return [] + + +# ── NMS ─────────────────────────────────────────────────────────────────────── +def _bbox(pts): + xs = [p[0] for p in pts]; ys = [p[1] for p in pts] + return min(xs), min(ys), max(xs), max(ys) + + +def _nms_core(shapes: list, iou_thresh: float) -> list: + """IoU 기반 NMS. shapes 각 항목에 score 필드 필요.""" + if not shapes: + return [] + bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32) + scores = np.array([float(s.get("score", 0.5)) for s in shapes]) + order = scores.argsort()[::-1] + keep = [] + while len(order): + i = order[0]; keep.append(i) + if len(order) == 1: break + xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0]) + yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1]) + xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2]) + yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3]) + inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1) + a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1]) + a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1]) + iou = inter / (a_i + a_j - inter + 1e-6) + order = order[1:][iou < iou_thresh] + return [shapes[i] for i in keep] + + +def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list: + return _nms_core(shapes, iou_thresh) + + +def _poly_orient(points: list, H: int, W: int) -> str: # post_merge_poles.py에서도 사용 + """폴리곤 장축 방향 판별 (render_skeleton_overlay.py 동일 로직). + + V: 장축이 이미지 중심에서 방사형 방향과 정렬 (cos_sim > 0.7) → 세로 기둥 + H: 장축이 radial 직교 방향 → 수평 빔 + ?: aspect ratio < 1.3 으로 판별 불가 + """ + pts = np.array(points, dtype=np.float32) + rect = cv2.minAreaRect(pts) + (rx, ry), (rw, rh), angle = rect + if min(rw, rh) < 1: + return '?' + ar = max(rw, rh) / min(rw, rh) + if ar < 1.3: + return '?' + long_angle_deg = angle if rw >= rh else angle + 90 + lx = float(np.cos(np.radians(long_angle_deg))) + ly = float(np.sin(np.radians(long_angle_deg))) + img_cx, img_cy = W / 2.0, H / 2.0 + rdx, rdy = rx - img_cx, ry - img_cy + radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5 + if radial_norm < 1: + return '?' + rdx, rdy = rdx / radial_norm, rdy / radial_norm + cos_sim = abs(lx * rdx + ly * rdy) + return 'V' if cos_sim > 0.7 else 'H' + + +def merge_nonramen_poles(shapes: list, H: int, W: int, + x_overlap_thresh: float = 0.30, + y_gap_thresh: int = 150) -> list: + """타일 경계 분할된 전철주 병합 — V+V 조합만 허용. + + _poly_orient로 각 폴리곤 V/H 분류. + 두 폴리곤 모두 V(세로 기둥)이고 공간 기준 충족 시만 병합. + H(수평 빔) 포함 쌍 = 라멘 관련 조각 → 병합 건너뜀. + """ + if len(shapes) <= 1: + return shapes + + orients = [_poly_orient(s["points"], H, W) for s in shapes] + v_count = sum(1 for o in orients if o == 'V') + h_count = sum(1 for o in orients if o == 'H') + print(f" [orient] V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}") + + def get_bbox(s): + xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]] + return min(xs), min(ys), max(xs), max(ys) + + def x_overlap_ratio(b1, b2): + ox = min(b1[2], b2[2]) - max(b1[0], b2[0]) + ux = max(b1[2], b2[2]) - min(b1[0], b2[0]) + return ox / ux if ux > 0 else 0.0 + + def y_gap(b1, b2): + return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3])) + + def merge_two(s1, s2): + mask = np.zeros((H, W), dtype=np.uint8) + for s in (s1, s2): + cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return s1 + c = max(contours, key=cv2.contourArea) + eps = 0.002 * cv2.arcLength(c, True) + approx = cv2.approxPolyDP(c, eps, True) + merged = dict(s1) + merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx] + merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0))) + return merged + + merged_flags = [False] * len(shapes) + result = [] + merged_count = 0 + for i in range(len(shapes)): + if merged_flags[i]: + continue + cur = shapes[i] + cur_ori = orients[i] + cb = get_bbox(cur) + for j in range(i + 1, len(shapes)): + if merged_flags[j]: + continue + if cur_ori != 'V' or orients[j] != 'V': + continue # 둘 다 V가 아니면 병합 안 함 + jb = get_bbox(shapes[j]) + if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh: + cur = merge_two(cur, shapes[j]) + cur_ori = 'V' + cb = get_bbox(cur) + merged_flags[j] = True + merged_count += 1 + result.append(cur) + print(f" [merge] 병합={merged_count}쌍") + return result + + +def cross_class_nms(buckets: list, categories: list, iou_thresh: float) -> list: + """클래스 간 NMS: 동일 영역에 다른 클래스가 중복 검출될 때 우선순위 높은 쪽 보존. + + 정렬 기준: (priority 오름차순, score 내림차순) + → priority 낮은 값(=중요 클래스)이 우선 보존됨. + """ + # 모든 shape에 클래스 인덱스 태깅 + tagged = [] + for i, shapes in enumerate(buckets): + priority = categories[i].get("priority", 99) + for s in shapes: + tagged.append((priority, -float(s.get("score", 0.5)), i, s)) + + # priority 오름차순, score 내림차순 정렬 + tagged.sort(key=lambda x: (x[0], x[1])) + + if not tagged: + return [[] for _ in buckets] + + all_shapes = [t[3] for t in tagged] + cls_ids = [t[2] for t in tagged] + + bboxes = np.array([_bbox(s["points"]) for s in all_shapes], dtype=np.float32) + suppressed = [False] * len(all_shapes) + + for i in range(len(all_shapes)): + if suppressed[i]: + continue + for j in range(i + 1, len(all_shapes)): + if suppressed[j]: + continue + xx1 = max(bboxes[i,0], bboxes[j,0]) + yy1 = max(bboxes[i,1], bboxes[j,1]) + xx2 = min(bboxes[i,2], bboxes[j,2]) + yy2 = min(bboxes[i,3], bboxes[j,3]) + inter = max(0, xx2-xx1) * max(0, yy2-yy1) + if inter == 0: + continue + a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1]) + a_j = (bboxes[j,2]-bboxes[j,0])*(bboxes[j,3]-bboxes[j,1]) + iou = inter / (a_i + a_j - inter + 1e-6) + if iou >= iou_thresh: + suppressed[j] = True # i가 우선순위 높으므로 j 제거 + + new_buckets = [[] for _ in buckets] + for i, (keep, cls_i, s) in enumerate(zip(suppressed, cls_ids, all_shapes)): + if not keep: + new_buckets[cls_i].append(s) + return new_buckets + + +# ── 타일 그리드 검출 (병렬) ─────────────────────────────────────────────────── +def detect_tiled(image_bgr: np.ndarray, cols: int, rows: int, overlap: float, + conf: float, workers: int, tile_filter: set, + prompt: str) -> list: + H, W = image_bgr.shape[:2] + base_w = W / cols + base_h = H / rows + pad_x = int(base_w * overlap) + pad_y = int(base_h * overlap) + + tiles = [] + for r in range(rows): + for c in range(cols): + idx = r * cols + c + 1 + if idx not in tile_filter: + continue + x0 = max(0, int(c * base_w) - pad_x) + x1 = min(W, int((c + 1) * base_w) + pad_x) + y0 = max(0, int(r * base_h) - pad_y) + y1 = min(H, int((r + 1) * base_h) + pad_y) + tiles.append((idx, x0, y0, x1, y1)) + + total = len(tiles) + done = [0] + all_shapes = [] + + def process(args): + idx, x0, y0, x1, y1 = args + tile = image_bgr[y0:y1, x0:x1] + shapes = sam3_segment_tile(tile, prompt, conf) + for s in shapes: + s["points"] = [[px + x0, py + y0] for px, py in s["points"]] + return shapes + + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = {ex.submit(process, t): t for t in tiles} + for fut in as_completed(futs): + all_shapes.extend(fut.result()) + done[0] += 1 + print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r") + print() + return all_shapes + + +# ── 시각화 ──────────────────────────────────────────────────────────────────── +def draw_detections(image_bgr: np.ndarray, buckets: list, + categories: list, tile_filter: set, + cols: int, rows: int, overlap: float) -> np.ndarray: + vis = image_bgr.copy() + H, W = vis.shape[:2] + + # 처리된 타일 경계 표시 + base_w = W / cols + base_h = H / rows + for r in range(rows): + for c in range(cols): + idx = r * cols + c + 1 + if idx in tile_filter: + bx0, by0 = int(c * base_w), int(r * base_h) + bx1, by1 = min(W, int((c+1)*base_w)), min(H, int((r+1)*base_h)) + cv2.rectangle(vis, (bx0, by0), (bx1, by1), (255, 255, 255), 1) + + # 마스크 + 순번 레이블 + font = cv2.FONT_HERSHEY_SIMPLEX + font_sc = max(0.4, min(W, H) / 8000) + thickness = max(1, int(font_sc * 2)) + + for i, cat in enumerate(categories): + color = tuple(cat["color_bgr"]) + prefix = cat["name"][:3].upper() # 예: RAI, CAT, BRA … + for seq, s in enumerate(buckets[i], start=1): + pts = np.array(s["points"], dtype=np.int32) + overlay = vis.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis) + cv2.polylines(vis, [pts], True, color, 2) + + # 무게중심에 순번 표시 + cx = int(np.mean(pts[:, 0])) + cy = int(np.mean(pts[:, 1])) + score = float(s.get("score", 0.0)) + label = f"{prefix}{seq:03d} {score:.2f}" + (tw, th), _ = cv2.getTextSize(label, font, font_sc, thickness) + tx, ty = cx - tw // 2, cy + th // 2 + # 배경 박스는 검정, 텍스트는 흰색으로 변경 + cv2.rectangle(vis, (tx - 2, ty - th - 2), (tx + tw + 2, ty + 2), + (0, 0, 0), -1) + cv2.putText(vis, label, (tx, ty), font, font_sc, + (255, 255, 255), thickness, cv2.LINE_AA) + + # 범례 + lx, ly = W - 280, 20 + panel_h = len(categories) * 24 + 10 + vis[0:panel_h, lx-8:W] = (vis[0:panel_h, lx-8:W] * 0.35).astype(np.uint8) + for i, cat in enumerate(categories): + color = tuple(cat["color_bgr"]) + prefix = cat["name"][:3].upper() + cv2.rectangle(vis, (lx, ly-13), (lx+15, ly+3), color, -1) + cv2.putText(vis, f"[{prefix}] {cat['name']} ({len(buckets[i])})", + (lx+20, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.50, color, 1, cv2.LINE_AA) + ly += 24 + + return vis + + +# ── 메인 ────────────────────────────────────────────────────────────────────── +def main(): + ap = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__) + ap.add_argument("--input", required=True, help="입력 이미지 경로") + ap.add_argument("--output", default=None, help="출력 이미지 경로 (기본: 입력파일명_out.jpg)") + ap.add_argument("--categories", default=None, help="카테고리 JSON 경로 (기본: 내장 10개)") + ap.add_argument("--tiles", default="all", help="처리할 타일 번호: 9-24 / 1,5,9 / all (기본: all)") + ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본: 8)") + ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본: 6)") + ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본: 0.10)") + ap.add_argument("--conf", type=float, default=0.20, help="SAM3 신뢰도 임계값 (기본: 0.20)") + ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본: 4)") + ap.add_argument("--save-labels", action="store_true", help="YOLO 폴리곤 포맷 .txt 라벨 파일 저장") + ap.add_argument("--save-json", action="store_true", help="AnyLabeling JSON 포맷 저장 (railway.json 동일 양식)") + args = ap.parse_args() + + img_path = Path(args.input) + buf = np.fromfile(str(img_path), dtype=np.uint8) + image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image_bgr is None: + print(f"이미지 로드 실패: {img_path}"); return + + H, W = image_bgr.shape[:2] + total_tiles = args.cols * args.rows + tile_filter = parse_tiles(args.tiles, total_tiles) + categories = load_categories(args.categories) + combined_prompt = build_combined_prompt(categories) + + # cross-class NMS IoU: JSON > 기본값 0.45 + cc_iou = 0.45 + if args.categories: + raw = json.loads(Path(args.categories).read_text(encoding="utf-8")) + cc_iou = raw.get("cross_class_nms_iou", cc_iou) + + # SAM3 호출용 conf = 모든 카테고리 conf 중 최솟값 (낮은 쪽부터 받아 후처리로 필터) + sam3_conf = min(cat.get("conf", args.conf) for cat in categories) + + print(f"이미지 : {W}×{H}") + print(f"타일 그리드: {args.cols}×{args.rows}={total_tiles}개 | 처리 대상: {sorted(tile_filter)}") + print(f"카테고리 : {len(categories)}개 | 중복: {args.overlap*100:.0f}%") + print(f"SAM3 conf : {sam3_conf} (전체 최솟값) | cross-class NMS IoU: {cc_iou}") + print(f"SAM3 호출 : {len(tile_filter)}회 | 병렬: {args.workers}스레드\n") + + t0 = time.time() + all_shapes = detect_tiled(image_bgr, args.cols, args.rows, args.overlap, + sam3_conf, args.workers, tile_filter, combined_prompt) + + print(f"전체 검출 {len(all_shapes)}개 → 분류 + per-class conf 필터 + NMS...") + buckets = [[] for _ in categories] + unmatched = 0 + for s in all_shapes: + idx = label_to_category(s.get("label", ""), categories) + if idx < 0: + unmatched += 1 + continue + # per-class conf 필터 + cat_conf = categories[idx].get("conf", args.conf) + if float(s.get("score", 0.0)) < cat_conf: + continue + buckets[idx].append(s) + + # 클래스 내 NMS + print(" [1] 클래스 내 NMS") + for i, cat in enumerate(categories): + before = len(buckets[i]) + buckets[i] = nms_shapes(buckets[i]) + print(f" {cat['name']:18s}: {before:3d} → {len(buckets[i]):3d}개 (conf≥{cat.get('conf', args.conf)})") + + # 클래스 간 NMS + total_before = sum(len(b) for b in buckets) + print(f" [2] cross-class NMS (IoU≥{cc_iou})") + buckets = cross_class_nms(buckets, categories, cc_iou) + total_after = sum(len(b) for b in buckets) + print(f" {total_before}개 → {total_after}개") + + if unmatched: + print(f" (미분류/conf미달 {unmatched}개 제외)") + print(f"\n완료: {time.time()-t0:.0f}초") + + vis = draw_detections(image_bgr, buckets, categories, + tile_filter, args.cols, args.rows, args.overlap) + h, w = vis.shape[:2] + if max(h, w) > 4096: + s = 4096 / max(h, w) + vis = cv2.resize(vis, (int(w*s), int(h*s))) + + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + else: + tile_tag = args.tiles.replace(",", "_").replace("-", "to") + cat_tag = Path(args.categories).stem if args.categories else "default" + out_dir = Path("output") / "detect" / img_path.stem + out_dir.mkdir(parents=True, exist_ok=True) + base_name = f"tiles{tile_tag}_{cat_tag}" + n = 1 + while True: + out_path = out_dir / f"{base_name}_{n:03d}.jpg" + if not out_path.exists(): + break + n += 1 + + cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path)) + print(f"저장: {out_path}") + + if args.save_labels: + label_path = out_path.with_suffix(".txt") + with open(label_path, "w", encoding="utf-8") as f: + for cls_idx, shapes in enumerate(buckets): + for s in shapes: + pts_norm = [[px / W, py / H] for px, py in s["points"]] + coords = " ".join(f"{x:.6f} {y:.6f}" for x, y in pts_norm) + f.write(f"{cls_idx} {coords}\n") + print(f"라벨 저장: {label_path}") + + if args.save_json: + import json as _json + json_shapes = [] + for cls_idx, shapes in enumerate(buckets): + cat_name = categories[cls_idx]["name"] if cls_idx < len(categories) else str(cls_idx) + for s in shapes: + json_shapes.append({ + "label": cat_name, + "score": float(s.get("score", 0.0)), + "points": [[float(px), float(py)] for px, py in s["points"]], + "group_id": None, + "description": None, + "shape_type": "polygon", + "flags": None, + }) + anylabel = { + "version": "3.3.9", + "flags": {}, + "shapes": json_shapes, + "imagePath": img_path.name, + "imageData": None, + "imageHeight": H, + "imageWidth": W, + } + json_path = out_path.with_suffix(".json") + json_path.write_text(_json.dumps(anylabel, ensure_ascii=False, indent=2), + encoding="utf-8") + print(f"JSON 저장: {json_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/detect_hollow_section.py b/tools/detect_hollow_section.py new file mode 100644 index 0000000..dcc7549 --- /dev/null +++ b/tools/detect_hollow_section.py @@ -0,0 +1,152 @@ +""" +얇은 중공단면(도너츠 형태) 검출 및 표시 스크립트 +드론 영상에서 전철주 상단 원형 단면을 찾아 표시 + +사용법: + python detect_hollow_section.py [--radius ] [--tol ] [--topk ] + + --radius : 3D 프로그램 줌 기준 단면 반지름 (픽셀). 미지정 시 자동 탐색. + --tol : radius ± 허용 오차 (기본 30px) + --topk : 표시할 최대 후보 수 (기본 3) +""" +import argparse +import cv2 +import numpy as np +from pathlib import Path + + +def ring_score(edges: np.ndarray, cx: int, cy: int, r: int, + H: int, W: int, inner_ratio: float = 0.60) -> float: + """ + 링(도넛) 특성 점수. + - 링 영역(외곽~내곽) 엣지 강도 / 내부 엣지 강도 비율 (최대 5 클리핑) + - 이미지 중심 근접도 보정 (중심에 가까울수록 가산점) + Canny edges는 외부에서 한 번만 계산해 전달받음. + """ + inner_r = max(int(r * inner_ratio), 3) + + mask_full = np.zeros((H, W), np.uint8) + mask_inner = np.zeros((H, W), np.uint8) + cv2.circle(mask_full, (cx, cy), r, 255, -1) + cv2.circle(mask_inner, (cx, cy), inner_r, 255, -1) + mask_ring = cv2.subtract(mask_full, mask_inner) + + if mask_ring.any() == 0 or mask_inner.any() == 0: + return 0.0 + + ring_edge = float(edges[mask_ring > 0].mean()) + inner_edge = float(edges[mask_inner > 0].mean()) + 1e-6 + edge_ratio = min(ring_edge / inner_edge, 5.0) + + dist_from_center = np.hypot(cx - W / 2, cy - H / 2) + max_dist = np.hypot(W / 2, H / 2) + center_bonus = 1.0 - dist_from_center / max_dist # 0~1 + + return edge_ratio * (0.7 + 0.3 * center_bonus) + + +def detect_hollow_section( + image_path: str, + output_path: str | None = None, + radius_px: int | None = None, # 3D 프로그램 줌 기준 알려진 반지름 + tol_px: int = 30, # radius_px ± 허용 오차 + top_k: int = 3, +) -> str: + img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR) + if img is None: + raise FileNotFoundError(f"이미지 로드 실패: {image_path}") + + H, W = img.shape[:2] + print(f"이미지 크기: {W}×{H}") + + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) + blurred = cv2.GaussianBlur(clahe.apply(gray), (7, 7), 2) + + # Canny 엣지를 미리 한 번만 계산 (ring_score에서 재사용) + edges = cv2.Canny(blurred, 30, 90) + + # 반경 범위 결정 + if radius_px is not None: + min_r = max(5, radius_px - tol_px) + max_r = radius_px + tol_px + else: + short = min(W, H) + min_r = max(10, short // 20) + max_r = short // 3 + + print(f"탐색 반경: {min_r} ~ {max_r} px" + + (f" (기준={radius_px}px ±{tol_px})" if radius_px else " (자동)")) + + circles = cv2.HoughCircles( + blurred, + cv2.HOUGH_GRADIENT, + dp=1.0, + minDist=min_r * 3, + param1=80, + param2=35, + minRadius=min_r, + maxRadius=max_r, + ) + + result = img.copy() + + if circles is None: + print("HoughCircles 미검출") + else: + circles = np.round(circles[0]).astype(int) + print(f"HoughCircles 후보: {len(circles)}개 → 링 스코어 계산 중...") + + scored = sorted( + [(ring_score(edges, cx, cy, r, H, W), cx, cy, r) + for cx, cy, r in circles], + reverse=True, + ) + + # 1위: 녹색(굵게), 2위: 노랑, 3위: 파랑 + colors = [(0, 255, 0), (0, 220, 255), (255, 120, 0)] + + print(f"\n상위 {min(top_k, len(scored))}개 중공단면 후보:") + for rank, (score, cx, cy, r) in enumerate(scored[:top_k]): + color = colors[rank % len(colors)] + inner_r = max(int(r * 0.60), 5) + lw = 3 if rank == 0 else 2 + + cv2.circle(result, (cx, cy), r, color, lw) + cv2.circle(result, (cx, cy), inner_r, color, 1) + cv2.circle(result, (cx, cy), 6, (0, 0, 255), -1) + arm = r + 15 + cv2.line(result, (cx - arm, cy), (cx + arm, cy), color, 1) + cv2.line(result, (cx, cy - arm), (cx, cy + arm), color, 1) + cv2.putText(result, f"#{rank+1} r={r}px s={score:.2f}", + (cx - r, cy - r - 8), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 2) + print(f" #{rank+1}: 중심=({cx},{cy}), 반지름={r}px, 링스코어={score:.3f}") + + if output_path is None: + p = Path(image_path) + output_path = str(p.parent / (p.stem + "_hollow_detected" + p.suffix)) + + cv2.imencode(Path(output_path).suffix, result)[1].tofile(output_path) + print(f"\n결과 저장: {output_path}") + return output_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="얇은 중공단면 검출") + parser.add_argument("image", nargs="?", + default=r"d:\MYCLAUDE_PROJECT\x-anylabeling01\data\Message_2026-04-23T11_13_29+09_00.png") + parser.add_argument("--radius", type=int, default=None, + help="3D 프로그램 줌 기준 단면 반지름(px). 미지정 시 자동 탐색.") + parser.add_argument("--tol", type=int, default=30, + help="radius ± 허용 오차 (기본 30px)") + parser.add_argument("--topk", type=int, default=3, + help="표시할 최대 후보 수 (기본 3)") + args = parser.parse_args() + + detect_hollow_section( + args.image, + radius_px=args.radius, + tol_px=args.tol, + top_k=args.topk, + ) diff --git a/tools/detect_raamen.py b/tools/detect_raamen.py new file mode 100644 index 0000000..c1a997e --- /dev/null +++ b/tools/detect_raamen.py @@ -0,0 +1,666 @@ +""" +드론 사선 촬영 이미지에서 라멘형(門자형) 전철주 검출. +기존 픽셀 위상수학(Skeleton) 방식을 공간 기하학 방식으로 교체. + +파이프라인: + Phase 1: 폴리곤 단순화(approxPolyDP) + 소실점(Vanishing Point) 계산 + Phase 2: 동적 V/H 분류 (소실점 기반 기대 각도) + Phase 3: 근접성 기반 그룹핑 (H 앵커 → 아래 V 탐색) + Phase 4: 라멘 구조 판정 + 예외(가림) 처리 + +사용: + python tools/detect_raamen.py \ + --image --label --output \ + [--class-ids 1] [--epsilon 4.0] [--v-thresh 20.0] +""" +import argparse +import numpy as np +import cv2 +from pathlib import Path + + +# ── Phase 1: 파싱 + 단순화 + 소실점 ───────────────────────────────────── + +def load_polygons(label_path, W, H, class_ids=None, class_names=None): + """ + AnyLabeling JSON 또는 YOLO .txt 자동 감지 후 파싱. + Returns: (픽셀 좌표 폴리곤 리스트, 절대 shapes[] 인덱스 리스트) + 절대 인덱스는 AnyLabeling JSON shapes[] 배열 인덱스와 동일. + """ + import json as _json + if not label_path.exists(): + raise FileNotFoundError(f"라벨 파일을 찾을 수 없습니다: {label_path}") + + if label_path.suffix.lower() == ".json": + data = _json.loads(label_path.read_text(encoding="utf-8")) + shapes = data.get("shapes", []) + polys, abs_indices = [], [] + for abs_idx, s in enumerate(shapes): + if class_names is not None and s.get("label", "") not in class_names: + continue + pts = np.array([[float(p[0]), float(p[1])] for p in s.get("points", [])], + dtype=np.float32) + if len(pts) >= 3: + polys.append(pts) + abs_indices.append(abs_idx) + return polys, abs_indices + + # YOLO .txt + polys, abs_indices = [], [] + for abs_idx, line in enumerate(label_path.read_text(encoding="utf-8").splitlines()): + parts = line.split() + if not parts: + continue + if class_ids is not None and int(parts[0]) not in class_ids: + continue + coords = list(map(float, parts[1:])) + pts = np.array( + [[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)], + dtype=np.float32, + ) + if len(pts) >= 3: + polys.append(pts) + abs_indices.append(abs_idx) + return polys, abs_indices + + +def smooth_polygon(pts, epsilon): + """approxPolyDP로 노이즈 경계 → 직선 위주 단순화.""" + approx = cv2.approxPolyDP(pts.astype(np.int32), epsilon, closed=True) + result = approx.reshape(-1, 2).astype(np.float32) + return result if len(result) >= 3 else pts.astype(np.float32) + + +def _minrect(pts): + """cv2.minAreaRect 래퍼 (int32 변환 포함).""" + return cv2.minAreaRect(pts.astype(np.int32)) + + +def _long_axis_angle(pts): + """minAreaRect 장축 각도 (degrees) 반환.""" + (cx, cy), (rw, rh), angle = _minrect(pts) + return angle if rw >= rh else angle + 90, cx, cy, max(rw, rh), min(rw, rh) + + +def _radial_cos_sim(pts, img_cx, img_cy): + """VP 후보 bootstrap용 기존 radial cos_sim.""" + long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts) + if short_side < 1 or long_side / short_side < 1.3: + return 0.0 + lx = np.cos(np.radians(long_angle_deg)) + ly = np.sin(np.radians(long_angle_deg)) + rdx, rdy = cx - img_cx, cy - img_cy + n = (rdx ** 2 + rdy ** 2) ** 0.5 + return abs(lx * rdx / n + ly * rdy / n) if n > 1 else 0.0 + + +def compute_vanishing_point(polys): + """ + 후보 폴리곤 장축 선분들의 최소자승 교점 (소실점) 계산. + 각 장축 선분: 법선 n=(-dy, dx), 방정식 -dy*x + dx*y = -dy*cx + dx*cy + """ + A_rows, b_vals = [], [] + for pts in polys: + long_angle_deg, cx, cy, _, _ = _long_axis_angle(pts) + dx = np.cos(np.radians(long_angle_deg)) + dy = np.sin(np.radians(long_angle_deg)) + A_rows.append([-dy, dx]) + b_vals.append(-dy * cx + dx * cy) + vp, *_ = np.linalg.lstsq(np.array(A_rows), np.array(b_vals), rcond=None) + return float(vp[0]), float(vp[1]) + + +def _estimate_vp_iterative(polys, seed_indices, v_thresh, h_max_diff, vp_min_len, + x_horiz_thresh=10.0, max_iter=6): + """초기 후보에서 반복 정제 VP 추정. Returns (vp_x, vp_y, n_v, orients, adiffs).""" + n = len(polys) + orients = ['?'] * n + adiffs = [90.0] * n + if len(seed_indices) < 2: + return None, None, 0, orients, adiffs + vp_x, vp_y = compute_vanishing_point([polys[i] for i in seed_indices]) + for _ in range(max_iter): + for i, pts in enumerate(polys): + orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff, + x_horiz_thresh) + v_cands = [i for i in range(n) + if orients[i] == 'V' and _long_axis_angle(polys[i])[3] > vp_min_len] + if len(v_cands) < 3: + break + nx, ny = compute_vanishing_point([polys[i] for i in v_cands]) + shift = ((nx - vp_x) ** 2 + (ny - vp_y) ** 2) ** 0.5 + vp_x, vp_y = nx, ny + if shift < 5.0: + for i, pts in enumerate(polys): + orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff, + x_horiz_thresh) + break + return vp_x, vp_y, orients.count('V'), orients, adiffs + + +# ── Phase 2: 동적 V/H 분류 ────────────────────────────────────────────── + +def classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff=75.0, x_horiz_thresh=10.0): + """ + 소실점 기준 V/H 분류. + - 이미지 절대 수평(X축 ±x_horiz_thresh°) AND AR≥4 → '?' (레일·전선 등) + - diff < v_thresh → V (기둥) + - v_thresh ≤ diff < h_max_diff → H (빔) + - diff ≥ h_max_diff → '?' + Returns: (orient, angle_diff_deg) + """ + long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts) + if short_side < 1 or long_side / short_side < 1.3: + return '?', 90.0 + + # 절대 수평 제외: X축 ±x_horiz_thresh° 이내 + AR≥4 (레일·전선 등 가느다란 수평체) + abs_from_horiz = long_angle_deg % 180.0 + if abs_from_horiz > 90.0: + abs_from_horiz = 180.0 - abs_from_horiz + if abs_from_horiz < x_horiz_thresh and long_side / short_side >= 4.0: + return '?', 90.0 + + # VP 기준 상대 각도 분류 + exp_angle = np.degrees(np.arctan2(vp_y - cy, vp_x - cx)) + diff = abs(long_angle_deg - exp_angle) % 180.0 + if diff > 90.0: + diff = 180.0 - diff + if diff < v_thresh: + return 'V', diff + if diff < h_max_diff: + return 'H', diff + return '?', diff + + +# ── Phase 3: 폴리곤 접촉/교차 기반 그룹핑 ─────────────────────────────── + +def connectivity_groups(polys, orients, margin=30): + """ + 폴리곤 bbox가 margin px 이내로 닿거나 교차하면 같은 그룹 (H/V 구분 없음). + Union-Find로 연결된 폴리곤들을 묶은 뒤, 각 그룹 내에서 H/V 목록 분리. + Returns: list of {'id': int, 'H': [idx,...], 'V': [idx,...]} + """ + n = len(polys) + parent = list(range(n)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a, b): + ra, rb = find(a), find(b) + if ra != rb: + parent[ra] = rb + + # 폴리곤 bbox 미리 계산 + bboxes = [] + for pts in polys: + bboxes.append((pts[:, 0].min(), pts[:, 1].min(), + pts[:, 0].max(), pts[:, 1].max())) + + # bbox가 margin 이내로 닿거나 교차하면 union + for i in range(n): + ax0, ay0, ax1, ay1 = bboxes[i] + for j in range(i + 1, n): + bx0, by0, bx1, by1 = bboxes[j] + if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and + ay1 + margin >= by0 and by1 + margin >= ay0): + union(i, j) + + # 연결 컴포넌트 수집 + from collections import defaultdict + comp = defaultdict(list) + for i in range(n): + comp[find(i)].append(i) + + groups = [] + for gid, members in enumerate(comp.values(), 1): + h_list = sorted(i for i in members if orients[i] == 'H') + v_list = sorted(i for i in members if orients[i] == 'V') + groups.append({'id': gid, 'H': h_list, 'V': v_list}) + + # 면적 내림차순으로 ID 재부여 (큰 그룹이 G1) + for g in groups: + g['area'] = sum(cv2.contourArea(polys[i].astype(np.int32)) + for i in g['H'] + g['V']) + groups.sort(key=lambda x: x['area'], reverse=True) + for gid, g in enumerate(groups, 1): + g['id'] = gid + + return groups + + +# ── Phase 4: 라멘 구조 판정 ───────────────────────────────────────────── + +def _cluster_polys(indices, polys, margin=60): + """ + 인접 폴리곤을 클러스터링 → 물리적 객체(기둥) 단위 반환. + Returns: list of [idx, ...] clusters + """ + if not indices: + return [] + n = len(indices) + parent = list(range(n)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + bboxes = [(polys[i][:, 0].min(), polys[i][:, 1].min(), + polys[i][:, 0].max(), polys[i][:, 1].max()) for i in indices] + for a in range(n): + ax0, ay0, ax1, ay1 = bboxes[a] + for b in range(a + 1, n): + bx0, by0, bx1, by1 = bboxes[b] + if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and + ay1 + margin >= by0 and by1 + margin >= ay0): + ra, rb = find(a), find(b) + if ra != rb: + parent[ra] = rb + + from collections import defaultdict + clusters = defaultdict(list) + for i in range(n): + clusters[find(i)].append(indices[i]) + return list(clusters.values()) + + +def judge_raamen(group, polys, W, center_ratio=0.2, v_cluster_margin=60): + """ + 라멘 구조 판정. + - 빔 기준: 그룹 내 가장 큰 H 폴리곤 + - 기둥 수: V 폴리곤을 근접 클러스터링한 클러스터 수 + - 빔 x 범위(±50%) 밖의 V 클러스터는 잡폴리곤으로 무시 + Returns: ('RAAMEN' | 'RAAMEN_OCCLUDED' | 'PARTIAL' | '', n_poles) + """ + hs, vs = group['H'], group['V'] + + # H 없음: 중앙 영역이면 V/H 분류 자체가 신뢰 불가 (기둥·빔 각도 수렴) + # → 2개 이상의 V 폴리곤이 중앙에 모여 있으면 RAAMEN_CENTER 처리 + if not hs: + if len(vs) >= 2: + all_v_pts = np.vstack([polys[i] for i in vs]) + gcx = float(all_v_pts[:, 0].mean()) + if abs(gcx - W / 2) < W * center_ratio: + return 'RAAMEN_CENTER', len(vs) + return '', 0 + + # 큰 H 폴리곤 최대 2개를 빔 기준으로 사용 (2nd가 1st 면적의 50% 이상이면 포함) + h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True) + main_hs = [h_by_area[0]] + if len(h_by_area) > 1: + a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32)) + a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32)) + if a1 >= a0 * 0.5: + main_hs.append(h_by_area[1]) + hx0 = int(min(polys[i][:, 0].min() for i in main_hs)) + hx1 = int(max(polys[i][:, 0].max() for i in main_hs)) + hcx = (hx0 + hx1) / 2.0 + is_center = abs(hcx - W / 2) < W * center_ratio + span = hx1 - hx0 + + # V 폴리곤 클러스터링 → 기둥 단위 + pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin) + + # 기둥은 빔 양 끝단(좌 35% / 우 35%)에만 존재. 중앙부 클러스터는 부속물로 제외. + x_tol = max(span * 0.5, 50) + left_zone = hx0 + span * 0.35 # 좌끝단 경계 + right_zone = hx1 - span * 0.35 # 우끝단 경계 + valid_cxs = [] + for cluster in pole_clusters: + ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster])) + in_range = hx0 - x_tol <= ccx <= hx1 + x_tol + in_end_zone = ccx <= left_zone or ccx >= right_zone + if in_range and in_end_zone: + valid_cxs.append(ccx) + + n_poles = len(valid_cxs) + + if n_poles >= 2: + lcx, rcx = min(valid_cxs), max(valid_cxs) + if lcx <= hx0 + span * 0.4 and rcx >= hx1 - span * 0.4: + return 'RAAMEN', n_poles + return 'PARTIAL', n_poles + + if n_poles == 1: + return ('RAAMEN_OCCLUDED', 1) if not is_center else ('', 1) + + return '', 0 + + +def _merge_poly_hull(indices, polys): + """여러 폴리곤 점들을 합쳐 Convex Hull 좌표 반환 (AnyLabeling points 형식).""" + all_pts = np.vstack([polys[i] for i in indices]) + hull = cv2.convexHull(all_pts.astype(np.int32)) + return [[float(p[0][0]), float(p[0][1])] for p in hull] + + +def group_detail(group, polys, W, center_ratio=0.2, v_cluster_margin=60): + """ + 라멘 그룹의 세부 구성 분석. + Returns dict: main_h, junk_h, valid_pole_clusters, attach_clusters + """ + hs, vs = group['H'], group['V'] + if not hs: + return {'main_h': None, 'junk_h': [], 'valid_pole_clusters': [], 'attach_clusters': []} + + h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True) + main_hs = [h_by_area[0]] + if len(h_by_area) > 1: + a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32)) + a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32)) + if a1 >= a0 * 0.5: + main_hs.append(h_by_area[1]) + main_h = main_hs[0] # JSON 출력용 대표 빔 + junk_h = [i for i in hs if i not in main_hs] + + hx0 = int(min(polys[i][:, 0].min() for i in main_hs)) + hx1 = int(max(polys[i][:, 0].max() for i in main_hs)) + span = hx1 - hx0 + + pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin) + x_tol = max(span * 0.5, 50) + left_zone = hx0 + span * 0.35 + right_zone = hx1 - span * 0.35 + + valid_pole_clusters, attach_clusters = [], [] + for cluster in pole_clusters: + ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster])) + in_range = hx0 - x_tol <= ccx <= hx1 + x_tol + in_end_zone = ccx <= left_zone or ccx >= right_zone + if in_range and in_end_zone: + valid_pole_clusters.append(cluster) + elif in_range: + attach_clusters.append(cluster) + + return { + 'main_h': main_h, + 'main_hs': main_hs, + 'junk_h': junk_h, + 'valid_pole_clusters': valid_pole_clusters, + 'attach_clusters': attach_clusters, + } + + +# ── 시각화 상수 ───────────────────────────────────────────────────────── + +_VH_COLOR = { + 'V': (255, 80, 0), # 주황 (수직 기둥) + 'H': ( 0, 80, 255), # 파란 (수평 빔) + '?': (140, 140, 140), # 회색 (미분류) +} +_RAAMEN_COLOR = { + 'RAAMEN': ( 0, 255, 0), # 초록 + 'RAAMEN_CENTER': ( 0, 255, 255), # 노랑 (중앙 영역, H/V 분류 불신뢰) + 'RAAMEN_OCCLUDED': ( 0, 165, 255), # 주황 (가림/부분 검출) + 'PARTIAL': (128, 128, 128), # 회색 +} + + +# ── 메인 렌더링 ───────────────────────────────────────────────────────── + +def render(image_path, label_path, output_path, args, + class_ids=None, class_names=None, epsilon=4.0, v_thresh=20.0, + h_max_diff=75.0, vp_min_ar=2.5, vp_min_len=80.0, vp_outer_ratio=0.2, + x_horiz_thresh=10.0): + + buf = np.fromfile(str(image_path), dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + H, W = img.shape[:2] + + # ── Phase 1 ────────────────────────────────────────────────────────── + raw_polys, poly_abs_idx = load_polygons(label_path, W, H, class_ids, class_names) + polys = [smooth_polygon(p, epsilon) for p in raw_polys] + print(f" {len(polys)}개 폴리곤 파싱 (epsilon={epsilon})") + + img_cx, img_cy = W / 2.0, H / 2.0 + + # ── Phase 1+2: 두 가지 VP 시드 방식으로 시도 → V 폴리곤이 더 많은 VP 채택 ── + elong_idx = [] + for i, pts in enumerate(polys): + la, cx, cy, long_side, short_side = _long_axis_angle(pts) + if short_side > 0 and long_side / short_side > vp_min_ar and long_side > vp_min_len: + elong_idx.append(i) + + # 시드 A: 지배적 장축 각도 클러스터 (이미지 방향 비의존) + seeds_A = [] + if len(elong_idx) >= 2: + avals_all = [(i, _long_axis_angle(polys[i])[0] % 180.0) for i in elong_idx] + hist, edges = np.histogram([a for _, a in avals_all], bins=12, range=(0.0, 180.0)) + pk = int(np.argmax(hist)) + peak_c = (edges[pk] + edges[pk + 1]) / 2.0 + seeds_A = [i for i, a in avals_all + if min(abs(a - peak_c), 180.0 - abs(a - peak_c)) < 20.0] + + # 시드 B: radial cos_sim (이미지 중심 방사 방향 정렬) + seeds_B = [] + for i in elong_idx: + _, cx, cy, _, _ = _long_axis_angle(polys[i]) + dist = ((cx - img_cx) ** 2 + (cy - img_cy) ** 2) ** 0.5 + if dist > min(W, H) * vp_outer_ratio and _radial_cos_sim(polys[i], img_cx, img_cy) > 0.5: + seeds_B.append(i) + + # 두 시드 모두 시도 → V가 더 많이 나오는 VP 채택 + best_vp_x, best_vp_y = img_cx, -H * 3.0 + best_n_v = 0 + orients, adiffs = ['?'] * len(polys), [90.0] * len(polys) + for label, seeds in [('지배각', seeds_A), ('radial', seeds_B)]: + if len(seeds) < 2: + continue + vx, vy, nv, ors, ads = _estimate_vp_iterative( + polys, seeds, v_thresh, h_max_diff, vp_min_len, x_horiz_thresh) + print(f" VP [{label}]: ({vx:.0f}, {vy:.0f}) V={nv}") + if nv > best_n_v: + best_vp_x, best_vp_y, best_n_v = vx, vy, nv + orients, adiffs = ors, ads + vp_x, vp_y = best_vp_x, best_vp_y + print(f" → 채택 VP: ({vp_x:.1f}, {vp_y:.1f})") + + print(f"\n [V/H 분류] threshold=±{v_thresh}° h_max=±{h_max_diff}°") + for i in range(len(polys)): + print(f" poly {i:>2d}: {orients[i]} diff={adiffs[i]:.1f}°") + print(f" V:{orients.count('V')} H:{orients.count('H')} ?:{orients.count('?')}") + + # ── Phase 3 ────────────────────────────────────────────────────────── + groups = connectivity_groups(polys, orients, margin=args.margin) + print(f"\n [그룹핑] 연결 컴포넌트 {len(groups)}개 (margin={args.margin}px)") + for g in groups: + print(f" G{g['id']}: H={g['H']} V={g['V']}") + + # ── Phase 4 ────────────────────────────────────────────────────────── + print(f"\n [라멘 판정]") + for g in groups: + verdict, n_poles = judge_raamen(g, polys, W) + g['verdict'] = verdict + g['n_poles'] = n_poles + pole_str = f"{n_poles}poles" if n_poles else "-" + print(f" G{g['id']}: H={g['H']} V={g['V']} → {verdict or '-':18s} ({pole_str})") + + valid_items = [g for g in groups if g['verdict']] + valid_items.sort(key=lambda x: x['area'], reverse=True) + + print(f"\n [최종 라멘 객체] {len(valid_items)}개 (면적순)") + for g in valid_items: + print(f" G{g['id']}: H={g['H']} V={g['V']} → {g['verdict']:18s} ({g['n_poles']}poles) Area={g['area']:,.0f}") + + # 최소 면적 필터링 + if args.min_group_area > 0: + before = len(valid_items) + valid_items = [g for g in valid_items if g['area'] >= args.min_group_area] + print(f" [필터] 최소 면적 {args.min_group_area} 미만 제거: {before} → {len(valid_items)}개") + + # 모든 그룹: 그룹 내 최하단 꼭짓점 포함 폴리곤이 H이면 → V로 재분류 + for g in valid_items: + all_idxs = g['H'] + g['V'] + bottom_vi = max(all_idxs, key=lambda i: polys[i][:, 1].max()) + if bottom_vi in g['H']: + g['H'].remove(bottom_vi) + g['V'].append(bottom_vi) + orients[bottom_vi] = 'V' + # RAAMEN_CENTER (H 없는 그룹): 최하단=V, 나머지=H로 display 재조정 + if not g['H']: + for vi in g['V']: + orients[vi] = 'V' if vi == bottom_vi else 'H' + + # ── 시각화 ─────────────────────────────────────────────────────────── + + # 1. 폴리곤 V/H 색상 반투명 오버레이 + for i, (pts, orient) in enumerate(zip(polys, orients)): + color = _VH_COLOR[orient] + pts_i = pts.astype(np.int32) + ov = img.copy() + cv2.fillPoly(ov, [pts_i], color) + cv2.addWeighted(ov, 0.25, img, 0.75, 0, img) + cv2.polylines(img, [pts_i], True, color, 2) + cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean()) + lbl = f"{i}{orient}" + cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4) + cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) + + # 2. VP 시드 폴리곤에 청록 테두리 (채택된 시드셋 재구성) + for i in seeds_A + [i for i in seeds_B if i not in seeds_A]: + cv2.polylines(img, [polys[i].astype(np.int32)], True, (0, 220, 220), 3) + + # 3. 소실점 표시 (이미지 내부: 원, 외부: 방향 화살표) + vp_ix, vp_iy = int(vp_x), int(vp_y) + if 0 <= vp_ix < W and 0 <= vp_iy < H: + cv2.circle(img, (vp_ix, vp_iy), 20, (0, 255, 255), 3) + cv2.putText(img, "VP", (vp_ix + 25, vp_iy), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2) + else: + ax_s, ay_s = W // 2, H // 5 + dx, dy = vp_x - img_cx, vp_y - img_cy + n = (dx ** 2 + dy ** 2) ** 0.5 + if n > 0: + ax_e = max(0, min(W - 1, int(ax_s + dx / n * 80))) + ay_e = max(0, min(H - 1, int(ay_s + dy / n * 80))) + cv2.arrowedLine(img, (ax_s, ay_s), (ax_e, ay_e), (0, 255, 255), 3) + cv2.putText(img, f"VP({vp_ix},{vp_iy})", (ax_s + 5, ay_s - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) + + # 4. 라멘 그룹 강조: 바운딩 박스 + 그룹 ID + 판정 라벨 + for g in valid_items: + verdict = g['verdict'] + color = _RAAMEN_COLOR[verdict] + all_idx = g['H'] + g['V'] + all_pts = np.vstack([polys[i] for i in all_idx]).astype(np.int32) + x0 = all_pts[:, 0].min() - 15; y0 = all_pts[:, 1].min() - 15 + x1 = all_pts[:, 0].max() + 15; y1 = all_pts[:, 1].max() + 15 + cv2.rectangle(img, (x0, y0), (x1, y1), color, 4) + lbl = f"G{g['id']} {verdict} ({g['n_poles']}poles)" + cv2.putText(img, lbl, (x0, y0 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 0), 6) + cv2.putText(img, lbl, (x0, y0 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2) + + # 이미지 저장 + scale = min(1.0, 4096 / max(H, W)) + if scale < 1.0: + img = cv2.resize(img, (int(W * scale), int(H * scale))) + output_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path)) + print(f"\n → {output_path}") + + # AnyLabeling JSON 저장 + import json + + def _shape(label, points, gid, desc): + return {"label": label, "score": None, "points": points, + "group_id": gid, "description": desc, + "shape_type": "polygon", "flags": None} + + shapes = [] + for g in valid_items: + gid = g['id'] + verdict = g['verdict'] + desc_base = f"G{gid} {verdict}" + + def abs_ids(rel_list): + """상대 폴리곤 인덱스 → 절대 JSON shapes[] 인덱스 변환.""" + return sorted(poly_abs_idx[i] for i in rel_list) + + if not g['H']: + # RAAMEN_CENTER: 최하단 꼭짓점 포함 폴리곤 = 기둥, 나머지 = 빔 + bottom_vi = max(g['V'], key=lambda i: polys[i][:, 1].max()) + for vi in g['V']: + label = "raamen_pole" if vi == bottom_vi else "raamen_beam" + shapes.append(_shape(label, + [[float(p[0]), float(p[1])] for p in polys[vi]], + gid, f"{desc_base} shape#{poly_abs_idx[vi]}")) + else: + # 일반 RAAMEN: V = 기둥, H = 빔 + for vi in g['V']: + shapes.append(_shape("raamen_pole", + [[float(p[0]), float(p[1])] for p in polys[vi]], + gid, f"{desc_base} pole shape#{poly_abs_idx[vi]}")) + for hi in g['H']: + shapes.append(_shape("raamen_beam", + [[float(p[0]), float(p[1])] for p in polys[hi]], + gid, f"{desc_base} beam shape#{poly_abs_idx[hi]}")) + + anylabel_json = { + "version": "3.3.9", + "flags": {}, + "shapes": shapes, + "imagePath": image_path.name, + "imageData": None, + "imageHeight": H, + "imageWidth": W, + } + json_path = output_path.with_suffix(".json") + json_path.write_text(json.dumps(anylabel_json, ensure_ascii=False, indent=2), + encoding="utf-8") + print(f" → {json_path}") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--image", required=True) + ap.add_argument("--label", required=True) + ap.add_argument("--output", required=True) + ap.add_argument("--class-ids", default="", + help="포함할 클래스 ID, 콤마 구분 (.txt 전용)") + ap.add_argument("--class-names", default="catenary_pole", + help="포함할 클래스 이름, 콤마 구분 (.json 전용, 기본: 'catenary_pole')") + ap.add_argument("--epsilon", type=float, default=4.0, + help="approxPolyDP epsilon (기본 4.0px)") + ap.add_argument("--v-thresh", type=float, default=20.0, + help="V/H 분류 각도 임계값 degrees (기본 20°)") + ap.add_argument("--h-max-diff", type=float, default=75.0, + help="H(빔) 최대 각도 diff; 이 이상은 레일/전선 등으로 제외 (기본 75°)") + ap.add_argument("--x-horiz-thresh", type=float, default=10.0, + help="X축 절대 수평 제외 임계값 degrees; AR≥4 AND 이 각도 이내 → 제외 (기본 10°)") + ap.add_argument("--margin", type=int, default=30, + help="폴리곤 접촉 판정 margin px (기본 30)") + ap.add_argument("--min-group-area", type=float, default=0, + help="라멘 그룹의 최소 면적 합계 (기본 0)") + args = ap.parse_args() + + class_ids = ({int(x) for x in args.class_ids.split(',') if x.strip()} + if args.class_ids else None) + class_names = ({x.strip() for x in args.class_names.split(',') if x.strip()} + if args.class_names else None) + + out = Path(args.output) + folder = out.parent / out.stem # e.g. output/0004_test + if folder.exists(): + n = 1 + while (out.parent / f"{out.stem}_{n}").exists(): + n += 1 + folder = out.parent / f"{out.stem}_{n}" + folder.mkdir(parents=True, exist_ok=True) + out = folder / out.name # e.g. output/0004_test/0004_test.jpg + print(f" [출력 폴더] {folder}") + + render(Path(args.image), Path(args.label), out, args, + class_ids=class_ids, class_names=class_names, + epsilon=args.epsilon, v_thresh=args.v_thresh, + h_max_diff=args.h_max_diff, x_horiz_thresh=args.x_horiz_thresh) + + +if __name__ == "__main__": + main() diff --git a/tools/labeling_server.py b/tools/labeling_server.py new file mode 100644 index 0000000..8c5fac0 --- /dev/null +++ b/tools/labeling_server.py @@ -0,0 +1,503 @@ +""" +Control Box 라벨링 서버 v2 — 전체 이미지 + bbox 오버레이 방식 +사용법: python tools/labeling_server.py \ + --json "data/역사이미지/slope/DJI_20260306113838_0004_everything.json" + [--reset] DB 초기화 +브라우저: http://localhost:7001 +""" +import argparse, json, sqlite3, sys +from collections import defaultdict +from pathlib import Path + +import cv2, numpy as np +from fastapi import FastAPI +from fastapi.responses import HTMLResponse, JSONResponse, Response +import uvicorn + +ROOT = Path(__file__).parent.parent +DB = ROOT / "labels" / "labeling.db" +MIN_VOTES = 3 +TRUE_RATIO = 0.6 +MAX_DIM = 3000 # served image max dimension (px) + +CONTROL_LABELS = { + "small dark object on ballast", "manhole cover", "trackside enclosure", + "dark rectangle on ground", "small square box", "small metal cabinet", + "small cube on gravel", "compact trackside junction box", + "small square gray metal box beside rail", + "small near-square electrical enclosure on the ground", +} + +app = FastAPI() +_img_data: bytes = None +_orig_w = _orig_h = _disp_w = _disp_h = 0 + + +# ── DB ──────────────────────────────────────────────────────────────────────── +def get_db(): + con = sqlite3.connect(str(DB)) + con.row_factory = sqlite3.Row + return con + + +def init_db(candidates: list): + DB.parent.mkdir(parents=True, exist_ok=True) + con = get_db() + con.executescript(""" + CREATE TABLE IF NOT EXISTS candidates ( + id INTEGER PRIMARY KEY, + json_idx INTEGER, + label TEXT, + score REAL, + bbox TEXT, + image_path TEXT + ); + CREATE TABLE IF NOT EXISTS votes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + candidate_id INTEGER, + user TEXT, + vote INTEGER, + ts DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(candidate_id, user) + ); + """) + if con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0] > 0: + n = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0] + print(f"DB 기존 데이터 유지. candidates={n}") + con.close() + return + for c in candidates: + con.execute( + "INSERT INTO candidates(json_idx,label,score,bbox,image_path) VALUES(?,?,?,?,?)", + (c["json_idx"], c["label"], c["score"], json.dumps(c["bbox"]), c["image_path"]), + ) + con.commit() + print(f"DB 등록: {len(candidates)}개") + con.close() + + +# ── 후보 로드 ───────────────────────────────────────────────────────────────── +def load_candidates(json_path: Path) -> list: + data = json.loads(json_path.read_text(encoding="utf-8")) + stem = json_path.stem.replace("_everything", "") + img_path = next( + (json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png") + if (json_path.parent / f"{stem}{ext}").exists()), + None, + ) + if img_path is None: + raise FileNotFoundError(f"원본 이미지 없음: {json_path.parent / stem}.*") + + candidates = [] + for idx, seg in enumerate(data.get("segments", [])): + label = seg.get("label", "").strip() + if label not in CONTROL_LABELS: + continue + x0, y0, x1, y1 = [float(v) for v in seg["bbox"]] + if (x1 - x0) < 4 or (y1 - y0) < 4: + continue + candidates.append({ + "json_idx": idx, + "label": label, + "score": float(seg.get("score", 0)), + "bbox": [x0, y0, x1, y1], + "image_path": str(img_path), + }) + + segs_total = len(data.get("segments", [])) + print(f"필터링: {segs_total}개 → {len(candidates)}개 control_box 후보") + return candidates + + +# ── 이미지 준비 (서버 시작 시 1회) ─────────────────────────────────────────── +def prepare_image(img_path: Path): + global _img_data, _orig_w, _orig_h, _disp_w, _disp_h + buf = np.fromfile(str(img_path), dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if img is None: + raise ValueError(f"이미지 로드 실패: {img_path}") + _orig_h, _orig_w = img.shape[:2] + scale = min(1.0, MAX_DIM / max(_orig_w, _orig_h)) + _disp_w, _disp_h = int(_orig_w * scale), int(_orig_h * scale) + if scale < 1.0: + img = cv2.resize(img, (_disp_w, _disp_h), interpolation=cv2.INTER_LANCZOS4) + _, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 92]) + _img_data = bytes(enc) + print(f"이미지 준비: {_orig_w}×{_orig_h} → {_disp_w}×{_disp_h}") + + +# ── API ─────────────────────────────────────────────────────────────────────── +@app.get("/image") +async def serve_image(): + return Response(content=_img_data, media_type="image/jpeg") + + +@app.get("/api/image_info") +async def api_image_info(): + return {"orig_w": _orig_w, "orig_h": _orig_h, "disp_w": _disp_w, "disp_h": _disp_h} + + +@app.get("/api/candidates") +async def api_candidates(user: str = ""): + con = get_db() + rows = con.execute(""" + SELECT c.id, c.bbox, c.label, c.score, v.vote + FROM candidates c + LEFT JOIN votes v ON c.id=v.candidate_id AND v.user=? + """, (user,)).fetchall() + con.close() + return {"candidates": [ + { + "id": r["id"], + "bbox": json.loads(r["bbox"]), + "label": r["label"], + "score": round(r["score"], 2), + "voted": r["vote"] is not None, + "is_true": r["vote"] == 1 if r["vote"] is not None else None, + } for r in rows + ]} + + +@app.post("/api/vote") +async def api_vote(data: dict): + user = data.get("user", "").strip() + all_ids = data.get("all_ids", []) + true_ids = set(data.get("true_ids", [])) + if not user or not all_ids: + return {"ok": False, "error": "파라미터 오류"} + con = get_db() + for cid in all_ids: + con.execute( + "INSERT OR REPLACE INTO votes(candidate_id,user,vote) VALUES(?,?,?)", + (cid, user, 1 if cid in true_ids else 0), + ) + con.commit() + con.close() + return {"ok": True} + + +@app.get("/api/stats") +async def api_stats(): + con = get_db() + total = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0] + voted3 = con.execute(f""" + SELECT COUNT(*) FROM ( + SELECT candidate_id FROM votes GROUP BY candidate_id HAVING COUNT(*) >= {MIN_VOTES} + ) + """).fetchone()[0] + users = con.execute("SELECT COUNT(DISTINCT user) FROM votes").fetchone()[0] + tvotes = con.execute("SELECT COUNT(*) FROM votes").fetchone()[0] + con.close() + return {"total": total, "voted3plus": voted3, "users": users, "total_votes": tvotes} + + +@app.post("/api/export") +async def api_export(): + con = get_db() + confirmed = con.execute(f""" + SELECT c.id, c.bbox, c.image_path, + SUM(v.vote) AS true_v, COUNT(v.id) AS total_v + FROM candidates c JOIN votes v ON c.id=v.candidate_id + GROUP BY c.id + HAVING total_v >= {MIN_VOTES} AND (CAST(true_v AS REAL)/total_v) >= {TRUE_RATIO} + """).fetchall() + con.close() + + by_image = defaultdict(list) + for row in confirmed: + by_image[row["image_path"]].append(row) + + out_dir = ROOT / "labels" / "yolo_export" + out_dir.mkdir(parents=True, exist_ok=True) + + count = 0 + for img_path, rows in by_image.items(): + buf = np.fromfile(img_path, dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + H, W = img.shape[:2] + txt = out_dir / (Path(img_path).stem + ".txt") + with open(txt, "w") as f: + for row in rows: + x0, y0, x1, y1 = json.loads(row["bbox"]) + f.write(f"0 {((x0+x1)/2)/W:.6f} {((y0+y1)/2)/H:.6f} " + f"{(x1-x0)/W:.6f} {(y1-y0)/H:.6f}\n") + count += 1 + return {"ok": True, "labels": count, "dir": str(out_dir)} + + +# ── HTML ────────────────────────────────────────────────────────────────────── +@app.get("/", response_class=HTMLResponse) +async def index(): + return HTML + + +HTML = r""" + + + +Control Box 라벨링 + + + +
+

🎯 Control Box

+ + + + +
+ 미투표 + 컨트롤박스 ✓ + 아님 ✗ + 타인투표완료 +
+
+ +
+
+ +
+
휠=줌 | 드래그=이동 | 클릭=YES/NO 토글 | F=맞춤
+
+ + + +""" + + +# ── 메인 ───────────────────────────────────────────────────────────────────── +def main(): + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--json", required=True, help="*_everything.json 경로") + ap.add_argument("--port", type=int, default=7001) + ap.add_argument("--reset", action="store_true", help="DB 초기화 후 재로드") + args = ap.parse_args() + + json_path = Path(args.json) + if not json_path.exists(): + print(f"JSON 없음: {json_path}"); sys.exit(1) + + if args.reset and DB.exists(): + DB.unlink() + print("DB 초기화 완료.") + + print("후보 로딩 중...") + candidates = load_candidates(json_path) + init_db(candidates) + + # Prepare image + stem = json_path.stem.replace("_everything", "") + img_path = next( + (json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png") + if (json_path.parent / f"{stem}{ext}").exists()), + None, + ) + if img_path is None: + print(f"원본 이미지 없음: {json_path.parent / stem}.*"); sys.exit(1) + prepare_image(img_path) + + print(f"\n라벨링 서버 시작: http://localhost:{args.port}") + uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/tools/merge_tiles_vis.py b/tools/merge_tiles_vis.py new file mode 100644 index 0000000..949e861 --- /dev/null +++ b/tools/merge_tiles_vis.py @@ -0,0 +1,90 @@ +""" +6×8 (또는 임의 그리드) 타일 라벨을 원본 이미지 좌표로 병합해 시각화. + +사용: + python tools/merge_tiles_vis.py \ + --orig data/역사이미지/slope/DJI_20260306113838_0004.JPG \ + --labels output/autolabel/tile6x8/labels \ + --output output/autolabel/tile6x8/merged_vis.jpg \ + --cols 6 --rows 8 +""" +import argparse +import cv2 +import numpy as np +from pathlib import Path + +CLASS_NAMES = ["catenary_pole", "bracket"] +CLASS_COLORS = [(0, 200, 255), (255, 130, 0)] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--orig", required=True) + parser.add_argument("--labels", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--cols", type=int, default=6) + parser.add_argument("--rows", type=int, default=8) + parser.add_argument("--alpha", type=float, default=0.3) + args = parser.parse_args() + + buf = np.fromfile(args.orig, dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + H, W = img.shape[:2] + tw, th = W // args.cols, H // args.rows + + label_dir = Path(args.labels) + counts = [0] * len(CLASS_NAMES) + + for r in range(args.rows): + for c in range(args.cols): + label_file = label_dir / f"tile_r{r+1:02d}_c{c+1:02d}.txt" + if not label_file.exists(): + continue + x0 = c * tw + y0 = r * th + tile_w = tw if c < args.cols - 1 else W - x0 + tile_h = th if r < args.rows - 1 else H - y0 + + text = label_file.read_text(encoding="utf-8").strip() + if not text: + continue + for line in text.splitlines(): + parts = line.split() + if not parts: + continue + cls_id = int(parts[0]) + coords = list(map(float, parts[1:])) + pts = np.array( + [[coords[i] * tile_w + x0, coords[i + 1] * tile_h + y0] + for i in range(0, len(coords), 2)], + dtype=np.int32, + ) + color = CLASS_COLORS[cls_id % len(CLASS_COLORS)] + overlay = img.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, args.alpha, img, 1 - args.alpha, 0, img) + cv2.polylines(img, [pts], True, color, 3) + if cls_id < len(counts): + counts[cls_id] += 1 + + # 범례 + for i, name in enumerate(CLASS_NAMES): + y = 30 + i * 50 + cv2.rectangle(img, (15, y - 20), (55, y + 10), CLASS_COLORS[i], -1) + cv2.putText(img, f"{name}: {counts[i]}", + (65, y), cv2.FONT_HERSHEY_SIMPLEX, 1.5, CLASS_COLORS[i], 3) + + total = sum(counts) + print(f"총 {total}개 " + ", ".join(f"{CLASS_NAMES[i]}={counts[i]}" for i in range(len(counts)))) + + # 4096px 이하로 미리보기 저장 + scale = min(1.0, 4096 / max(H, W)) + vis = cv2.resize(img, (int(W * scale), int(H * scale))) + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imencode(out_path.suffix, vis)[1].tofile(str(out_path)) + print(f"저장: {out_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/post_merge_poles.py b/tools/post_merge_poles.py new file mode 100644 index 0000000..e140034 --- /dev/null +++ b/tools/post_merge_poles.py @@ -0,0 +1,154 @@ +"""post_merge_poles.py — detect_all_objects.py 출력 JSON에서 catenary_pole 병합. + +detecting은 한 번만, 병합 파라미터 조정 시 이 스크립트만 재실행. + +Usage: + python tools/post_merge_poles.py INPUT.json [--x-overlap 0.30] [--y-gap 150] + python tools/post_merge_poles.py INPUT.json --inplace + python tools/post_merge_poles.py INPUT.json --output OUTPUT.json +""" +import argparse +import json +import sys +from pathlib import Path + +import cv2 +import numpy as np + + +def _poly_orient(points: list, H: int, W: int) -> str: + """폴리곤 장축 방향 판별. + + V: 장축이 이미지 중심 radial 방향 정렬(cos_sim > 0.7) → 세로 기둥 + H: 직교 → 수평 빔 + ?: aspect ratio < 1.3 + """ + pts = np.array(points, dtype=np.float32) + rect = cv2.minAreaRect(pts) + (rx, ry), (rw, rh), angle = rect + if min(rw, rh) < 1: + return '?' + ar = max(rw, rh) / min(rw, rh) + if ar < 1.3: + return '?' + long_angle_deg = angle if rw >= rh else angle + 90 + lx = float(np.cos(np.radians(long_angle_deg))) + ly = float(np.sin(np.radians(long_angle_deg))) + img_cx, img_cy = W / 2.0, H / 2.0 + rdx, rdy = rx - img_cx, ry - img_cy + radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5 + if radial_norm < 1: + return '?' + rdx, rdy = rdx / radial_norm, rdy / radial_norm + cos_sim = abs(lx * rdx + ly * rdy) + return 'V' if cos_sim > 0.7 else 'H' + + +def merge_poles(shapes: list, H: int, W: int, + x_overlap_thresh: float = 0.30, + y_gap_thresh: int = 150) -> list: + """V+V 조합 전철주만 타일 경계 병합.""" + if len(shapes) <= 1: + return shapes + + orients = [_poly_orient(s["points"], H, W) for s in shapes] + v_count = sum(1 for o in orients if o == 'V') + h_count = sum(1 for o in orients if o == 'H') + print(f" orient: V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}") + + def get_bbox(s): + xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]] + return min(xs), min(ys), max(xs), max(ys) + + def x_overlap_ratio(b1, b2): + ox = min(b1[2], b2[2]) - max(b1[0], b2[0]) + ux = max(b1[2], b2[2]) - min(b1[0], b2[0]) + return ox / ux if ux > 0 else 0.0 + + def y_gap(b1, b2): + return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3])) + + def merge_two(s1, s2): + mask = np.zeros((H, W), dtype=np.uint8) + for s in (s1, s2): + cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return s1 + c = max(contours, key=cv2.contourArea) + eps = 0.002 * cv2.arcLength(c, True) + approx = cv2.approxPolyDP(c, eps, True) + merged = dict(s1) + merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx] + merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0))) + return merged + + merged_flags = [False] * len(shapes) + result = [] + merged_count = 0 + for i in range(len(shapes)): + if merged_flags[i]: + continue + cur = shapes[i] + cur_ori = orients[i] + cb = get_bbox(cur) + for j in range(i + 1, len(shapes)): + if merged_flags[j]: + continue + if cur_ori != 'V' or orients[j] != 'V': + continue + jb = get_bbox(shapes[j]) + if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh: + cur = merge_two(cur, shapes[j]) + cur_ori = 'V' + cb = get_bbox(cur) + merged_flags[j] = True + merged_count += 1 + result.append(cur) + print(f" 병합: {len(shapes)} → {len(result)}개 ({merged_count}쌍 합침)") + return result + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("input", help="AnyLabeling JSON 파일") + ap.add_argument("--output", default=None, help="출력 JSON (기본: INPUT_merged.json)") + ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기") + ap.add_argument("--x-overlap", type=float, default=0.30, help="x-range 겹침 비율 임계값") + ap.add_argument("--y-gap", type=int, default=150, help="y-range 간격 임계값 (px)") + ap.add_argument("--label", default="catenary_pole", help="병합 대상 라벨명") + args = ap.parse_args() + + src = Path(args.input) + if not src.exists(): + print(f"파일 없음: {src}", file=sys.stderr) + sys.exit(1) + + data = json.loads(src.read_text(encoding="utf-8")) + shapes = data.get("shapes", []) + iW = data.get("imageWidth", 0) + iH = data.get("imageHeight", 0) + if iW == 0 or iH == 0: + print("imageWidth/imageHeight 정보 없음", file=sys.stderr) + sys.exit(1) + + target = [s for s in shapes if s.get("label") == args.label] + others = [s for s in shapes if s.get("label") != args.label] + print(f"{args.label}: {len(target)}개 → 병합 처리") + + merged = merge_poles(target, iH, iW, args.x_overlap, args.y_gap) + data["shapes"] = others + merged + + if args.inplace: + dst = src + elif args.output: + dst = Path(args.output) + else: + dst = src.with_stem(src.stem + "_merged") + + dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"저장: {dst}") + + +if __name__ == "__main__": + main() diff --git a/tools/rail_alignment_fit.py b/tools/rail_alignment_fit.py new file mode 100644 index 0000000..7ebba64 --- /dev/null +++ b/tools/rail_alignment_fit.py @@ -0,0 +1,370 @@ +""" +철도 평면 선형 피팅 모듈 +스켈레톤 좌표점 → 직선/원곡선/완화곡선 분류 및 피팅 → 매끄러운 폴리라인 + +선형 구성요소: + 1. 직선 (Straight) : y = ax + b + 2. 원곡선 (Circular) : R = 상수, Taubin 원 피팅 + 3. 완화곡선 (Transition) : y = x³/(6RL), 3차포물선 (KR 설계기준) +""" + +import numpy as np +from scipy.signal import savgol_filter +from scipy.linalg import eig as scipy_eig + + +# ── 1. 곡률 계산 ────────────────────────────────────────────────────────── + +def compute_curvature(pts, smooth_window=15): + """ + 길이 가중 방향벡터를 이용한 곡률 계산 + + Parameters + ---------- + pts : (N, 2) array 세계 좌표 (미터) + smooth_window : int Savitzky-Golay 스무딩 윈도우 (홀수) + + Returns + ------- + curvatures : (N,) 곡률 κ [1/m] + arc_lengths : (N,) 각 점의 누적 호 길이 [m] + angles : (N,) 방향각 [rad], unwrap 처리됨 + """ + pts = np.asarray(pts, dtype=float) + n = len(pts) + if n < 3: + return np.zeros(n), np.zeros(n), np.zeros(n) + + # 세그먼트 벡터 및 길이 + segs = np.diff(pts, axis=0) # (n-1, 2) + lens = np.linalg.norm(segs, axis=1) # (n-1,) + lens = np.maximum(lens, 1e-12) + + # 단위 방향벡터 + dirs = segs / lens[:, np.newaxis] + + # 가중 방향벡터: 인접 세그먼트 길이를 가중치로 평균 + directions = np.empty((n, 2)) + directions[0] = dirs[0] + directions[-1] = dirs[-1] + for i in range(1, n - 1): + w1, w2 = lens[i - 1], lens[i] + directions[i] = (w1 * dirs[i - 1] + w2 * dirs[i]) / (w1 + w2) + + # 정규화 + norms = np.linalg.norm(directions, axis=1, keepdims=True) + directions /= np.maximum(norms, 1e-12) + + # 방향각 (연속성 보장) + angles = np.arctan2(directions[:, 1], directions[:, 0]) + angles = np.unwrap(angles) + + # 누적 호 길이 + arc_lengths = np.zeros(n) + arc_lengths[1:] = np.cumsum(lens) + + # 곡률 κ = dθ/ds (중앙차분) + curvatures = np.zeros(n) + for i in range(1, n - 1): + dtheta = angles[i + 1] - angles[i - 1] + ds = arc_lengths[i + 1] - arc_lengths[i - 1] + if ds > 1e-10: + curvatures[i] = dtheta / ds + curvatures[0] = curvatures[1] + curvatures[-1] = curvatures[-2] + + # Savitzky-Golay 스무딩 (노이즈 억제) + w = smooth_window + if n >= w and w >= 5: + w = w if w % 2 == 1 else w - 1 + curvatures = savgol_filter(curvatures, w, 3) + + return curvatures, arc_lengths, angles + + +# ── 2. 구간 분류 ────────────────────────────────────────────────────────── + +def segment_by_curvature(curvatures, arc_lengths, + kappa_straight=0.001, # 1/1000m → R > 1000m 직선 처리 + cv_thresh=0.35, # 변동계수(표준편차/평균) 임계값 + min_seg_len=3.0): # 최소 구간 길이 [m] + """ + 곡률 패턴으로 구간 분류 + + Returns + ------- + segments : list of (start_idx, end_idx, type_str) + type_str ∈ {'straight', 'circular', 'transition'} + """ + n = len(curvatures) + kappa_abs = np.abs(curvatures) + win = max(5, n // 15) + + def classify_window(k_arr): + km = k_arr.mean() + ks = k_arr.std() + if km < kappa_straight: + return 'straight' + cv = ks / (km + 1e-12) + if cv < cv_thresh: + return 'circular' + return 'transition' + + # 각 점의 로컬 타입 + local_type = [] + for i in range(n): + lo = max(0, i - win // 2) + hi = min(n, i + win // 2 + 1) + local_type.append(classify_window(kappa_abs[lo:hi])) + + # 연속된 같은 타입 구간 합치기 + segments = [] + cur_type = local_type[0] + cur_start = 0 + for i in range(1, n): + if local_type[i] != cur_type: + segments.append((cur_start, i - 1, cur_type)) + cur_type = local_type[i] + cur_start = i + segments.append((cur_start, n - 1, cur_type)) + + # 너무 짧은 구간은 인접 구간에 병합 + merged = True + while merged and len(segments) > 1: + merged = False + new_segs = [] + i = 0 + while i < len(segments): + s, e, t = segments[i] + seg_len = arc_lengths[e] - arc_lengths[s] + if seg_len < min_seg_len and len(segments) > 1: + if i == 0 and i + 1 < len(segments): + ns, ne, nt = segments[i + 1] + segments[i + 1] = (s, ne, nt) + i += 1 + merged = True + elif new_segs: + new_segs[-1] = (new_segs[-1][0], e, new_segs[-1][2]) + merged = True + else: + new_segs.append((s, e, t)) + else: + new_segs.append((s, e, t)) + i += 1 + if merged: + segments = new_segs + + return segments + + +# ── 3. 개별 구간 피팅 ──────────────────────────────────────────────────── + +def fit_straight(pts, spacing=0.5): + """PCA 직선 피팅 → 균등 샘플 점 목록""" + pts = np.asarray(pts, dtype=float) + center = pts.mean(axis=0) + _, _, Vt = np.linalg.svd(pts - center) + direction = Vt[0] + t = (pts - center) @ direction + t_min, t_max = t.min(), t.max() + n_pts = max(2, int((t_max - t_min) / spacing) + 1) + ts = np.linspace(t_min, t_max, n_pts) + return [(center + ti * direction).tolist() for ti in ts] + + +def taubin_circle_fit(pts): + """Taubin 방법으로 안정적 원 피팅 → (cx, cy, R) 또는 None""" + pts = np.asarray(pts, dtype=float) + mean = pts.mean(axis=0) + x = pts[:, 0] - mean[0] + y = pts[:, 1] - mean[1] + z = x**2 + y**2 + + Mxx = (x * x).mean() + Myy = (y * y).mean() + Mxy = (x * y).mean() + Mxz = (x * z).mean() + Myz = (y * z).mean() + Mzz = (z * z).mean() + + Mz = Mxx + Myy + Cov_xy = Mxx * Myy - Mxy**2 + A3 = 4 * Mz + A2 = -3 * Mz**2 - Mzz + A1 = Mzz * Mz + 4 * Cov_xy * Mz - Mxz**2 - Myz**2 - Mz**3 + A0 = Mxz**2 * Myy + Myz**2 * Mxx - Mzz * Cov_xy - 2 * Mxz * Myz * Mxy + Mz**2 * Cov_xy + + try: + roots = np.roots([A3, A2, A1, A0]) + real_roots = roots[np.isreal(roots)].real + if len(real_roots) == 0: + return None + xn = max(real_roots) + yn = (xn**2 - Mz) * xn + A0 / A3 + det = xn**2 - xn * Mz + Cov_xy + if abs(det) < 1e-12: + return None + xcen = (Mxz * (Myy - xn) - Myz * Mxy) / (2 * det) + mean[0] + ycen = (Myz * (Mxx - xn) - Mxz * Mxy) / (2 * det) + mean[1] + R = np.sqrt(np.mean((pts[:, 0] - xcen)**2 + (pts[:, 1] - ycen)**2)) + return xcen, ycen, R + except Exception: + return None + + +def fit_circular(pts, spacing=0.5): + """원곡선 피팅 → 호 위의 균등 샘플 점 목록""" + result = taubin_circle_fit(pts) + if result is None: + return fit_straight(pts, spacing) + + cx, cy, R = result + pts = np.asarray(pts, dtype=float) + angs = np.arctan2(pts[:, 1] - cy, pts[:, 0] - cx) + angs = np.unwrap(angs) + a_start, a_end = angs[0], angs[-1] + + arc_len = abs(a_end - a_start) * R + n_pts = max(2, int(arc_len / spacing) + 1) + thetas = np.linspace(a_start, a_end, n_pts) + return [(cx + R * np.cos(t), cy + R * np.sin(t)) for t in thetas] + + +def fit_transition_cubic(pts, R_adj, spacing=0.5): + """ + 3차포물선 완화곡선 피팅 y = x³ / (6·R·L) + 로컬 좌표계(시작점, 시작방향 기준)에서 피팅 후 세계 좌표로 역변환 + + Parameters + ---------- + pts : (N, 2) 세계 좌표 + R_adj : 인접 원곡선 반경 [m] (없으면 곡률 역수로 추정) + """ + pts = np.asarray(pts, dtype=float) + origin = pts[0].copy() + + # 로컬 좌표계 축 (PCA 주방향) + _, _, Vt = np.linalg.svd(pts - origin) + axis_x = Vt[0] + axis_y = np.array([-axis_x[1], axis_x[0]]) + + local = pts - origin + lx = local @ axis_x + ly = local @ axis_y + + L = lx.max() + if L < 1e-5 or R_adj < 1e-5: + return pts.tolist() + + # 균등 샘플 + n_pts = max(2, int(L / spacing) + 1) + xs = np.linspace(0, L, n_pts) + ys = xs**3 / (6.0 * R_adj * L) + + return [(origin + xi * axis_x + yi * axis_y).tolist() + for xi, yi in zip(xs, ys)] + + +# ── 4. 전체 피팅 파이프라인 ─────────────────────────────────────────────── + +def fit_alignment(pts, spacing=0.5): + """ + 메인 함수: 세계 좌표 점 목록 → 선형 피팅 → 매끄러운 폴리라인 + + Parameters + ---------- + pts : list of (x, y) EPSG:5186 세계 좌표 [m] + spacing : float 출력 점 간격 [m] + + Returns + ------- + smooth_pts : list of (x, y) 피팅된 매끄러운 폴리라인 점 목록 + seg_info : list of dict 각 구간 정보 (디버깅용) + """ + pts = np.asarray(pts, dtype=float) + n = len(pts) + + if n < 5: + return pts.tolist(), [] + + # --- 곡률 계산 --- + sw = min(21, n if n % 2 == 1 else n - 1) + sw = max(5, sw) + curvatures, arc_lengths, angles = compute_curvature(pts, sw) + + # --- 구간 분류 --- + segments = segment_by_curvature(curvatures, arc_lengths) + + # --- 인접 원곡선 R 사전 계산 (완화곡선에서 사용) --- + seg_R = {} + for idx, (s, e, t) in enumerate(segments): + if t == 'circular': + km = np.abs(curvatures[s:e+1]).mean() + seg_R[idx] = 1.0 / km if km > 1e-6 else 1000.0 + + smooth_pts = [] + seg_info = [] + + for idx, (s, e, seg_type) in enumerate(segments): + seg_pts = pts[s:e+1] + + if len(seg_pts) < 2: + smooth_pts.extend(seg_pts.tolist()) + continue + + km = np.abs(curvatures[s:e+1]).mean() + R_est = 1.0 / km if km > 1e-6 else 9999.0 + + if seg_type == 'straight': + fitted = fit_straight(seg_pts, spacing) + elif seg_type == 'circular': + fitted = fit_circular(seg_pts, spacing) + else: # transition + # 인접 원곡선 R 탐색 + R_use = R_est + for di in [1, -1]: + ni = idx + di + if ni in seg_R: + R_use = seg_R[ni] + break + fitted = fit_transition_cubic(seg_pts, R_use, spacing) + + seg_len = arc_lengths[e] - arc_lengths[s] + seg_info.append({ + 'type': seg_type, + 'start_idx': s, + 'end_idx': e, + 'length_m': round(seg_len, 2), + 'R_m': round(R_est, 1), + 'n_pts_in': len(seg_pts), + 'n_pts_out': len(fitted), + }) + + # 이음새 중복 제거 + if smooth_pts and fitted: + smooth_pts.extend(fitted[1:]) + else: + smooth_pts.extend(fitted) + + return smooth_pts, seg_info + + +# ── 5. 간단 테스트 ──────────────────────────────────────────────────────── + +if __name__ == '__main__': + # 직선 + 원곡선 + 완화곡선 합성 테스트 + t = np.linspace(0, 2*np.pi, 300) + # 직선 구간 + straight = [(float(i)*0.5, 0.0) for i in range(40)] + # 원곡선 구간 R=300m + R = 300.0 + theta = np.linspace(0, np.pi/4, 60) + cx, cy = straight[-1][0], straight[-1][1] + R + circ = [(cx + R*np.sin(th), cy - R*np.cos(th)) for th in theta] + pts_test = straight + circ[1:] + + smooth, info = fit_alignment(pts_test, spacing=1.0) + print(f"입력: {len(pts_test)}점 → 출력: {len(smooth)}점") + for s in info: + print(f" [{s['type']:10s}] L={s['length_m']:.1f}m R={s['R_m']:.0f}m " + f"점: {s['n_pts_in']}→{s['n_pts_out']}") diff --git a/tools/rail_centerline_dxf.py b/tools/rail_centerline_dxf.py new file mode 100644 index 0000000..f9951f1 --- /dev/null +++ b/tools/rail_centerline_dxf.py @@ -0,0 +1,175 @@ +""" +2cm GSD 드론 TIF 전체에서 레일 중심선 추출 → DXF 저장 + +중심선 추출 방법: + masks.xy 폴리곤 → PCA 주축(길이 방향) 찾기 + → 주축 방향으로 슬라이싱 → 폭 방향 중앙값 + → 계단 현상 없는 매끄러운 중심선 +""" +import numpy as np +import rasterio +from rasterio.windows import Window +from ultralytics import YOLO +from PIL import Image +import ezdxf + + +def polygon_to_centerline(xy_tile, spacing_px=10): + """ + 타일 픽셀 좌표 폴리곤 → 중심선 점 목록 + + Parameters + ---------- + xy_tile : list of (x, y) 타일 픽셀 좌표 (masks.xy × sx/sy) + spacing_px : int 슬라이싱 간격 [픽셀] (10px = 0.2m at 2cm GSD) + + Returns + ------- + list of (x, y) 중심선 타일 픽셀 좌표 + """ + pts = np.asarray(xy_tile, dtype=float) + if len(pts) < 4: + return [] + + # PCA: 주축(레일 길이 방향) 탐색 + center = pts.mean(axis=0) + _, _, Vt = np.linalg.svd(pts - center, full_matrices=False) + v_long = Vt[0] # 길이 방향 단위벡터 + v_perp = Vt[1] # 폭 방향 단위벡터 + + local = pts - center + t = local @ v_long # 길이 방향 좌표 + + t_min, t_max = t.min(), t.max() + if t_max - t_min < spacing_px: + # 너무 짧으면 무게중심 1점 + return [center.tolist()] + + n_bins = max(2, int((t_max - t_min) / spacing_px) + 1) + bins = np.linspace(t_min, t_max, n_bins + 1) + + centerline = [] + s_all = local @ v_perp # 폭 방향 좌표 (전체) + + for i in range(n_bins): + mask = (t >= bins[i]) & (t <= bins[i + 1]) + if mask.sum() < 2: + continue + t_c = (bins[i] + bins[i + 1]) / 2 + s_vals = s_all[mask] + # 폭 방향: 폴리곤 양 가장자리의 기하학적 중앙 + s_c = (s_vals.max() + s_vals.min()) / 2 + pt = center + t_c * v_long + s_c * v_perp + centerline.append(pt.tolist()) + + return centerline + + +# ── 설정 ────────────────────────────────────────── +SRC_TIF = "drone_2cm/22)조치원(STA.127+570~131+300).tif" +MODEL_PT = "runs/segment/output/yolo_train_2cm/rail_seg_v2/weights/best.pt" +OUT_DXF = "output/rail_centerline_2cm_pca.dxf" +TILE = 1024 +OVERLAP = 128 +STEP = TILE - OVERLAP +CONF = 0.3 +BLACK_THR = 0.4 # 검은 픽셀 비율 임계값 +SPACING = 10 # 중심선 샘플 간격 [픽셀] 10px ≈ 0.2m + +# ── 모델 로드 ────────────────────────────────────── +print("모델 로드 중...") +model = YOLO(MODEL_PT) + +# ── DXF 초기화 ──────────────────────────────────── +doc = ezdxf.new() +msp = doc.modelspace() +doc.layers.add("RAIL_CENTERLINE", color=3) + +# ── TIF 열기 ────────────────────────────────────── +src = rasterio.open(SRC_TIF) +W, H = src.width, src.height +transform = src.transform +print(f"이미지 크기: {W} x {H}") + + +def pixel_to_world(px, py): + """픽셀 좌표 → 세계 좌표 (EPSG:5186)""" + wx = transform.c + px * transform.a + wy = transform.f + py * transform.e + return wx, wy + + +# ── 타일 순회 ───────────────────────────────────── +total_polylines = 0 +xs_range = range(0, W - TILE // 2, STEP) +ys_range = range(0, H - TILE // 2, STEP) +total_tiles = len(xs_range) * len(ys_range) +processed = 0 + +print(f"총 타일 수(예상): {total_tiles}, 처리 시작...") + +for ty in ys_range: + for tx in xs_range: + tw = min(TILE, W - tx) + th = min(TILE, H - ty) + if tw < 64 or th < 64: + continue + + win = Window(tx, ty, tw, th) + data = src.read([1, 2, 3], window=win) + img_arr = np.transpose(data, (1, 2, 0)).astype(np.uint8) + + black_ratio = np.mean(np.all(img_arr < 10, axis=2)) + if black_ratio > BLACK_THR: + processed += 1 + continue + + pil_img = Image.fromarray(img_arr) + results = model.predict(pil_img, imgsz=TILE, conf=CONF, + device='cuda:0', verbose=False) + + if not results or results[0].masks is None: + processed += 1 + continue + + masks_data = results[0].masks.data.cpu().numpy() + h_r, w_r = masks_data.shape[1], masks_data.shape[2] + sx = tw / w_r + sy = th / h_r + + for xy in results[0].masks.xy: + if len(xy) < 4: + continue + + # YOLO 좌표 → 타일 픽셀 좌표 + xy_tile = [(float(lx) * sx, float(ly) * sy) for lx, ly in xy] + + # PCA 슬라이싱으로 중심선 추출 (타일 픽셀 좌표) + cl_tile = polygon_to_centerline(xy_tile, spacing_px=SPACING) + if len(cl_tile) < 2: + continue + + # overlap 가장자리 제거 + 세계 좌표 변환 + pts_world = [] + for lx, ly in cl_tile: + if tx + tw < W and lx > (tw - OVERLAP // 2): + continue + if ty + th < H and ly > (th - OVERLAP // 2): + continue + pts_world.append(pixel_to_world(tx + lx, ty + ly)) + + if len(pts_world) >= 2: + msp.add_lwpolyline(pts_world, + dxfattribs={"layer": "RAIL_CENTERLINE"}) + total_polylines += 1 + + processed += 1 + if processed % 500 == 0: + pct = processed / total_tiles * 100 + print(f" 진행: {processed}/{total_tiles} ({pct:.1f}%)" + f" | 폴리라인: {total_polylines}") + +src.close() + +doc.saveas(OUT_DXF) +print(f"\n완료: {total_polylines}개 폴리라인 → {OUT_DXF}") diff --git a/tools/rail_to_dxf.py b/tools/rail_to_dxf.py new file mode 100644 index 0000000..66b728c --- /dev/null +++ b/tools/rail_to_dxf.py @@ -0,0 +1,234 @@ +""" +rail_to_dxf.py +============== +X-AnyLabeling JSON 어노테이션에서 레일 중심선을 추출하여 Rhino용 DXF로 저장. + +사용법: + python tools/rail_to_dxf.py [output.dxf] + +예시: + python tools/rail_to_dxf.py images/rail.json + python tools/rail_to_dxf.py images/rail.json output/rail_centerline.dxf + +라벨 이름 (X-AnyLabeling에서 Finish 후 입력한 이름): + rail, railway_track, track, AUTOLABEL_OBJECT 자동 인식 + 다른 이름이면 스크립트 하단 TARGET_LABELS 수정 + +Rhino에서 사용: + 1. DXF Import + 2. RAIL_CENTERLINE 레이어 선택 + 3. Sweep2 또는 Rail Sweep으로 레일 단면 적용 +""" + +import json +import sys +import numpy as np +import cv2 +from pathlib import Path +from skimage.morphology import skeletonize + + +# ─── 설정 ───────────────────────────────────────────────────────────────────── +TARGET_LABELS = [ + "rail", + "railline", + "railway_track", + "track", + "레일", + "철로", + "AUTOLABEL_OBJECT", +] +# line/linestrip 타입은 스켈레톤 불필요 — 직접 DXF 출력 +LINE_SHAPE_TYPES = {"line", "linestrip", "lines"} +SMOOTH_WINDOW = 15 # 중심선 스무딩 강도 (클수록 부드러움, 0=비활성) +DOWNSAMPLE_STEP = 8 # Rhino 폴리라인 포인트 간격 (클수록 포인트 수 감소) +DILATION_ITER = 3 # 마스크 팽창 반복 (얇은 레일 마스크 연결 보완) +# ────────────────────────────────────────────────────────────────────────────── + + +def polygon_to_mask(points, h, w): + mask = np.zeros((h, w), dtype=np.uint8) + pts = np.array([[int(p[0]), int(p[1])] for p in points], dtype=np.int32) + cv2.fillPoly(mask, [pts], 1) + return mask + + +def extract_skeleton(mask): + kernel = np.ones((3, 3), np.uint8) + dilated = cv2.dilate(mask, kernel, iterations=DILATION_ITER) + return skeletonize(dilated > 0).astype(np.uint8) + + +def order_skeleton(skeleton): + """스켈레톤 픽셀을 끝점에서 시작해 순서대로 연결.""" + ys, xs = np.where(skeleton > 0) + if len(xs) == 0: + return [] + + pt_set = set(zip(xs.tolist(), ys.tolist())) + + def neighbors(pt): + x, y = pt + return [(x+dx, y+dy) + for dx in (-1,0,1) for dy in (-1,0,1) + if not (dx==0 and dy==0) and (x+dx, y+dy) in pt_set] + + # 끝점(이웃 1개) 찾기 → 없으면 임의 시작 + endpoints = [p for p in pt_set if len(neighbors(p)) == 1] + start = endpoints[0] if endpoints else next(iter(pt_set)) + + ordered, visited = [start], {start} + current = start + + while True: + nbs = [n for n in neighbors(current) if n not in visited] + if not nbs: + break + # 가장 직선에 가까운 방향 우선 선택 + if len(ordered) >= 2: + dx = current[0] - ordered[-2][0] + dy = current[1] - ordered[-2][1] + nbs.sort(key=lambda n: -((n[0]-current[0])*dx + (n[1]-current[1])*dy)) + current = nbs[0] + visited.add(current) + ordered.append(current) + + return ordered + + +def smooth_polyline(points, window): + if window < 3 or len(points) < window: + return points + pts = np.array(points, dtype=float) + half = window // 2 + out = pts.copy() + for i in range(half, len(pts) - half): + out[i] = pts[i-half:i+half+1].mean(axis=0) + return out.tolist() + + +def downsample(points, step): + if step <= 1 or len(points) <= step: + return points + sampled = points[::step] + if list(sampled[-1]) != list(points[-1]): + sampled = list(sampled) + [points[-1]] + return sampled + + +def to_dxf_coords(points, flip_y=True): + """이미지 좌표(Y↓) → DXF 좌표(Y↑)""" + if flip_y: + return [(float(p[0]), float(-p[1])) for p in points] + return [(float(p[0]), float(p[1])) for p in points] + + +def process(json_path: str, dxf_path: str): + import ezdxf + + with open(json_path, encoding="utf-8") as f: + data = json.load(f) + + H = data.get("imageHeight", 1000) + W = data.get("imageWidth", 1000) + print(f"[이미지] {W} × {H} px") + + shapes = data.get("shapes", []) + targets = [s for s in shapes if s.get("label") in TARGET_LABELS] + + if not targets: + print("⚠ TARGET_LABELS 일치 없음 → 모든 shape 사용") + targets = shapes + + type_counts = {} + for s in targets: + t = s.get("shape_type", "unknown") + type_counts[t] = type_counts.get(t, 0) + 1 + print(f"[처리] {len(targets)}개: {type_counts}") + + doc = ezdxf.new("R2010") + msp = doc.modelspace() + doc.layers.add("RAIL_CENTERLINE", color=1) # 빨강 — Sweep 경로 + doc.layers.add("RAIL_POLYGON", color=3) # 초록 — 원본 마스크 윤곽 + + for i, shape in enumerate(targets): + label = shape.get("label", f"shape_{i}") + pts = shape.get("points", []) + shape_type = shape.get("shape_type", "polygon") + print(f"\n [{i+1}] label={label!r} type={shape_type!r} pts={len(pts)}") + + if len(pts) < 2: + print(" → 포인트 부족, 건너뜀") + continue + + # 중복 끝점 제거 (Polygon 도구로 그린 선: 마지막 점이 앞 점과 거의 동일) + if len(pts) >= 3: + dedup = [pts[0]] + for p in pts[1:]: + if abs(p[0]-dedup[-1][0]) > 2 or abs(p[1]-dedup[-1][1]) > 2: + dedup.append(p) + if len(dedup) < len(pts): + print(f" 중복점 제거: {len(pts)}pt → {len(dedup)}pt") + pts = dedup + + # ── LINE 타입: 스켈레톤 없이 직접 DXF 출력 ────── + if shape_type in LINE_SHAPE_TYPES or len(pts) == 2: + cl_dxf = to_dxf_coords(pts) + msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"}) + print(f" → 라인 직접 출력 ({len(pts)}pt)") + continue + + # ── POLYGON 타입: 마스크 → 스켈레톤 → 중심선 ──── + if len(pts) < 3: + print(" → 폴리곤 포인트 부족, 건너뜀") + continue + + poly_dxf = to_dxf_coords(pts) + poly_dxf.append(poly_dxf[0]) + msp.add_lwpolyline(poly_dxf, dxfattribs={"layer": "RAIL_POLYGON"}) + + mask = polygon_to_mask(pts, H, W) + area = int(mask.sum()) + print(f" 마스크 면적: {area} px²") + if area < 20: + print(" → 면적 너무 작음, 건너뜀") + continue + + skel = extract_skeleton(mask) + skel_n = int(skel.sum()) + print(f" 스켈레톤 픽셀: {skel_n}") + if skel_n < 2: + print(" → 스켈레톤 생성 실패, 건너뜀") + continue + + ordered = order_skeleton(skel) + smoothed = smooth_polyline(ordered, SMOOTH_WINDOW) + final_pts = downsample(smoothed, DOWNSAMPLE_STEP) + print(f" 중심선 포인트: {len(final_pts)}") + if len(final_pts) < 2: + continue + + cl_dxf = to_dxf_coords(final_pts) + msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"}) + + Path(dxf_path).parent.mkdir(parents=True, exist_ok=True) + doc.saveas(dxf_path) + print(f"\n[완료] DXF 저장: {dxf_path}") + print(" 레이어: RAIL_CENTERLINE(빨강) = Sweep 경로") + print(" RAIL_POLYGON(초록) = 원본 마스크") + print("\nRhino 사용법:") + print(" 1. File → Import → DXF 선택") + print(" 2. RAIL_CENTERLINE 레이어 선택") + print(" 3. Rail 단면 커브 그리기 (레일 단면 약 60x60mm)") + print(" 4. Sweep1 또는 Sweep2 명령으로 단면 × 중심선 → 3D 레일 생성") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("사용법: python tools/rail_to_dxf.py [output.dxf]") + sys.exit(1) + + json_file = sys.argv[1] + dxf_file = sys.argv[2] if len(sys.argv) >= 3 else str(Path(json_file).with_suffix(".dxf")) + + process(json_file, dxf_file) diff --git a/tools/railway_pipeline.py b/tools/railway_pipeline.py new file mode 100644 index 0000000..e56547c --- /dev/null +++ b/tools/railway_pipeline.py @@ -0,0 +1,870 @@ +""" +railway_pipeline.py +=================== +정사영상(GeoTIFF/PNG)에서 철도 시설물을 자동 검출하여 +실좌표(UTM/WGS84) 기반 DXF + GeoJSON 출력. + +이미지에 실제로 보이는 것을 그대로 검출 (표준 규격 기반 아님). + +사용법: + python tools/railway_pipeline.py [output_dir] + +예시: + python tools/railway_pipeline.py 경부선.tif output/ + python tools/railway_pipeline.py 경부선.png output/ # PNG는 GSD=0.05m 가정 + +출력: + output/railway_rails.dxf - 레일 중심선 + output/railway_objects.dxf - 전체 시설물 (레이어별) + output/railway_objects.geojson - GIS 활용용 + output/railway_debug.jpg - 검출 결과 시각화 +""" + +import sys +import json +import math +import numpy as np +import cv2 +from pathlib import Path +from collections import defaultdict + + +# ─── 검출 파라미터 ───────────────────────────────────────────────────────────── +RAIL = dict( + canny_low=30, canny_high=80, + hough_threshold=400, min_len=500, max_gap=30, + cluster_dist=14, min_total_len=3000, +) +SLEEPER = dict( + search_width=60, # 레일 양옆 검색 폭 (px) + min_area=30, # 최소 면적 + max_area=2000, # 최대 면적 + aspect_min=2.0, # 최소 가로세로비 (침목은 길쭉) +) +POLE = dict( + search_margin=200, # 레일 옆 검색 범위 (px) + min_radius=3, # 최소 반지름 (px) + max_radius=25, # 최대 반지름 (px) + min_dist=40, # 검출 최소 간격 (px) + param1=50, param2=25, # Hough Circle 파라미터 +) +CBOX = dict( + search_margin=300, # 레일 옆 검색 범위 (px) + min_area=100, + max_area=8000, + aspect_min=1.2, # 직사각형 비율 + aspect_max=6.0, + min_solidity=0.75, # 컨투어 충실도 (직사각형에 가까울수록 1) +) +GABOR = dict( + lambdas=[5, 7, 10], # 침목 간격 범위 (px, 축소 이미지 기준) + # 5cm GSD, 8000px 폭(scale≈0.54): 600mm/50mm*0.54 ≈ 6.5px + sigma=3.0, # Gabor 커널 폭 + gamma=0.5, # 종횡비 (1=원형, 0.5=타원) + n_angles=12, # 방향 수 (0°~165°, 15° 간격) + thresh_pct=60, # 응답 임계값 백분위 (높을수록 엄격) + min_area=4000, # 최소 연결성분 면적 (작은 노이즈 제거) +) +SCALE_FACTOR = 8000 # 침목 패턴 감지를 위해 높은 해상도 유지 +# ────────────────────────────────────────────────────────────────────────────── + + +# ══════════════════════════════════════════════════════════════════════════════ +# 좌표 변환 +# ══════════════════════════════════════════════════════════════════════════════ + +class GeoTransform: + """픽셀좌표 ↔ 실좌표 변환.""" + + def __init__(self, image_path: str): + self.crs = None + self.affine = None + self.origin = (0.0, 0.0) + self.pixel_size = (0.05, 0.05) # 기본 5cm GSD + self._load(image_path) + + def _load(self, path: str): + if path.lower().endswith(('.tif', '.tiff')): + try: + import rasterio + with rasterio.open(path) as r: + self.crs = str(r.crs) + t = r.transform + self.affine = t + self.pixel_size = (abs(t.a), abs(t.e)) + self.origin = (t.c, t.f) + print(f" CRS: {self.crs}") + print(f" 픽셀크기: {self.pixel_size[0]:.4f}m x {self.pixel_size[1]:.4f}m") + print(f" 원점: ({self.origin[0]:.2f}, {self.origin[1]:.2f})") + except Exception as e: + print(f" [경고] rasterio 실패: {e} → 픽셀좌표 사용") + else: + # PNG: world file 탐색 + wld = Path(path).with_suffix('.pgw') + if not wld.exists(): + wld = Path(path).with_suffix('.wld') + if wld.exists(): + vals = [float(l.strip()) for l in wld.read_text().splitlines() if l.strip()] + if len(vals) >= 6: + self.pixel_size = (abs(vals[0]), abs(vals[3])) + self.origin = (vals[4], vals[5]) + print(f" World file 적용: origin={self.origin}") + else: + print(f" World file 없음 → 픽셀좌표 사용 (GSD=0.05m)") + + def px_to_world(self, px: float, py: float): + """픽셀 (col, row) → 실좌표 (x, y).""" + if self.affine: + from rasterio.transform import xy as rio_xy + try: + x, y = self.affine * (px, py) + return float(x), float(y) + except Exception: + pass + x = self.origin[0] + px * self.pixel_size[0] + y = self.origin[1] - py * self.pixel_size[1] + return x, y + + def scale_coords(self, pts, scale: float): + """축소된 이미지의 픽셀좌표를 원본 픽셀좌표로 변환 후 실좌표 출력.""" + return [self.px_to_world(p[0] / scale, p[1] / scale) for p in pts] + + def scale_point(self, px: float, py: float, scale: float): + return self.px_to_world(px / scale, py / scale) + + +# ══════════════════════════════════════════════════════════════════════════════ +# 이미지 로드 (한글 경로 대응) +# ══════════════════════════════════════════════════════════════════════════════ + +def load_image(path: str): + """TIF/PNG 모두 BGR numpy array로 반환.""" + if path.lower().endswith(('.tif', '.tiff')): + try: + import rasterio + with rasterio.open(path) as r: + if r.count >= 3: + arr = np.dstack([r.read(i) for i in [1, 2, 3]]) + # uint16 → uint8 변환 + if arr.dtype != np.uint8: + arr = (arr / arr.max() * 255).astype(np.uint8) + return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) + else: + band = r.read(1) + if band.dtype != np.uint8: + band = (band / band.max() * 255).astype(np.uint8) + return cv2.cvtColor(band, cv2.COLOR_GRAY2BGR) + except Exception as e: + print(f" rasterio 읽기 실패: {e}, PIL 시도...") + # PNG 또는 fallback + with open(path, 'rb') as f: + data = f.read() + arr = np.frombuffer(data, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + return img + + +# ══════════════════════════════════════════════════════════════════════════════ +# 자갈도상 마스크 +# ══════════════════════════════════════════════════════════════════════════════ + +def extract_ballast_mask(img): + """자갈도상(ballast) 구역 마스크 추출. + 항공 정사영상에서 자갈도상은 밝고(V>70) 채도 낮음(S<60) — HSV 임계값 후 + 큰 연결성분만 유지하여 선로 띠 영역만 남김. + """ + H, W = img.shape[:2] + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + + # 낮은 채도 + 중간-높은 밝기 = 자갈/회색 (채도 45 이하로 더 엄격하게) + mask = cv2.inRange(hsv, + np.array([0, 0, 80], np.uint8), + np.array([180, 45, 210], np.uint8)) + + # 작은 노이즈 제거 + k_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k_open, iterations=2) + + # 팽창: 자갈 구역 연결 + 레일 엣지까지 포함 (2회로 줄여 과팽창 방지) + k_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + mask = cv2.dilate(mask, k_dil, iterations=2) + + # 연결성분 중 이미지 면적 0.5% 이상만 유지 (도로·건물 등 소규모 제거) + min_area = max(H * W * 0.005, 5000) + n, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + result = np.zeros_like(mask) + for lbl in range(1, n): + if stats[lbl, cv2.CC_STAT_AREA] >= min_area: + result[labels == lbl] = 255 + + kept = int((result > 0).sum()) + total = H * W + print(f" 자갈도상 마스크: {kept:,}px ({kept/total*100:.1f}%)") + return result + + +# ══════════════════════════════════════════════════════════════════════════════ +# Gabor 침목 패턴 마스크 + 레일 검출 (통합) +# ══════════════════════════════════════════════════════════════════════════════ + +def detect_rails_gabor(img): + """ + [1단계] Gabor 필터 뱅크로 침목 줄무늬 패턴 영역 추출 + → 자갈도상 색상이 아닌 침목 주기 텍스처로 선로 구역 판별 + → 도로는 이 패턴 없으므로 자동 제거 + + [2단계] 침목 마스크 안에서 Hough로 개별 레일 중심선 검출 + + Returns: (rail_polylines, gabor_mask) + """ + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + H, W = gray.shape + + # ── Gabor 필터 뱅크 적용 ────────────────────────────────── + response_max = np.zeros((H, W), dtype=np.float32) + n_ang = GABOR['n_angles'] + for i in range(n_ang): + theta = i * np.pi / n_ang # 0 ~ π (모든 방향) + for lam in GABOR['lambdas']: + ksize = max(int(GABOR['sigma'] * 6) | 1, 7) # 홀수 보장 + kernel = cv2.getGaborKernel( + (ksize, ksize), + GABOR['sigma'], theta, lam, + GABOR['gamma'], 0, cv2.CV_32F + ) + resp = np.abs(cv2.filter2D(gray.astype(np.float32), cv2.CV_32F, kernel)) + response_max = np.maximum(response_max, resp) + + # ── 응답 정규화 + 임계값 ────────────────────────────────── + resp_u8 = cv2.normalize(response_max, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + nz = resp_u8[resp_u8 > 0] + thresh_val = int(np.percentile(nz, GABOR['thresh_pct'])) if len(nz) else 128 + _, binary = cv2.threshold(resp_u8, thresh_val, 255, cv2.THRESH_BINARY) + + # ── 모폴로지 정리 ───────────────────────────────────────── + k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, k5, iterations=2) + binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, k3, iterations=1) + + # ── 큰 연결성분만 유지 (소규모 노이즈 제거) ─────────────── + n_cc, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8) + gabor_mask = np.zeros_like(binary) + for lbl in range(1, n_cc): + if stats[lbl, cv2.CC_STAT_AREA] >= GABOR['min_area']: + gabor_mask[labels == lbl] = 255 + + kept = int((gabor_mask > 0).sum()) + print(f" Gabor 마스크: {kept:,}px ({kept / (H * W) * 100:.1f}%)") + + # ── Gabor 마스크 내에서 Hough 레일 중심선 검출 ─────────── + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(gray) + blurred = cv2.GaussianBlur(enhanced, (5, 5), 0) + edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high']) + edges = cv2.bitwise_and(edges, edges, mask=gabor_mask) + print(f" 마스크 내 엣지: {int(edges.sum() // 255):,}px") + + lines_raw = cv2.HoughLinesP( + edges, 1, np.pi / 180, + threshold=RAIL['hough_threshold'], + minLineLength=RAIL['min_len'], + maxLineGap=RAIL['max_gap'], + ) + if lines_raw is None: + print(" Hough 검출 없음") + return [], gabor_mask + + lines = [tuple(l[0]) for l in lines_raw] + print(f" Hough 검출: {len(lines)}개") + + # ── 클러스터링 → 폴리라인 ──────────────────────────────── + # 곡선 구간 지원: 각도 기준 아닌 미드포인트 거리 기준으로 클러스터링 + # 같은 레일의 곡선 세그먼트들은 인접 미드포인트를 가짐 → 체인으로 연결 + ATOL = RAIL['cluster_dist'] + + # 각도별 그룹화 후 수직거리 기반 클러스터링 + # 분기기 구간: 강제 연결 없이 세그먼트 단위로 유지 (분기기는 수동 연결) + def angle(x1, y1, x2, y2): + return np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180 + + def perp_dist(x1, y1, x2, y2, px, py): + dx, dy = x2 - x1, y2 - y1 + L = math.hypot(dx, dy) + if L == 0: + return math.hypot(px - x1, py - y1) + return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / L + + angle_groups = defaultdict(list) + for l in lines: + key = round(angle(*l) / 5) * 5 + angle_groups[key].append(l) + + clusters = [] + for grp in angle_groups.values(): + used = [False] * len(grp) + for i, li in enumerate(grp): + if used[i]: + continue + cl = [li] + used[i] = True + for j, lj in enumerate(grp): + if used[j]: + continue + mxj = ((lj[0]+lj[2])/2, (lj[1]+lj[3])/2) + if perp_dist(*li, *mxj) < ATOL: + cl.append(lj) + used[j] = True + clusters.append(cl) + + def merge(cluster): + pts = [] + for x1, y1, x2, y2 in cluster: + pts += [(x1, y1), (x2, y2)] + arr = np.array(pts, float) + mean = arr.mean(0) + cov = np.cov((arr - mean).T) + if cov.ndim < 2: + direction = np.array([1.0, 0.0]) + else: + ev, evec = np.linalg.eig(cov) + direction = evec[:, np.argmax(ev)] + proj = (arr - mean).dot(direction) + sorted_pts = arr[np.argsort(proj)] + merged = [sorted_pts[0]] + for p in sorted_pts[1:]: + if math.hypot(p[0] - merged[-1][0], p[1] - merged[-1][1]) > 5: + merged.append(p) + return merged + + def poly_len(poly): + return sum( + math.hypot(poly[i + 1][0] - poly[i][0], poly[i + 1][1] - poly[i][1]) + for i in range(len(poly) - 1) + ) + + result = [] + for cl in clusters: + poly = merge(cl) + if poly_len(poly) >= RAIL['min_total_len']: + result.append(poly) + result.sort(key=lambda p: -poly_len(p)) + return result, gabor_mask + + +# ══════════════════════════════════════════════════════════════════════════════ +# 레일 검출 (Hough — fallback용, 주 검출은 detect_rails_gabor 사용) +# ══════════════════════════════════════════════════════════════════════════════ + +def detect_rails(img, ballast_mask=None): + """Hough Line으로 레일 중심선 폴리라인 반환.""" + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) + enhanced = clahe.apply(gray) + blurred = cv2.GaussianBlur(enhanced, (5,5), 0) + edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high']) + + # 자갈도상 마스크 적용 — 선로 구역 외 엣지 제거 + if ballast_mask is not None: + edges = cv2.bitwise_and(edges, edges, mask=ballast_mask) + print(f" 마스크 적용 후 엣지: {int(edges.sum()//255):,}px") + + lines_raw = cv2.HoughLinesP(edges, 1, np.pi/180, + threshold=RAIL['hough_threshold'], + minLineLength=RAIL['min_len'], + maxLineGap=RAIL['max_gap']) + if lines_raw is None: + return [] + + lines = [tuple(l[0]) for l in lines_raw] + + def angle(x1,y1,x2,y2): + return np.degrees(np.arctan2(y2-y1, x2-x1)) % 180 + + def line_len(x1,y1,x2,y2): + return math.hypot(x2-x1, y2-y1) + + def midpoint(x1,y1,x2,y2): + return ((x1+x2)/2, (y1+y2)/2) + + def perp_dist(x1,y1,x2,y2,px,py): + dx,dy = x2-x1, y2-y1 + L = math.hypot(dx,dy) + if L == 0: return math.hypot(px-x1, py-y1) + return abs(dy*px - dx*py + x2*y1 - y2*x1) / L + + # 방향 필터 불필요 — 자갈도상 마스크가 공간 제약 역할을 하므로 + # 커브 구간도 포함하려면 모든 방향 허용 + print(f" Hough 검출: {len(lines)}개") + + # 각도별 그룹화 후 거리 클러스터링 + ATOL = RAIL['cluster_dist'] + angle_groups = defaultdict(list) + for l in lines: + key = round(angle(*l) / 5) * 5 + angle_groups[key].append(l) + + clusters = [] + for grp in angle_groups.values(): + used = [False]*len(grp) + for i, li in enumerate(grp): + if used[i]: continue + cl = [li]; used[i] = True + mx, my = midpoint(*li) + for j, lj in enumerate(grp): + if used[j]: continue + if perp_dist(*li, *midpoint(*lj)) < ATOL: + cl.append(lj); used[j] = True + clusters.append(cl) + + def merge(cluster): + pts = [] + for x1,y1,x2,y2 in cluster: + pts += [(x1,y1),(x2,y2)] + arr = np.array(pts, float) + mean = arr.mean(0) + cov = np.cov((arr - mean).T) + if cov.ndim < 2: direction = np.array([1.0,0.0]) + else: + ev, evec = np.linalg.eig(cov) + direction = evec[:, np.argmax(ev)] + proj = (arr - mean).dot(direction) + idx = np.argsort(proj) + sorted_pts = arr[idx] + merged = [sorted_pts[0]] + for p in sorted_pts[1:]: + if math.hypot(p[0]-merged[-1][0], p[1]-merged[-1][1]) > 5: + merged.append(p) + return merged + + def poly_len(poly): + return sum(math.hypot(poly[i+1][0]-poly[i][0], poly[i+1][1]-poly[i][1]) + for i in range(len(poly)-1)) + + result = [] + for cl in clusters: + poly = merge(cl) + if poly_len(poly) >= RAIL['min_total_len']: + result.append(poly) + result.sort(key=lambda p: -poly_len(p)) + return result + + +# ══════════════════════════════════════════════════════════════════════════════ +# 침목 검출 +# ══════════════════════════════════════════════════════════════════════════════ + +def detect_sleepers(img, rail_polys, gabor_mask=None): + """레일 영역 주변에서 침목(가로 직사각형) 검출.""" + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + # 레일 위치로 마스크 생성 + mask = np.zeros(gray.shape, np.uint8) + sw = SLEEPER['search_width'] + for poly in rail_polys: + pts = np.array([[int(p[0]), int(p[1])] for p in poly]) + cv2.polylines(mask, [pts], False, 255, sw * 2) + masked = cv2.bitwise_and(thresh, thresh, mask=mask) + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) + cleaned = cv2.morphologyEx(masked, cv2.MORPH_CLOSE, kernel) + + contours, _ = cv2.findContours(cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + sleepers = [] + for cnt in contours: + area = cv2.contourArea(cnt) + if not (SLEEPER['min_area'] <= area <= SLEEPER['max_area']): + continue + rect = cv2.minAreaRect(cnt) + (cx, cy), (w, h), angle = rect + if w == 0 or h == 0: continue + aspect = max(w, h) / min(w, h) + if aspect < SLEEPER['aspect_min']: + continue + sleepers.append({ + 'center': (float(cx), float(cy)), + 'size': (float(w), float(h)), + 'angle': float(angle), + 'area': float(area), + }) + return sleepers + + +# ══════════════════════════════════════════════════════════════════════════════ +# 전철주 검출 +# ══════════════════════════════════════════════════════════════════════════════ + +def detect_poles(img, rail_polys): + """레일 주변에서 원형/점형 전철주 검출 (Hough Circle + Blob).""" + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + blurred = cv2.GaussianBlur(gray, (9,9), 2) + + # 레일 근처 마스크 + mask = np.zeros(gray.shape, np.uint8) + for poly in rail_polys: + pts = np.array([[int(p[0]), int(p[1])] for p in poly]) + cv2.polylines(mask, [pts], False, 255, POLE['search_margin'] * 2) + + masked = cv2.bitwise_and(blurred, blurred, mask=mask) + + circles = cv2.HoughCircles( + masked, cv2.HOUGH_GRADIENT, dp=1, + minDist=POLE['min_dist'], + param1=POLE['param1'], + param2=POLE['param2'], + minRadius=POLE['min_radius'], + maxRadius=POLE['max_radius'], + ) + + poles = [] + if circles is not None: + for (cx, cy, r) in circles[0]: + poles.append({ + 'center': (float(cx), float(cy)), + 'radius': float(r), + }) + return poles + + +# ══════════════════════════════════════════════════════════════════════════════ +# 컨트롤박스 검출 +# ══════════════════════════════════════════════════════════════════════════════ + +def detect_control_boxes(img, rail_polys): + """레일 근처 직사각형 객체 검출.""" + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + blurred = cv2.GaussianBlur(gray, (5,5), 0) + edges = cv2.Canny(blurred, 30, 90) + + # 레일 근처 마스크 + mask = np.zeros(gray.shape, np.uint8) + for poly in rail_polys: + pts = np.array([[int(p[0]), int(p[1])] for p in poly]) + cv2.polylines(mask, [pts], False, 255, CBOX['search_margin'] * 2) + masked_edges = cv2.bitwise_and(edges, edges, mask=mask) + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3)) + closed = cv2.morphologyEx(masked_edges, cv2.MORPH_CLOSE, kernel, iterations=2) + + contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + area = cv2.contourArea(cnt) + if not (CBOX['min_area'] <= area <= CBOX['max_area']): + continue + hull = cv2.convexHull(cnt) + hull_area = cv2.contourArea(hull) + if hull_area == 0: continue + solidity = area / hull_area + if solidity < CBOX['min_solidity']: + continue + rect = cv2.minAreaRect(cnt) + (cx, cy), (w, h), angle = rect + if w == 0 or h == 0: continue + aspect = max(w, h) / min(w, h) + if not (CBOX['aspect_min'] <= aspect <= CBOX['aspect_max']): + continue + boxes.append({ + 'center': (float(cx), float(cy)), + 'size': (float(w), float(h)), + 'angle': float(angle), + 'area': float(area), + }) + return boxes + + +# ══════════════════════════════════════════════════════════════════════════════ +# DXF 출력 +# ══════════════════════════════════════════════════════════════════════════════ + +def save_dxf(output_path, rails_world, sleepers_world, poles_world, boxes_world): + import ezdxf + doc = ezdxf.new("R2010") + msp = doc.modelspace() + + # 레이어 정의 + doc.layers.add("RAIL", color=1) # 빨강 + doc.layers.add("SLEEPER", color=3) # 초록 + doc.layers.add("POLE", color=5) # 파랑 + doc.layers.add("CONTROL_BOX", color=6) # 마젠타 + + # 레일 중심선 + for poly in rails_world: + if len(poly) >= 2: + msp.add_lwpolyline(poly, dxfattribs={"layer": "RAIL"}) + + # 침목 + for s in sleepers_world: + cx, cy = s['world_center'] + w, h = s['size'] + ang = s['angle'] + scale = s.get('scale', 1.0) + # 회전된 직사각형 + box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang)) + msp.add_lwpolyline( + [(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))], + dxfattribs={"layer": "SLEEPER"} + ) + + # 전철주 (원) + for p in poles_world: + cx, cy = p['world_center'] + r = p.get('world_radius', 1.0) + msp.add_circle((cx, -cy), r, dxfattribs={"layer": "POLE"}) + + # 컨트롤박스 + for b in boxes_world: + cx, cy = b['world_center'] + w, h = b['size'] + ang = b['angle'] + scale = b.get('scale', 1.0) + box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang)) + msp.add_lwpolyline( + [(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))], + dxfattribs={"layer": "CONTROL_BOX"} + ) + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + doc.saveas(output_path) + print(f" DXF: {output_path}") + + +# ══════════════════════════════════════════════════════════════════════════════ +# GeoJSON 출력 +# ══════════════════════════════════════════════════════════════════════════════ + +def save_geojson(output_path, rails_world, sleepers_world, poles_world, boxes_world): + features = [] + + for i, poly in enumerate(rails_world): + features.append({ + "type": "Feature", + "geometry": {"type": "LineString", "coordinates": [[x, y] for x, y in poly]}, + "properties": {"type": "rail", "id": i} + }) + for i, s in enumerate(sleepers_world): + cx, cy = s['world_center'] + features.append({ + "type": "Feature", + "geometry": {"type": "Point", "coordinates": [cx, cy]}, + "properties": {"type": "sleeper", "id": i, "angle": s['angle']} + }) + for i, p in enumerate(poles_world): + cx, cy = p['world_center'] + features.append({ + "type": "Feature", + "geometry": {"type": "Point", "coordinates": [cx, cy]}, + "properties": {"type": "pole", "id": i, "radius_m": p.get('world_radius', 0)} + }) + for i, b in enumerate(boxes_world): + cx, cy = b['world_center'] + features.append({ + "type": "Feature", + "geometry": {"type": "Point", "coordinates": [cx, cy]}, + "properties": {"type": "control_box", "id": i, + "width_m": b.get('world_w', 0), "height_m": b.get('world_h', 0)} + }) + + geojson = {"type": "FeatureCollection", "features": features} + if not output_path.endswith('.geojson'): + output_path += '.geojson' + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(geojson, f, ensure_ascii=False, indent=2) + print(f" GeoJSON: {output_path}") + + +# ══════════════════════════════════════════════════════════════════════════════ +# 시각화 +# ══════════════════════════════════════════════════════════════════════════════ + +def save_debug(img, rails, sleepers, poles, boxes, output_path, ballast_mask=None): + vis = img.copy() + # 자갈도상 마스크 반투명 오버레이 (노란색) + if ballast_mask is not None: + overlay = vis.copy() + overlay[ballast_mask > 0] = (0, 200, 200) + cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis) + # 레일 + for poly in rails: + pts = np.array([[int(p[0]), int(p[1])] for p in poly]) + cv2.polylines(vis, [pts], False, (0,0,255), 2) + # 침목 + for s in sleepers: + cx, cy = int(s['center'][0]), int(s['center'][1]) + rect = ((cx,cy), (s['size'][0], s['size'][1]), s['angle']) + box = cv2.boxPoints(rect).astype(int) + cv2.drawContours(vis, [box], 0, (0,255,0), 1) + # 전철주 + for p in poles: + cx, cy, r = int(p['center'][0]), int(p['center'][1]), int(p['radius']) + cv2.circle(vis, (cx, cy), r, (255,0,0), 2) + # 컨트롤박스 + for b in boxes: + cx, cy = int(b['center'][0]), int(b['center'][1]) + rect = ((cx,cy), (b['size'][0], b['size'][1]), b['angle']) + box = cv2.boxPoints(rect).astype(int) + cv2.drawContours(vis, [box], 0, (255,0,255), 2) + + # 범례 + legend = [("RAIL", (0,0,255)), + ("SLEEPER", (0,255,0)), + ("POLE", (255,0,0)), + ("CONTROL_BOX", (255,0,255))] + for i, (name, color) in enumerate(legend): + cv2.rectangle(vis, (10, 10+i*28), (24, 24+i*28), color, -1) + cv2.putText(vis, name, (30, 22+i*28), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) + + with open(str(output_path), 'wb') as f: + _, buf = cv2.imencode('.jpg', vis, [cv2.IMWRITE_JPEG_QUALITY, 85]) + f.write(buf.tobytes()) + print(f" 시각화: {output_path}") + + +# ══════════════════════════════════════════════════════════════════════════════ +# 메인 파이프라인 +# ══════════════════════════════════════════════════════════════════════════════ + +def process(image_path: str, output_dir: str): + print(f"\n{'='*60}") + print(f"입력: {image_path}") + print(f"{'='*60}") + + # 좌표 변환기 초기화 + print("\n[좌표계]") + geo = GeoTransform(image_path) + + # 이미지 로드 + print("\n[이미지 로드]") + img = load_image(image_path) + if img is None: + print("오류: 이미지를 열 수 없습니다.") + sys.exit(1) + H, W = img.shape[:2] + print(f" 크기: {W} x {H} px") + + # 처리용 축소 + scale = 1.0 + if W > SCALE_FACTOR: + scale = SCALE_FACTOR / W + small = cv2.resize(img, (int(W*scale), int(H*scale))) + print(f" 축소: {int(W*scale)} x {int(H*scale)} (x{scale:.2f})") + else: + small = img.copy() + + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + stem = Path(image_path).stem + + # ── Gabor 침목 패턴 검출 + 레일 중심선 ─────────────────── + print("\n[1] Gabor 침목 패턴 → 레일 검출...") + rails_px, gabor_mask = detect_rails_gabor(small) + print(f" 검출: {len(rails_px)}개") + + # Gabor 마스크 저장 (확인용) + mask_path = out / f"{stem}_gabor_mask.jpg" + with open(str(mask_path), 'wb') as _f: + _, _buf = cv2.imencode('.jpg', gabor_mask) + _f.write(_buf.tobytes()) + print(f" Gabor 마스크: {mask_path}") + + # 레일 실좌표 변환 + rails_world = [] + for poly in rails_px: + world_pts = geo.scale_coords(poly, scale) + rails_world.append(world_pts) + + # ── 침목 검출 ────────────────────────────────────────── + print("\n[2] 침목 검출...") + sleepers_px = detect_sleepers(small, rails_px, gabor_mask) + print(f" 검출: {len(sleepers_px)}개") + + sleepers_world = [] + for s in sleepers_px: + cx, cy = s['center'] + wx, wy = geo.scale_point(cx, cy, scale) + ps = geo.pixel_size[0] / scale + sleepers_world.append({ + **s, + 'world_center': (wx, wy), + 'scale': scale, + }) + + # ── 전철주 검출 ──────────────────────────────────────── + print("\n[3] 전철주 검출...") + poles_px = detect_poles(small, rails_px) + print(f" 검출: {len(poles_px)}개") + + poles_world = [] + for p in poles_px: + cx, cy = p['center'] + wx, wy = geo.scale_point(cx, cy, scale) + r_m = p['radius'] * geo.pixel_size[0] / scale + poles_world.append({ + **p, + 'world_center': (wx, wy), + 'world_radius': r_m, + }) + + # ── 컨트롤박스 검출 ──────────────────────────────────── + print("\n[4] 컨트롤박스 검출...") + boxes_px = detect_control_boxes(small, rails_px) + print(f" 검출: {len(boxes_px)}개") + + boxes_world = [] + ps = geo.pixel_size[0] + for b in boxes_px: + cx, cy = b['center'] + wx, wy = geo.scale_point(cx, cy, scale) + w_m = b['size'][0] * ps / scale + h_m = b['size'][1] * ps / scale + boxes_world.append({ + **b, + 'world_center': (wx, wy), + 'world_w': w_m, + 'world_h': h_m, + 'scale': scale, + }) + + # ── 결과 요약 ────────────────────────────────────────── + print(f"\n{'='*60}") + print(f"검출 결과 요약") + print(f" 레일: {len(rails_world):>5}개") + print(f" 침목: {len(sleepers_world):>5}개") + print(f" 전철주: {len(poles_world):>5}개") + print(f" 컨트롤박스: {len(boxes_world):>5}개") + print(f"{'='*60}") + + # ── 출력 ────────────────────────────────────────────── + print("\n[출력]") + save_dxf( + str(out / f"{stem}_railway.dxf"), + rails_world, sleepers_world, poles_world, boxes_world + ) + save_geojson( + str(out / f"{stem}_railway.geojson"), + rails_world, sleepers_world, poles_world, boxes_world + ) + save_debug( + small, rails_px, sleepers_px, poles_px, boxes_px, + out / f"{stem}_railway_debug.jpg", + ballast_mask=gabor_mask + ) + + print(f"\n완료.") + print(f" DXF 레이어: RAIL(빨강) / SLEEPER(초록) / POLE(파랑) / CONTROL_BOX(마젠타)") + print(f" Rhino: Import DXF → 레이어별 3D 모델 교체") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("사용법: python tools/railway_pipeline.py [output_dir]") + sys.exit(1) + + img_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) >= 3 else "output" + process(img_path, output_dir) diff --git a/tools/render_skeleton_overlay.py b/tools/render_skeleton_overlay.py new file mode 100644 index 0000000..be27db8 --- /dev/null +++ b/tools/render_skeleton_overlay.py @@ -0,0 +1,535 @@ +"""SAM 원본 폴리곤의 skeleton을 까만색으로 이미지에 겹쳐 표시. +union mask 기준 skeleton을 그리므로, 겹쳐있는 fragment들의 합쳐진 골격이 보임. + +사용: + python tools/render_skeleton_overlay.py \ + --image data/역사이미지/slope/DJI_20260306113842_0006.JPG \ + --label output/autolabel/raw_0006/labels/DJI_20260306113842_0006.txt \ + --output output/autolabel/raw_0006/vis_skeleton.jpg +""" +import argparse +import base64 +import requests +import cv2 +import numpy as np +from pathlib import Path + +SAM3_SERVER = "http://localhost:8000" +SAM3_MODEL_ID = "segment_anything_3" + +SUB_CATEGORIES = [ + {"name": "pole_shaft", "name_kr": "전주기둥", + "prompt": "vertical steel pole shaft, cylindrical mast, tubular pole body", + "keywords": ["pole", "shaft", "mast", "tubular"], + "color_bgr": [0, 200, 255]}, + {"name": "bracket_arm", "name_kr": "브라켓/암", + "prompt": "horizontal cantilever arm, bracket arm, cross arm, pole cross arm", + "keywords": ["bracket", "cantilever", "cross arm", "arm"], + "color_bgr": [255, 100, 0]}, + {"name": "insulator", "name_kr": "애자", + "prompt": "ceramic insulator, glass insulator, suspension insulator, string of insulators", + "keywords": ["insulator"], + "color_bgr": [0, 255, 255]}, + {"name": "wire", "name_kr": "가선/조가선", + "prompt": "overhead contact wire, catenary wire, messenger wire, electric cable", + "keywords": ["wire", "cable", "catenary"], + "color_bgr": [200, 200, 255]}, +] + + +def _sam3_call(crop_bgr: np.ndarray, prompt: str, conf: float = 0.15) -> list: + h, w = crop_bgr.shape[:2] + scale = 1.0 + if max(h, w) > 1280: + scale = 1280 / max(h, w) + crop_bgr = cv2.resize(crop_bgr, (int(w * scale), int(h * scale))) + _, buf = cv2.imencode(".jpg", crop_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) + b64 = base64.b64encode(buf).decode() + payload = { + "model": SAM3_MODEL_ID, + "image": b64, + "params": {"text_prompt": prompt, "conf_threshold": conf, + "show_masks": True, "show_boxes": False}, + } + try: + r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120) + r.raise_for_status() + resp = r.json() + if not resp.get("success"): + return [] + shapes = resp.get("data", {}).get("shapes", []) + shapes = [s if isinstance(s, dict) else s.dict() for s in shapes] + if scale < 1.0: + inv = 1.0 / scale + for s in shapes: + if s.get("shape_type") == "polygon": + s["points"] = [[x * inv, y * inv] for x, y in s["points"]] + return [s for s in shapes if s.get("shape_type") == "polygon"] + except Exception as e: + print(f" SAM3 오류: {e}") + return [] + + +def _label_to_subcat(label_str: str): + ll = label_str.lower() + for cat in SUB_CATEGORIES: + if any(kw in ll for kw in cat["keywords"]): + return cat + return None + + +def _subdetect_group(img: np.ndarray, group_polys: list, pad: int = 80) -> int: + """그룹 폴리곤들의 union bbox crop → SAM3 sub-detect → img 오버레이.""" + all_pts = np.vstack(group_polys) + H, W = img.shape[:2] + x0 = max(0, int(all_pts[:, 0].min()) - pad) + y0 = max(0, int(all_pts[:, 1].min()) - pad) + x1 = min(W, int(all_pts[:, 0].max()) + pad) + y1 = min(H, int(all_pts[:, 1].max()) + pad) + + crop = img[y0:y1, x0:x1].copy() + combined_prompt = ", ".join(c["prompt"] for c in SUB_CATEGORIES) + shapes = _sam3_call(crop, combined_prompt) + + font = cv2.FONT_HERSHEY_SIMPLEX + for s in shapes: + cat = _label_to_subcat(s.get("label", "")) + color = tuple(cat["color_bgr"]) if cat else (128, 128, 128) + pts = np.array([[int(px + x0), int(py + y0)] for px, py in s["points"]], dtype=np.int32) + overlay = img.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.30, img, 0.70, 0, img) + cv2.polylines(img, [pts], True, color, 2) + cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean()) + name = cat["name"] if cat else "?" + cv2.putText(img, name, (cx, cy), font, 0.7, color, 2, cv2.LINE_AA) + + cv2.rectangle(img, (x0, y0), (x1, y1), (255, 200, 0), 2) + return len(shapes) + + +def render(image_path: Path, label_path: Path, output_path: Path, + selected_indices=None, branch_radius: int = 10, class_ids=None, + subdetect: bool = False, subdetect_polys: set = None): + buf = np.fromfile(str(image_path), dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + img_orig = img.copy() + H, W = img.shape[:2] + + text = label_path.read_text(encoding="utf-8").strip() if label_path.exists() else "" + if class_ids is not None: + text = "\n".join( + ln for ln in text.splitlines() + if ln.split() and int(ln.split()[0]) in class_ids + ) + all_polygons = [] + for line in text.splitlines(): + parts = line.split() + if not parts: + continue + coords = list(map(float, parts[1:])) + pts = np.array( + [[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)], + dtype=np.int32, + ) + if len(pts) >= 3: + all_polygons.append(pts) + + if selected_indices is not None: + keep_idx = [i for i in sorted(selected_indices) if 0 <= i < len(all_polygons)] + else: + keep_idx = list(range(len(all_polygons))) + polygons = [all_polygons[i] for i in keep_idx] + print(f" total polygons in label: {len(all_polygons)}, " + f"selected: {len(polygons)} (indices: {keep_idx})") + + # 수직/수평 판별: 장축 vs 이미지 중심→폴리곤 중심 radial 방향 비교 + # 드론 사선 촬영에서 수직 기둥은 이미지 중심에서 방사형으로 기울어져 보임 + img_cx, img_cy = W / 2.0, H / 2.0 + + def _poly_orient(pts): + rect = cv2.minAreaRect(pts) + (rx, ry), (rw, rh), angle = rect + if min(rw, rh) < 1: + return (160, 160, 160), '?', 0.0 + ar = max(rw, rh) / min(rw, rh) + if ar < 1.3: + return (160, 160, 160), '?', 0.0 + # 장축 방향 벡터 + long_angle_deg = angle if rw >= rh else angle + 90 + lx = np.cos(np.radians(long_angle_deg)) + ly = np.sin(np.radians(long_angle_deg)) + # radial 방향: 이미지 중심 → 폴리곤 중심 + rdx, rdy = rx - img_cx, ry - img_cy + radial_norm = (rdx**2 + rdy**2) ** 0.5 + if radial_norm < 1: + return (160, 160, 160), '?', 0.0 + rdx, rdy = rdx / radial_norm, rdy / radial_norm + # 장축과 radial 방향의 정렬도 (|cos θ|) + cos_sim = abs(lx * rdx + ly * rdy) + # cos_sim → 1: 장축이 radial 방향 = 수직 기둥 + # cos_sim → 0: 장축이 radial ⊥ 방향 = 수평 빔 + if cos_sim > 0.7: + return (255, 80, 0), 'V', cos_sim + else: + return (0, 80, 255), 'H', cos_sim + + orient_map = {} # orig_idx → 'V'/'H'/'?' + print("\n [폴리곤 수직/수평 분류]") + for orig_idx, pts in zip(keep_idx, polygons): + color, orient, cos_sim = _poly_orient(pts) + orient_map[orig_idx] = orient + print(f" poly {orig_idx:>2d}: {orient} cos_sim={cos_sim:.3f}") + overlay = img.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.25, img, 0.75, 0, img) + cv2.polylines(img, [pts], True, color, 2) + cx = int(pts[:, 0].mean()) + cy = int(pts[:, 1].mean()) + label = f"{orig_idx}{orient}" + cv2.putText(img, label, (cx, cy), + cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 5) + cv2.putText(img, label, (cx, cy), + cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2) + + # 모든 폴리곤 합친 마스크 → skeleton + merged = np.zeros((H, W), dtype=np.uint8) + for pts in polygons: + cv2.fillPoly(merged, [pts], 255) + # 미세한 픽셀 overlap으로 다른 구조가 합쳐지는 것 방지: 작은 erosion + # 5px dilation으로 인접 폴리곤 gap 브리지 후 3px erosion으로 복원 + merged_dilated = cv2.dilate(merged, np.ones((11, 11), dtype=np.uint8)) + merged_eroded = cv2.erode(merged_dilated, np.ones((7, 7), dtype=np.uint8)) + skel = cv2.ximgproc.thinning(merged_eroded) + + # skeleton 두껍게 (시인성) + skel_thick = cv2.dilate(skel, np.ones((5, 5), dtype=np.uint8)) + img[skel_thick > 0] = (0, 0, 0) + + # spur pruning: 짧은 dead-end arm 제거 → 진짜 branch만 남음 + _box3 = np.ones((3, 3), dtype=np.float32) + + def _prune_spurs(sk, min_arm_px): + """branch pixel 제거 후 arm component 분석, 짧은 arm 제거.""" + s_i = (sk > 0).astype(np.int32) + sp2 = np.pad(s_i, 1, mode='constant') + h2, w2 = s_i.shape + n2 = [sp2[0:h2, 0:w2], sp2[0:h2, 1:w2+1], sp2[0:h2, 2:w2+2], + sp2[1:h2+1, 2:w2+2], sp2[2:h2+2, 2:w2+2], sp2[2:h2+2, 1:w2+1], + sp2[2:h2+2, 0:w2], sp2[1:h2+1, 0:w2]] + cn2 = np.zeros((h2, w2), dtype=np.int32) + for i2 in range(8): + cn2 += (1 - n2[i2]) * n2[(i2 + 1) % 8] + cn2 *= s_i + arm_sk = sk.copy() + arm_sk[cn2 >= 3] = 0 # branch pixel 제거 + n_arm, arm_lbl, arm_stats, arm_cents = cv2.connectedComponentsWithStats(arm_sk) + sk_out = sk.copy() + spur_cs = [] + for a_id in range(1, n_arm): + if arm_stats[a_id, cv2.CC_STAT_AREA] < min_arm_px: + sk_out[arm_lbl == a_id] = 0 + spur_cs.append((int(arm_cents[a_id, 0]), int(arm_cents[a_id, 1]))) + return sk_out, spur_cs + + skel_pruned, spur_centroids = _prune_spurs(skel, min_arm_px=30) + print(f" spur pruning: {len(spur_centroids)}개 arm 제거 (min_arm_px=30)") + + skel_pf = (skel_pruned > 0).astype(np.float32) + nbr_p = cv2.filter2D(skel_pf, cv2.CV_32F, _box3) - skel_pf + deg_p = (nbr_p * skel_pf).astype(np.int32) + + branch_mask_global = ((deg_p >= 3) & (skel_pruned > 0)).astype(np.uint8) * 255 + endpoint_mask_global = ((deg_p == 1) & (skel_pruned > 0)).astype(np.uint8) * 255 + + def cluster_centroids(mask, cluster_radius): + if not mask.any(): + return [] + dilated = cv2.dilate(mask, np.ones((cluster_radius, cluster_radius), np.uint8)) + num, _, _, cents = cv2.connectedComponentsWithStats(dilated) + return [(int(cents[i, 0]), int(cents[i, 1])) for i in range(1, num)] + + def extreme_endpoint_pair(eps): + if len(eps) < 2: + return None + best, best_d2 = None, -1 + for i in range(len(eps)): + for j in range(i + 1, len(eps)): + dx = eps[i][0] - eps[j][0] + dy = eps[i][1] - eps[j][1] + d2 = dx * dx + dy * dy + if d2 > best_d2: + best_d2 = d2 + best = (eps[i], eps[j]) + return best + + # 그룹핑 = pruned skeleton의 connected component + num_comp, comp_labels = cv2.connectedComponents(skel_pruned) + + # 진단: 각 polygon이 어느 component에 속하는지 + poly_idx→cids 역매핑 빌드 + print(f"\n [Diagnostic] skeleton components: {num_comp - 1}") + poly_to_cids: dict = {} + for i, pts in enumerate(polygons): + poly_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(poly_mask, [pts], 255) + in_skel = (skel_pruned > 0) & (poly_mask > 0) + if in_skel.any(): + cids = sorted(set(int(c) for c in np.unique(comp_labels[in_skel]) if c > 0)) + poly_to_cids[i] = cids + else: + print(f" poly {i:>2d}: (skeleton 없음 — 너무 작거나 erosion됨)") + + # 역매핑: component → polygon 목록 + cid_to_polys: dict = {} + for poly_i, cids in poly_to_cids.items(): + for cid in cids: + cid_to_polys.setdefault(cid, []).append(poly_i) + + print(f"\n [그룹 → 폴리곤 목록] (전체 {len(cid_to_polys)}그룹)") + for cid in sorted(cid_to_polys.keys()): + n_px = int((comp_labels == cid).sum()) + polys = cid_to_polys[cid] + orients = [f"{pi}({orient_map.get(pi, '?')})" for pi in polys] + has_v = any(orient_map.get(pi) == 'V' for pi in polys) + has_h = any(orient_map.get(pi) == 'H' for pi in polys) + tag = " ★라멘?" if (has_v and has_h) else "" + print(f" Comp {cid:>2d} ({n_px:>5d}px): {orients}{tag}") + print() + + # subdetect_polys가 지정된 경우, 해당 poly들이 속한 component만 sub-detect + do_subdetect = subdetect or bool(subdetect_polys) + target_cids: set = set() + if subdetect_polys: + for pi in subdetect_polys: + for cid_ in poly_to_cids.get(pi, []): + target_cids.add(cid_) + print(f" sub-detect 대상 그룹: {sorted(target_cids)} (poly {sorted(subdetect_polys)} 기준)\n") + + # cid → 해당 component에 속하는 polygon pts 목록 (subdetect용) + comp_to_polys: dict = {} + if do_subdetect: + for pts in polygons: + pmask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(pmask, [pts], 255) + in_s = (skel_pruned > 0) & (pmask > 0) + if in_s.any(): + for cid_ in set(int(c) for c in comp_labels[in_s] if c > 0): + comp_to_polys.setdefault(cid_, []).append(pts) + + total_branches = 0 + total_endpoints = 0 + total_groups = 0 + total_lines = 0 + + for cid in range(1, num_comp): + comp_mask = (comp_labels == cid) + if int(comp_mask.sum()) < 50: + continue + total_groups += 1 + + comp_skel = (skel_pruned * comp_mask.astype(np.uint8)).astype(np.uint8) + comp_branch = cv2.bitwise_and(branch_mask_global, comp_skel) + comp_endpoint = cv2.bitwise_and(endpoint_mask_global, comp_skel) + + branches = cluster_centroids(comp_branch, 7) + endpoints = cluster_centroids(comp_endpoint, 5) + + # Virtual endpoints: endpoint 3개인 component에만 적용 (라멘에서 1개 누락된 경우). + # 알고리즘: + # 1) skeleton bbox 계산 → 4 변 (top/bottom/left/right) + # 2) 각 endpoint를 가장 가까운 변에 배정 + # 3) 4 변 중 endpoint 없는 "빠진 방향" 찾음 + # 4) 빠진 방향에 따라 branch 중 극값(가장 X/Y가 작거나 큰)을 virtual endpoint로 + virtual_endpoints = [] + if len(endpoints) == 3 and branches: + ys_arr, xs_arr = np.where(comp_skel > 0) + if len(ys_arr) > 0: + xmin, xmax = int(xs_arr.min()), int(xs_arr.max()) + ymin, ymax = int(ys_arr.min()), int(ys_arr.max()) + # 각 endpoint를 가장 가까운 변에 배정 + occupied = set() + for ex, ey in endpoints: + d_top = ey - ymin + d_bot = ymax - ey + d_left = ex - xmin + d_right = xmax - ex + side, _ = min( + (('top', d_top), ('bottom', d_bot), + ('left', d_left), ('right', d_right)), + key=lambda x: x[1]) + occupied.add(side) + # 빠진 방향들 + all_sides = {'top', 'bottom', 'left', 'right'} + missing_sides = all_sides - occupied + # 각 빠진 방향마다 극값 branch 찾기 + for ms in missing_sides: + if ms == 'right': + target = max(branches, key=lambda b: b[0]) + elif ms == 'left': + target = min(branches, key=lambda b: b[0]) + elif ms == 'top': + target = min(branches, key=lambda b: b[1]) + elif ms == 'bottom': + target = max(branches, key=lambda b: b[1]) + too_close = any((target[0] - px) ** 2 + (target[1] - py) ** 2 < 80 ** 2 + for px, py in endpoints + virtual_endpoints) + if not too_close: + virtual_endpoints.append(target) + print(f" virtual endpoint: missing={list(missing_sides)}, " + f"added={virtual_endpoints}") + + # 통합 endpoint 리스트 (real + virtual) + endpoints_all = endpoints + virtual_endpoints + total_branches += len(branches) + total_endpoints += len(endpoints_all) + + # ── 새 알고리즘: 최상단 노드 기반 topology ────────────────────── + if len(endpoints_all) < 2: + pass + else: + def node_dist(a, b): + return ((a[0]-b[0])**2 + (a[1]-b[1])**2) ** 0.5 + + # 1. 최상단 노드 T (y 최솟값) + top_i = min(range(len(endpoints_all)), + key=lambda i: (endpoints_all[i][1], endpoints_all[i][0])) + T = endpoints_all[top_i] + other_idx = [i for i in range(len(endpoints_all)) if i != top_i] + + connected_idx = {top_i} + lines_to_draw = [] + + if len(other_idx) == 1: + j = other_idx[0] + lines_to_draw.append((T, endpoints_all[j])) + connected_idx.add(j) + longer_i = j + else: + # 2. Left(min x), Right(max x) → T에서 이 둘에만 연결 + left_i = min(other_idx, key=lambda i: endpoints_all[i][0]) + right_i = max(other_idx, key=lambda i: endpoints_all[i][0]) + lines_to_draw.append((T, endpoints_all[left_i])) + lines_to_draw.append((T, endpoints_all[right_i])) + connected_idx.update([left_i, right_i]) + + # 3. 긴 쪽 결정 + d_l = node_dist(T, endpoints_all[left_i]) + d_r = node_dist(T, endpoints_all[right_i]) + longer_i = left_i if d_l >= d_r else right_i + + # 4. 체이닝: 긴 쪽 끝점에서 y+ 방향 미연결 노드로 계속 연결 + current_i = longer_i + while True: + cur = endpoints_all[current_i] + below = [(i, node_dist(cur, endpoints_all[i])) + for i in range(len(endpoints_all)) + if i not in connected_idx and endpoints_all[i][1] > cur[1]] + if not below: + break + next_i = min(below, key=lambda x: x[1])[0] + lines_to_draw.append((cur, endpoints_all[next_i])) + connected_idx.add(next_i) + current_i = next_i + + for p1, p2 in lines_to_draw: + cv2.line(img, p1, p2, (0, 255, 255), 5) + total_lines += 1 + print(f" line {p1}→{p2}: {node_dist(p1,p2):.0f}px") + + # 노드 표시 (line 위에 올라오게) + for cx, cy in branches: + cv2.circle(img, (cx, cy), 14, (0, 220, 0), 3) # 초록 중간 원 (진짜 T/X junction) + for cx, cy in endpoints: + cv2.circle(img, (cx, cy), 18, (0, 255, 0), 4) + # virtual endpoint = 오렌지 원 + for cx, cy in virtual_endpoints: + cv2.circle(img, (cx, cy), 18, (0, 140, 255), 4) + + should_sub = do_subdetect and cid in comp_to_polys + if should_sub and target_cids: + should_sub = cid in target_cids + if should_sub: + n = _subdetect_group(img, comp_to_polys[cid]) + print(f" 그룹 {cid}: sub-detect {n}개") + + # 제거된 spur arm centroids = 분홍 작은 원 + for cx, cy in spur_centroids: + cv2.circle(img, (cx, cy), 6, (180, 120, 180), 2) + + print(f" {len(polygons)} polygons, skeleton 픽셀 {int((skel > 0).sum())}") + print(f" 그룹(skeleton component) {total_groups}개") + print(f" 분기점 {total_branches}개, 끝점 {total_endpoints}개, 연결선 {total_lines}개") + + # 그룹 시각화: comp별 색상으로 폴리곤 표시 (원본 이미지 위) + _PALETTE = [ + (255, 80, 80), ( 80, 220, 80), ( 80, 120, 255), (220, 200, 50), + (220, 80, 220), ( 60, 220, 220), (255, 160, 60), (100, 200, 120), + (180, 80, 255), (255, 140, 100), + ] + _poly2comp = {pi: cid for cid, pis in cid_to_polys.items() for pi in pis} + _idx2pos = {orig: i for i, orig in enumerate(keep_idx)} + + img = img_orig.copy() + for orig_idx, pts in zip(keep_idx, polygons): + cid = _poly2comp.get(orig_idx) + col = _PALETTE[(cid - 1) % len(_PALETTE)] if cid else (140, 140, 140) + ov = img.copy() + cv2.fillPoly(ov, [pts], col) + cv2.addWeighted(ov, 0.35, img, 0.65, 0, img) + cv2.polylines(img, [pts], True, col, 3) + cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean()) + cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4) + cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2) + + for cid, pis in cid_to_polys.items(): + col = _PALETTE[(cid - 1) % len(_PALETTE)] + valid = [_idx2pos[pi] for pi in pis if pi in _idx2pos] + if not valid: + continue + cx = int(sum(int(polygons[i][:, 0].mean()) for i in valid) / len(valid)) + cy = int(sum(int(polygons[i][:, 1].mean()) for i in valid) / len(valid)) + cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 7) + cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, col, 3) + + h, w = img.shape[:2] + scale = min(1.0, 4096 / max(h, w)) + if scale < 1.0: + img = cv2.resize(img, (int(w * scale), int(h * scale))) + + output_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path)) + print(f" → {output_path}") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--image", required=True) + ap.add_argument("--label", required=True) + ap.add_argument("--output", required=True) + ap.add_argument("--polygons", type=str, default="", + help="comma-separated polygon indices (예: '10,11,12,26,27'). 비우면 전체") + ap.add_argument("--class-ids", type=str, default="", + help="포함할 클래스 ID (예: '1' or '1,3'). 비우면 전체") + ap.add_argument("--branch-radius", type=int, default=10, + help="branch 판정 원 반지름 px (기본 10)") + ap.add_argument("--subdetect", action="store_true", + help="모든 그룹 sub-detect (SAM3 서버 필요)") + ap.add_argument("--subdetect-polys", type=str, default="", + help="지정 polygon이 속한 그룹만 sub-detect (예: '2,27')") + args = ap.parse_args() + selected = None + if args.polygons.strip(): + selected = set(int(x.strip()) for x in args.polygons.split(',') if x.strip()) + class_ids = None + if args.class_ids.strip(): + class_ids = set(int(x.strip()) for x in args.class_ids.split(',') if x.strip()) + subdetect_polys = None + if args.subdetect_polys.strip(): + subdetect_polys = set(int(x.strip()) for x in args.subdetect_polys.split(',') if x.strip()) + render(Path(args.image), Path(args.label), Path(args.output), selected, + branch_radius=args.branch_radius, class_ids=class_ids, + subdetect=args.subdetect, subdetect_polys=subdetect_polys) + + +if __name__ == "__main__": + main() diff --git a/tools/sam3_autolabel.py b/tools/sam3_autolabel.py new file mode 100644 index 0000000..3263b19 --- /dev/null +++ b/tools/sam3_autolabel.py @@ -0,0 +1,777 @@ +""" +SAM 3.1 텍스트 프롬프트 자동 라벨링 +===================================================== +알고리즘: + A = 전철주 프롬프트(prompts/pole.txt) 검출 → 높이 > min_pole_h 필터 + B = 철로 프롬프트(prompts/rail.txt) 검출 → zone mask 생성 + 결과 = A 중 B zone 내/근처에 있는 것만 유지 + +사용법: + python tools/sam3_autolabel.py --input data/역사이미지/slope/ + python tools/sam3_autolabel.py --input data/역사이미지/ --output output/autolabel/ + +사전 조건: + - start_server.bat 실행 (SAM 3.1 서버) +""" +import argparse +import base64 +from pathlib import Path + +import cv2 +import numpy as np +import requests + +SAM3_SERVER = "http://localhost:8000" +SAM3_MODEL_ID = "segment_anything_3" + +_PROMPT_DIR = Path(__file__).parent.parent / "prompts" + +CLASS_NAMES = ["catenary_pole", "bracket"] +_CLS_COLORS = [(0, 200, 255), (255, 130, 0)] + + +def _load_prompt(filename: str, default: str) -> str: + path = _PROMPT_DIR / filename + if path.exists(): + lines = [l.strip() for l in path.read_text(encoding="utf-8").splitlines() + if l.strip() and not l.startswith("#")] + if lines: + return ", ".join(lines) + return default + + +def _get_prompts() -> tuple: + """(pole_prompt, rail_prompt, bracket_prompt) 반환.""" + pole = _load_prompt("pole.txt", + "railway catenary pole, overhead line support pole, catenary mast") + rail = _load_prompt("rail.txt", "railroad, railway") + bracket = _load_prompt("bracket.txt", + "catenary pole top, pole cross arm, horizontal cantilever beam, bracket arm") + return pole, rail, bracket + + +# ── 이미지 인코딩 ────────────────────────────────────────────────────────────── +def encode_image(image_bgr: np.ndarray, max_size: int = 2048) -> tuple: + h, w = image_bgr.shape[:2] + scale = 1.0 + if max_size > 0 and max(h, w) > max_size: + scale = max_size / max(h, w) + image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale))) + _, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) + return base64.b64encode(buf).decode("utf-8"), scale + + +# ── SAM 3.1 텍스트 프롬프트 세그멘테이션 ────────────────────────────────────── +def sam3_text_segment(image_bgr: np.ndarray, text_prompt: str, + conf_threshold: float = 0.25, + sam_max_size: int = 2048) -> list: + b64, scale = encode_image(image_bgr, sam_max_size) + payload = { + "model": SAM3_MODEL_ID, + "image": b64, + "params": { + "text_prompt": text_prompt, + "conf_threshold": conf_threshold, + "show_masks": True, + "show_boxes": False, + }, + } + try: + r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=180) + r.raise_for_status() + resp = r.json() + if not resp.get("success"): + print(f" SAM3 오류: {resp.get('error', {}).get('message', 'unknown')}") + return [] + shapes = resp.get("data", {}).get("shapes", []) + shapes = [s if isinstance(s, dict) else s.dict() for s in shapes] + if scale < 1.0: + inv = 1.0 / scale + for s in shapes: + if s.get("shape_type") == "polygon": + s["points"] = [[x * inv, y * inv] for x, y in s["points"]] + return shapes + except Exception as e: + print(f" SAM3 오류: {e}") + return [] + + +def shapes_to_pairs(shapes: list) -> list: + """polygon shapes -> (det, shape) pairs. det = (label, score, x1,y1,x2,y2).""" + pairs = [] + for s in shapes: + if s.get("shape_type") != "polygon": + continue + pts = s.get("points", []) + if len(pts) < 3: + continue + xs = [p[0] for p in pts] + ys = [p[1] for p in pts] + det = (s.get("label", ""), float(s.get("score", 0.0)), + min(xs), min(ys), max(xs), max(ys)) + pairs.append((det, s)) + return pairs + + +# ── NMS ─────────────────────────────────────────────────────────────────────── +def nms(pairs: list, iou_thresh: float = 0.5) -> list: + if not pairs: + return [] + boxes = np.array([[x1, y1, x2, y2] for (_, _, x1, y1, x2, y2), _ in pairs]) + scores = np.array([conf for (_, conf, *_), _ in pairs]) + order = scores.argsort()[::-1] + keep = [] + while len(order): + i = order[0] + keep.append(i) + if len(order) == 1: + break + xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0]) + yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1]) + xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2]) + yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3]) + inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1) + a_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1]) + a_j = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1]) + iou = inter / (a_i + a_j - inter + 1e-6) + order = order[1:][iou < iou_thresh] + return [pairs[i] for i in keep] + + +# ── 높이 필터 ───────────────────────────────────────────────────────────────── +def filter_by_size(pairs: list, min_h: int, debug: bool = False) -> list: + """전철주 높이 필터: det bbox 높이 >= min_h, SAM mask도 충분히 높고 가늘어야 함.""" + kept = [] + for (label, conf, x1, y1, x2, y2), shape in pairs: + det_h = y2 - y1 + if det_h < min_h: + if debug: + print(f" DROP [{label} {conf:.2f}] h={det_h:.0f} (최소:{min_h})") + continue + pts = shape.get("points", []) + if pts: + ys = [p[1] for p in pts] + xs = [p[0] for p in pts] + mh = max(ys) - min(ys) + mw = max(xs) - min(xs) + if mh < min_h * 0.7: + if debug: + print(f" DROP [{label} {conf:.2f}] sam_h={mh:.0f} (최소:{min_h*0.7:.0f})") + continue + if mw > mh * 3: + if debug: + print(f" DROP [{label} {conf:.2f}] 너무 넓음 w{mw:.0f}>h{mh:.0f}x3") + continue + kept.append(((label, conf, x1, y1, x2, y2), shape)) + return kept + + +# ── 철로 존 마스크 ───────────────────────────────────────────────────────────── +def build_zone_mask(shapes: list, H: int, W: int, margin: int, + gap_close: int = 500, + lateral_expand_ratio: float = 0.0) -> np.ndarray: + """철로 polygon → binary mask → 수평 갭 보정(beam 차폐) → margin px 팽창. + lateral_expand_ratio > 0 이면 추가로 좌우 수평 팽창 (가장자리 기둥 포착).""" + mask = np.zeros((H, W), dtype=np.uint8) + for s in shapes: + pts = np.array(s["points"], dtype=np.int32) + cv2.fillPoly(mask, [pts], 255) + if gap_close > 0 and mask.any(): + h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (gap_close, 30)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, h_kernel) + if margin > 0 and mask.any(): + r = margin + yy, xx = np.ogrid[-r:r + 1, -r:r + 1] + kernel = (xx * xx + yy * yy <= r * r).astype(np.uint8) + mask = cv2.dilate(mask, kernel) + if lateral_expand_ratio > 0 and mask.any(): + lat = int(W * lateral_expand_ratio) + if lat > 0: + lat_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (lat * 2 + 1, 1)) + mask = cv2.dilate(mask, lat_kernel) + return mask + + +def pole_in_zone(pole_det: tuple, zone_mask: np.ndarray) -> bool: + """전철주 bbox 중심 열(column)의 y1~y2 범위 중 zone_mask와 교차하면 True. + 기둥은 철로 근처에 기저부를 두고 위/아래로 뻗으므로 + 중심 한 점 대신 세로 전체 범위로 판단.""" + _, _, x1, y1, x2, y2 = pole_det + H, W = zone_mask.shape + cx = int(max(0, min((x1 + x2) / 2, W - 1))) + ys = int(max(0, y1)) + ye = int(min(H, y2 + 1)) + return bool(zone_mask[ys:ye, cx].any()) + + +def _filter_pole_has_beam(pole_pairs: list, beam_pairs: list) -> list: + """A ∩ C: 전철주 bbox와 겹치는 Beam/Arm이 하나라도 있는 기둥만 유지. + Beam/Arm bbox와 기둥 bbox의 IoU > 0 (단순 교차)으로 판단.""" + beam_boxes = [(x1, y1, x2, y2) + for (_, _, x1, y1, x2, y2), _ in beam_pairs] + + def overlaps_any_beam(pole_det): + _, _, px1, py1, px2, py2 = pole_det + for bx1, by1, bx2, by2 in beam_boxes: + ix1 = max(px1, bx1); iy1 = max(py1, by1) + ix2 = min(px2, bx2); iy2 = min(py2, by2) + if ix2 > ix1 and iy2 > iy1: + return True + return False + + return [(d, s) for d, s in pole_pairs if overlaps_any_beam(d)] + + +def filter_long_beams(beam_pairs: list, rail_shapes: list, + min_beam_width: int = 400, + ratio_thresh: float = 2.0, + rail_x_margin: int = 500, + debug: bool = False) -> list: + """역사형 portal frame의 긴 빔만 유지. + 조건: + 1) bbox 폭 >= min_beam_width (가로등은 좁음) + 2) 폭/높이 >= ratio_thresh (가로형이어야 함) + 3) 빔 x-범위가 검출된 rail x-범위 ±rail_x_margin 내 (도로 차단) + """ + if not rail_shapes: + if debug: + print(f" [Beam filter] rail 미검출 → 빔 zone 보강 비활성") + return [] + + rail_xs = [] + for s in rail_shapes: + if s.get("shape_type") == "polygon": + xs = [p[0] for p in s["points"]] + rail_xs.extend([min(xs), max(xs)]) + if not rail_xs: + return [] + rail_x_min, rail_x_max = min(rail_xs), max(rail_xs) + + kept = [] + for (label, conf, x1, y1, x2, y2), shape in beam_pairs: + bw = x2 - x1 + bh = max(1, y2 - y1) + ratio = bw / bh + if bw < min_beam_width: + if debug: + print(f" DROP beam [{label} {conf:.2f}] w={bw:.0f} (최소:{min_beam_width})") + continue + if ratio < ratio_thresh: + if debug: + print(f" DROP beam [{label} {conf:.2f}] w/h={ratio:.2f} (최소:{ratio_thresh})") + continue + if x2 < rail_x_min - rail_x_margin or x1 > rail_x_max + rail_x_margin: + if debug: + print(f" DROP beam [{label} {conf:.2f}] rail x-범위 벗어남") + continue + kept.append(((label, conf, x1, y1, x2, y2), shape)) + return kept + + +def add_beams_to_zone(zone_mask: np.ndarray, beam_pairs: list, + downward_extend: int = 400, side_margin: int = 100) -> np.ndarray: + """필터된 긴 빔의 bbox를 zone_mask에 추가. + 빔 위치 + 아래쪽 downward_extend px → 레일이 빔에 차폐된 구간 보강.""" + H, W = zone_mask.shape + for (_, _, x1, y1, x2, y2), _ in beam_pairs: + bx1 = max(0, int(x1) - side_margin) + bx2 = min(W, int(x2) + side_margin) + by1 = max(0, int(y1)) + by2 = min(H, int(y2) + downward_extend) + zone_mask[by1:by2, bx1:bx2] = 255 + return zone_mask + + +def _mask_to_largest_polygon(mask: np.ndarray, min_area: int = 50) -> list: + """이진 마스크에서 가장 큰 외곽선 → polygon points.""" + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return None + largest = max(contours, key=cv2.contourArea) + if cv2.contourArea(largest) < min_area: + return None + epsilon = 0.002 * cv2.arcLength(largest, True) + approx = cv2.approxPolyDP(largest, epsilon, True) + return [[float(p[0][0]), float(p[0][1])] for p in approx] + + +def split_pole_arm(pts: list, H: int, W: int, + arm_min_factor: float = 0.5) -> tuple: + """전철주 polygon을 세로 기둥(column)과 가로 암(arm)으로 분리. + + 방법 (row-width 분석): + 1. polygon → mask + 2. 각 가로 row에서 mask width 측정 + 3. 좁은 row 절반의 중앙값 = column 폭 추정 + 4. column 중심 x = 좁은 row들의 mask 중점 중앙값 + 5. column mask = column 중심 ± column 폭/2 범위 내 mask + 6. arm mask = column 범위 외 mask (충분히 가로로 뻗은 row만) + + arm_min_factor: arm으로 인정할 row의 최소 width = column 폭 × factor + 반환: (column_pts, arm_pts | None) + """ + polygon = np.array(pts, dtype=np.int32) + if len(polygon) < 3: + return pts, None + + mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(mask, [polygon], 255) + + x1 = max(0, int(polygon[:, 0].min())) + x2 = min(W - 1, int(polygon[:, 0].max())) + y1 = max(0, int(polygon[:, 1].min())) + y2 = min(H - 1, int(polygon[:, 1].max())) + sub = mask[y1:y2 + 1, x1:x2 + 1] + + row_widths = (sub > 0).sum(axis=1).astype(np.float32) + valid = row_widths > 0 + if valid.sum() < 5: + return pts, None + + # column 폭 = 좁은 row들의 중앙값 (전체의 좁은 절반) + sorted_widths = np.sort(row_widths[valid]) + column_w = float(np.median(sorted_widths[:max(1, len(sorted_widths) // 2)])) + if column_w < 5: + return pts, None + + # column 중심 x 추정 — 좁은 row들의 mask 중점 중앙값 + narrow_threshold = column_w * 1.3 + cx_list = [] + for i, w in enumerate(row_widths): + if 0 < w <= narrow_threshold: + row_px = np.where(sub[i] > 0)[0] + if len(row_px) > 0: + cx_list.append((row_px[0] + row_px[-1]) / 2) + if not cx_list: + return pts, None + column_cx = int(np.median(cx_list)) + x1 + column_half = max(4, int(column_w * 0.7)) + + col_x_min = max(0, column_cx - column_half) + col_x_max = min(W, column_cx + column_half + 1) + + column_mask = np.zeros_like(mask) + column_mask[:, col_x_min:col_x_max] = mask[:, col_x_min:col_x_max] + + arm_mask = mask.copy() + arm_mask[:, col_x_min:col_x_max] = 0 + + # arm: 가로로 충분히 뻗은 row만 유지 + arm_min_w = max(20, int(column_w * arm_min_factor)) + arm_row_widths = (arm_mask > 0).sum(axis=1) + for y in range(H): + if arm_row_widths[y] < arm_min_w: + arm_mask[y, :] = 0 + + column_pts = _mask_to_largest_polygon(column_mask) + arm_pts = _mask_to_largest_polygon(arm_mask) if arm_mask.any() else None + + # arm이 너무 작으면 폐기 (column 면적의 5% 미만) + if arm_pts and column_pts: + c_area = cv2.contourArea(np.array(column_pts, dtype=np.int32)) + a_area = cv2.contourArea(np.array(arm_pts, dtype=np.int32)) + if a_area < c_area * 0.05: + arm_pts = None + + return (column_pts or pts), arm_pts + + +# ── 후처리: 합치기 + Skeleton 분리 (column / beam) ──────────────────────────── +def _compute_skeleton(mask: np.ndarray) -> np.ndarray: + """이진 마스크 → 1px 두께 skeleton (opencv-contrib-python의 ximgproc.thinning).""" + return cv2.ximgproc.thinning(mask) + + +def _find_branch_points(skel: np.ndarray) -> np.ndarray: + """Skeleton의 분기점(8-neighbor degree ≥ 3) 마스크 반환.""" + skel_bin = (skel > 0).astype(np.uint8) + # 8-neighbor 합 = 3x3 box filter * 9 - center + box_kernel = np.ones((3, 3), dtype=np.float32) + nbr = cv2.filter2D(skel_bin.astype(np.float32), cv2.CV_32F, box_kernel) - skel_bin + branch = (skel_bin == 1) & (nbr >= 3) + return branch.astype(np.uint8) * 255 + + +def _compute_degree_map(skel: np.ndarray) -> np.ndarray: + """Skeleton 각 픽셀의 8-neighbor degree (자기 자신 제외).""" + skel_bin = (skel > 0).astype(np.float32) + box = np.ones((3, 3), dtype=np.float32) + nbr = cv2.filter2D(skel_bin, cv2.CV_32F, box) - skel_bin + deg = (nbr * skel_bin).astype(np.int32) + return deg + + +def _split_one_component(mask: np.ndarray) -> tuple: + """하나의 connected component를 column / beam mask로 분리 (위상 규칙). + + 라멘 구조 위상: + - beam (가로 빔): 양 끝이 branch에 연결 → 원본 free endpoint 0개 + - column (세로 기둥): 한 끝이 branch, 다른 끝이 free → 원본 free endpoint 1개+ + + 위상 규칙이 적용되지 않는 폴리곤(단일 segment, 가로등 'ㅏ', "ㄱ" 등)은 + 라멘이 아니므로 (None, None) 반환 → 호출부에서 폴리곤 삭제. + """ + skel = _compute_skeleton(mask) + if not skel.any(): + return None, None + + # 원본 skeleton의 degree map + deg = _compute_degree_map(skel) + free_endpoint_mask = ((deg == 1) & (skel > 0)) + branch_mask = ((deg >= 3) & (skel > 0)).astype(np.uint8) * 255 + + # branch 제거 → segment 분리 + if branch_mask.any(): + bk = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) + seg_skel = cv2.bitwise_and(skel, cv2.bitwise_not(cv2.dilate(branch_mask, bk))) + else: + # branch가 전혀 없음 = 단일 선/곡선 → 라멘 아님 + return None, None + + num, labels = cv2.connectedComponents(seg_skel) + interior_seeds = [] # free endpoint 0개 segment + edge_seeds = [] # free endpoint 1개+ segment + + for sid in range(1, num): + seg = (labels == sid) + if seg.sum() < 10: + continue + n_free = int((free_endpoint_mask & seg).sum()) + seg_u8 = seg.astype(np.uint8) * 255 + if n_free == 0: + interior_seeds.append(seg_u8) + else: + edge_seeds.append(seg_u8) + + # 라멘이 되려면 interior(beam)와 edge(column)가 모두 있어야 함 + if not interior_seeds or not edge_seeds: + return None, None + + col_seed = np.zeros_like(mask) + beam_seed = np.zeros_like(mask) + for s in edge_seeds: + col_seed = cv2.bitwise_or(col_seed, s) + for s in interior_seeds: + beam_seed = cv2.bitwise_or(beam_seed, s) + + col_dist = cv2.distanceTransform(255 - col_seed, cv2.DIST_L2, 5) + beam_dist = cv2.distanceTransform(255 - beam_seed, cv2.DIST_L2, 5) + + mask_bool = mask > 0 + col_region = (mask_bool & (col_dist <= beam_dist)).astype(np.uint8) * 255 + beam_region = (mask_bool & (col_dist > beam_dist)).astype(np.uint8) * 255 + return col_region, beam_region + + +def _extract_polygons(mask: np.ndarray, min_area: int = 200, + epsilon_factor: float = 0.002) -> list: + """Mask의 모든 connected component → polygon points 리스트.""" + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + polys = [] + for c in contours: + if cv2.contourArea(c) < min_area: + continue + eps = epsilon_factor * cv2.arcLength(c, True) + approx = cv2.approxPolyDP(c, eps, True) + pts = [[float(p[0][0]), float(p[0][1])] for p in approx] + if len(pts) >= 3: + polys.append(pts) + return polys + + +def skeleton_split_columns_beams(pole_pairs: list, H: int, W: int, + min_component_area: int = 1000, + min_polygon_area: int = 200) -> tuple: + """Pole polygon별로 skeleton split → column / beam 분류된 polygon 추출. + + 반환: (column_polygons, beam_polygons) + """ + if not pole_pairs: + return [], [] + + merged = np.zeros((H, W), dtype=np.uint8) + for _, shape in pole_pairs: + if shape.get("shape_type") == "polygon": + pts = np.array(shape["points"], dtype=np.int32) + cv2.fillPoly(merged, [pts], 255) + if not merged.any(): + return [], [] + + # closing 미적용 — SAM3.1 polygon 형상 그대로 사용 + num, labels = cv2.connectedComponents(merged) + all_col = np.zeros_like(merged) + all_beam = np.zeros_like(merged) + kept = 0 + dropped = 0 + for cid in range(1, num): + comp = (labels == cid).astype(np.uint8) * 255 + if int(comp.sum() / 255) < min_component_area: + continue + col_m, beam_m = _split_one_component(comp) + if col_m is None or beam_m is None: + dropped += 1 + continue + all_col = cv2.bitwise_or(all_col, col_m) + all_beam = cv2.bitwise_or(all_beam, beam_m) + kept += 1 + print(f" [Topo] 라멘 위상 통과: {kept}개, 비라멘 폐기: {dropped}개") + + column_polys = _extract_polygons(all_col, min_area=min_polygon_area) + beam_polys = _extract_polygons(all_beam, min_area=min_polygon_area) + return column_polys, beam_polys + + +# ── 저장 ────────────────────────────────────────────────────────────────────── +def save_yolo_label(label_path: Path, pairs: list, W: int, H: int): + lines = [] + for (label, conf, *_), shape in pairs: + if shape.get("shape_type") != "polygon": + continue + pts = shape.get("points", []) + if len(pts) < 3: + continue + cls_id = shape.get("_cls_id", 0) + coords = " ".join(f"{x/W:.6f} {y/H:.6f}" for x, y in pts) + lines.append(f"{cls_id} {coords}") + label_path.write_text("\n".join(lines), encoding="utf-8") + + +def save_vis(vis_path: Path, image_bgr: np.ndarray, pairs: list, + rail_shapes: list = None, long_beam_pairs: list = None): + vis = image_bgr.copy() + for (label, conf, x1, y1, x2, y2), shape in pairs: + cls_id = shape.get("_cls_id", 0) + color = _CLS_COLORS[cls_id % len(_CLS_COLORS)] + if shape.get("shape_type") == "polygon": + pts = np.array(shape["points"], dtype=np.int32) + overlay = vis.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis) + cv2.polylines(vis, [pts], True, color, 2) + cv2.putText(vis, f"{CLASS_NAMES[cls_id]} {conf:.2f}", + (int(x1), max(0, int(y1) - 5)), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) + # 레일 zone 시각화 (빨간색, 반투명) + if rail_shapes: + for s in rail_shapes: + if s.get("shape_type") == "polygon": + pts = np.array(s["points"], dtype=np.int32) + overlay = vis.copy() + cv2.fillPoly(overlay, [pts], (0, 0, 200)) + cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis) + cv2.polylines(vis, [pts], True, (0, 0, 255), 2) + # 긴 빔 시각화 (자홍색 bbox) + if long_beam_pairs: + for (label, conf, x1, y1, x2, y2), _ in long_beam_pairs: + cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)), + (255, 0, 255), 2) + cv2.putText(vis, f"beam {conf:.2f}", (int(x1), max(0, int(y1) - 5)), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1) + h, w = vis.shape[:2] + if max(h, w) > 2048: + scale = 2048 / max(h, w) + vis = cv2.resize(vis, (int(w * scale), int(h * scale))) + cv2.imencode(vis_path.suffix, vis)[1].tofile(str(vis_path)) + + +# ── 타일 검출 ───────────────────────────────────────────────────────────────── +def _detect_tiled(image_bgr: np.ndarray, prompt: str, conf: float, + tile_size: int, overlap: float) -> list: + H, W = image_bgr.shape[:2] + step = max(1, int(tile_size * (1 - overlap))) + tile_coords = [] + for y0 in range(0, H, step): + for x0 in range(0, W, step): + tile_coords.append((x0, y0, min(x0 + tile_size, W), min(y0 + tile_size, H))) + print(f" 총 {len(tile_coords)}개 타일") + all_pairs = [] + for idx, (tx0, ty0, tx1, ty1) in enumerate(tile_coords, 1): + tile_bgr = image_bgr[ty0:ty1, tx0:tx1] + shapes = sam3_text_segment(tile_bgr, prompt, conf) + for s in shapes: + if s.get("shape_type") == "polygon": + s["points"] = [[x + tx0, y + ty0] for x, y in s["points"]] + all_pairs.extend(shapes_to_pairs(shapes)) + if idx % 5 == 0: + print(f" 타일 {idx}/{len(tile_coords)}, 누적 {len(all_pairs)}개") + return nms(all_pairs) + + +# ── 이미지 처리 ─────────────────────────────────────────────────────────────── +def process_image(image_path: Path, + out_label_dir: Path, out_vis_dir: Path, + sam3_conf: float, + tile_size: int, tile_overlap: float, + min_pole_h: int = 160, + rail_margin: int = 300, + rail_gap_close: int = 500, + beam_min_width: int = 400, + beam_ratio: float = 2.0, + beam_downward: int = 400, + rail_lateral: float = 0.0, + raw_polygons: bool = False) -> int: + + buf = np.fromfile(str(image_path), dtype=np.uint8) + image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image_bgr is None: + print(f" 이미지 로드 실패: {image_path}") + return 0 + H, W = image_bgr.shape[:2] + + pole_prompt, rail_prompt, bracket_prompt = _get_prompts() + use_tile = tile_size > 0 and (W > tile_size or H > tile_size) + + # A: 전철주 검출 + if use_tile: + print(f" [A] 타일 {tile_size}px: '{pole_prompt}'") + pole_pairs = _detect_tiled(image_bgr, pole_prompt, sam3_conf, tile_size, tile_overlap) + else: + print(f" [A] 전철주 검출: '{pole_prompt}'") + pole_pairs = nms(shapes_to_pairs(sam3_text_segment(image_bgr, pole_prompt, sam3_conf))) + + # A 높이 필터 (> min_pole_h) + before = len(pole_pairs) + pole_pairs = filter_by_size(pole_pairs, min_h=min_pole_h, debug=True) + print(f" [A] 높이 필터(h>={min_pole_h}): {before}→{len(pole_pairs)}개") + + for _, shape in pole_pairs: + shape["_cls_id"] = 0 + + # B: 철로 검출 + print(f" [B] 철로 검출: '{rail_prompt}'") + rail_shapes = [s for s in sam3_text_segment(image_bgr, rail_prompt, sam3_conf) + if s.get("shape_type") == "polygon"] + + # [C] Beam/Arm 쿼리는 사용하지 않음 (검출되는 "긴 빔"이 실제로는 레일 일부였음) + + # A ∩ Zone 필터 (Zone = rail + 측면 확장) + if rail_margin >= 0: + if rail_shapes: + zone_mask = build_zone_mask(rail_shapes, H, W, rail_margin, rail_gap_close, + lateral_expand_ratio=rail_lateral) + if rail_lateral > 0: + print(f" [Zone] rail + 측면 {rail_lateral*100:.0f}%") + before = len(pole_pairs) + kept = [] + for d, s in pole_pairs: + if pole_in_zone(d, zone_mask): + kept.append((d, s)) + else: + label, conf, x1, y1, x2, y2 = d + print(f" DROP zone [{label} {conf:.2f}] bbox=({x1:.0f},{y1:.0f})-({x2:.0f},{y2:.0f})") + pole_pairs = kept + print(f" [A∩Zone] 필터(+{rail_margin}px): {before}→{len(pole_pairs)}개") + else: + print(f" ! [B] 철로 미검출 → 전체 제외") + pole_pairs = [] + + # 후처리: raw_polygons=True 면 SAM 원본 폴리곤만 cls=0으로 그대로 저장 + if raw_polygons: + all_pairs = [] + for det, shape in pole_pairs: + shape["_cls_id"] = 0 + shape["label"] = "catenary_pole" + all_pairs.append((det, shape)) + print(f" [Raw] SAM 폴리곤 {len(all_pairs)}개 그대로 저장 (split 미적용)") + else: + # skeleton split (column / beam) + column_polys, beam_polys = skeleton_split_columns_beams(pole_pairs, H, W) + column_pairs = [] + arm_pairs = [] + for poly in column_polys: + xs = [p[0] for p in poly] + ys = [p[1] for p in poly] + det = ("catenary_pole", 1.0, min(xs), min(ys), max(xs), max(ys)) + shape = {"shape_type": "polygon", "points": poly, + "_cls_id": 0, "label": "catenary_pole"} + column_pairs.append((det, shape)) + for poly in beam_polys: + xs = [p[0] for p in poly] + ys = [p[1] for p in poly] + det = ("bracket", 1.0, min(xs), min(ys), max(xs), max(ys)) + shape = {"shape_type": "polygon", "points": poly, + "_cls_id": 1, "label": "bracket"} + arm_pairs.append((det, shape)) + print(f" [Split] column {len(column_pairs)}개, beam {len(arm_pairs)}개") + all_pairs = column_pairs + arm_pairs + if not all_pairs: + return 0 + + save_yolo_label(out_label_dir / (image_path.stem + ".txt"), all_pairs, W, H) + save_vis(out_vis_dir / (image_path.stem + "_vis.jpg"), image_bgr, all_pairs, + rail_shapes, None) + return len(all_pairs) + + +# ── classes.txt 저장 ────────────────────────────────────────────────────────── +def save_classes(out_dir: Path): + (out_dir / "classes.txt").write_text("\n".join(CLASS_NAMES), encoding="utf-8") + + +# ── 메인 ───────────────────────────────────────────────────────────────────── +def main(): + parser = argparse.ArgumentParser(description="SAM 3.1 텍스트 프롬프트 자동 라벨링") + parser.add_argument("--input", required=True, help="이미지 파일 또는 디렉터리") + parser.add_argument("--output", default="output/autolabel") + parser.add_argument("--sam3_conf", type=float, default=0.20) + parser.add_argument("--ext", default=".jpg,.jpeg,.png,.tif") + parser.add_argument("--tile_size", type=int, default=1280, + help="타일 크기 px (0=타일 없음, 기본:1280)") + parser.add_argument("--tile_overlap", type=float, default=0.2) + parser.add_argument("--min_pole_h", type=int, default=160, + help="전철주 최소 높이 px (기본:160, 드론 거리 변화 마진 포함)") + parser.add_argument("--rail_margin", type=int, default=300, + help="철로 mask 팽창 거리 px (-1=비활성, 기본:300)") + parser.add_argument("--rail_gap_close", type=int, default=500, + help="beam 차폐 갭 보정 수평 closing 폭 px (0=비활성, 기본:500)") + parser.add_argument("--beam_min_width", type=int, default=400, + help="긴 빔 최소 폭 px - 가로등/짧은 arm 차단 (기본:400, 역사형만)") + parser.add_argument("--beam_ratio", type=float, default=2.0, + help="긴 빔 폭/높이 최소 비율 (기본:2.0)") + parser.add_argument("--beam_downward", type=int, default=400, + help="빔 zone 아래 확장 px (기본:400)") + parser.add_argument("--rail_lateral", type=float, default=0.0, + help="rail zone 측면(좌우) 확장 비율 (이미지 폭 기준, 예:0.10=10%%)") + parser.add_argument("--raw_polygons", action="store_true", + help="후처리 skeleton split 없이 SAM 원본 폴리곤만 저장") + args = parser.parse_args() + + input_path = Path(args.input) + out_dir = Path(args.output) + out_label = out_dir / "labels" + out_vis = out_dir / "vis" + out_label.mkdir(parents=True, exist_ok=True) + out_vis.mkdir(parents=True, exist_ok=True) + + exts = [e.strip().lower() for e in args.ext.split(",")] + if input_path.is_dir(): + images = sorted([p for p in input_path.iterdir() if p.suffix.lower() in exts], + key=lambda p: p.name) + else: + images = [input_path] + + pole_prompt, rail_prompt, bracket_prompt = _get_prompts() + print(f"처리 대상: {len(images)}장") + print(f"전철주 프롬프트: {pole_prompt}") + print(f"철로 프롬프트: {rail_prompt}") + print(f"타일: {args.tile_size}px 겹침: {args.tile_overlap*100:.0f}%") + print(f"출력: {out_dir}\n") + save_classes(out_dir) + + total = 0 + for i, img_path in enumerate(images, 1): + print(f"[{i}/{len(images)}] {img_path.name}") + total += process_image(img_path, out_label, out_vis, + args.sam3_conf, + args.tile_size, args.tile_overlap, + args.min_pole_h, args.rail_margin, args.rail_gap_close, + args.beam_min_width, args.beam_ratio, args.beam_downward, + args.rail_lateral, args.raw_polygons) + + print(f"\n완료: 총 {total}개 라벨 생성") + print(f"라벨: {out_label}") + print(f"시각화: {out_vis}") + + +if __name__ == "__main__": + main() diff --git a/tools/sam3_batch_label.py b/tools/sam3_batch_label.py new file mode 100644 index 0000000..120b288 --- /dev/null +++ b/tools/sam3_batch_label.py @@ -0,0 +1,278 @@ +""" +SAM3 배치 자동 레이블링 파이프라인 +=================================== +SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 → +X-AnyLabeling 호환 JSON annotation 파일 자동 생성 + +사용법: + # 기본 (sample/rail 폴더, 서버 localhost:8000) + python tools/sam3_batch_label.py + + # 폴더 지정 + python tools/sam3_batch_label.py --input sample/rail --output output/labels + + # conf 조정 (낮출수록 더 많이 검출, 오탐도 증가) + python tools/sam3_batch_label.py --conf 0.20 + +SAM3 서버 실행: + cd X-AnyLabeling-Server + uvicorn app.main:app --host 0.0.0.0 --port 8000 +""" + +import argparse +import base64 +import json +import sys +from pathlib import Path + +import cv2 +import numpy as np +import requests + +SAM3_SERVER = "http://localhost:8000" +MODEL_ID = "segment_anything_3" + +# 검출 대상 + 한국어 레이블 매핑 +TARGETS = { + "pole": "전철주_세로", + "catenary arm": "전철주_가로", + "junction box": "통신박스", + "electrical box": "전기박스", + "fence": "펜스", +} + +# 시각화 색상 (BGR) +COLORS = { + "전철주_세로": (0, 200, 255), + "전철주_가로": (0, 100, 255), + "통신박스": (255, 180, 0), + "전기박스": (100, 255, 200), + "펜스": (0, 255, 100), +} + + +def encode_image(image_bgr: np.ndarray) -> str: + _, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도 + return base64.b64encode(buf).decode("utf-8") + + +def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list: + """SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환.""" + payload = { + "model": MODEL_ID, + "image": encode_image(image_bgr), + "params": { + "text_prompt": text_prompt, + "show_masks": True, + "show_boxes": False, + "conf_threshold": conf, + "epsilon_factor": 0.002, + }, + } + try: + resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60) + resp.raise_for_status() + data = resp.json() + if data.get("success"): + return data.get("data", {}).get("shapes", []) + except Exception as e: + print(f" [ERROR] SAM3 호출 실패: {e}") + return [] + + +def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict: + """이미지 1장: 모든 클래스 검출 → annotation dict 반환.""" + # 한글 경로 대응: np.fromfile + imdecode + buf = np.fromfile(str(image_path), dtype=np.uint8) + image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image_bgr is None: + return {} + h, w = image_bgr.shape[:2] + + all_shapes = [] + class_counts = {} + + for eng_label, kor_label in TARGETS.items(): + shapes = sam3_text_predict(image_bgr, eng_label, conf) + # label 필드를 한국어로 교체 + for s in shapes: + s["label"] = kor_label + all_shapes.extend(shapes) + if shapes: + class_counts[kor_label] = len(shapes) + + # X-AnyLabeling JSON 형식 + annotation = { + "version": "3.3.9", + "flags": {}, + "shapes": [ + { + "label": s["label"], + "points": s["points"], + "group_id": None, + "shape_type": s.get("shape_type", "polygon"), + "flags": {}, + "score": round(float(s.get("score", 0)), 4), + } + for s in all_shapes + ], + "imagePath": image_path.name, + "imageData": None, + "imageHeight": h, + "imageWidth": w, + } + + # 시각화 + if vis_dir is not None: + vis = draw_vis(image_bgr, all_shapes) + vis_path = vis_dir / f"{image_path.stem}_vis.jpg" + cv2.imencode(".jpg", vis)[1].tofile(str(vis_path)) + + return annotation, class_counts + + +def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray: + vis = image_bgr.copy() + overlay = image_bgr.copy() + + for shape in shapes: + pts = np.array(shape["points"], dtype=np.int32) + if len(pts) > 1 and np.array_equal(pts[0], pts[-1]): + pts = pts[:-1] + label = shape.get("label", "unknown") + color = COLORS.get(label, (200, 200, 200)) + cv2.fillPoly(overlay, [pts], color) + cv2.polylines(vis, [pts], True, color, 2) + + # 레이블 텍스트 + cx = int(pts[:, 0].mean()) + cy = int(pts[:, 1].mean()) + cv2.putText(vis, label, (cx - 30, cy), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) + + cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis) + + # 범례 + y = 25 + for kor, color in COLORS.items(): + cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1) + cv2.putText(vis, kor, (28, y), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + y += 20 + + return vis + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", default="sample/rail") + parser.add_argument("--output", default="output/sam3_labels", + help="JSON annotation 저장 폴더") + parser.add_argument("--vis", default="output/sam3_vis", + help="시각화 이미지 저장 폴더 ('' 로 비활성화)") + parser.add_argument("--conf", type=float, default=0.20, + help="SAM3 confidence threshold (기본 0.20)") + parser.add_argument("--classes", nargs="+", + help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm") + parser.add_argument("--server", default="http://localhost:8000") + args = parser.parse_args() + + global SAM3_SERVER, TARGETS + SAM3_SERVER = args.server + + if args.classes: + key_map = { + "catenary_pole": ("catenary pole", "전철주_세로"), + "concrete_pole": ("concrete pole", "전철주_세로"), + "catenary_arm": ("catenary arm", "전철주_가로"), + "junction_box": ("junction box", "통신박스"), + "electrical_box": ("electrical box", "전기박스"), + "fence": ("fence", "펜스"), + } + TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map} + + # 서버 확인 + try: + r = requests.get(f"{SAM3_SERVER}/health", timeout=5) + print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}") + except Exception: + print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}") + print(" 서버를 먼저 실행하세요:") + print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000") + sys.exit(1) + + # 입력 + input_path = Path(args.input) + if input_path.is_file(): + images = [input_path] + elif input_path.is_dir(): + images = sorted( + list(input_path.glob("*.jpg")) + + list(input_path.glob("*.jpeg")) + + list(input_path.glob("*.png")) + ) + else: + print(f"[ERROR] 경로 없음: {input_path}") + sys.exit(1) + + if not images: + print(f"[ERROR] 이미지 없음: {input_path}") + sys.exit(1) + + out_dir = Path(args.output) + out_dir.mkdir(parents=True, exist_ok=True) + + vis_dir = None + if args.vis: + vis_dir = Path(args.vis) + vis_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n처리 대상: {len(images)}장") + print(f"검출 클래스: {list(TARGETS.values())}") + print(f"SAM3 conf: {args.conf}") + print(f"annotation 저장: {out_dir}") + if vis_dir: + print(f"시각화 저장: {vis_dir}") + print() + + total_counts: dict = {v: 0 for v in TARGETS.values()} + processed = 0 + + for img_path in images: + print(f"[{processed+1}/{len(images)}] {img_path.name}") + result = process_image(img_path, args.conf, vis_dir) + if not result: + print(f" [SKIP] 처리 실패") + continue + + annotation, class_counts = result + n = len(annotation["shapes"]) + print(f" → {n}개 객체 검출: {class_counts if class_counts else '없음'}") + + # JSON 저장 + json_path = out_dir / f"{img_path.stem}.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(annotation, f, ensure_ascii=False, indent=2) + + for k, v in class_counts.items(): + total_counts[k] = total_counts.get(k, 0) + v + processed += 1 + + # 요약 + print("\n" + "="*50) + print(f"완료: {processed}/{len(images)}장 처리") + print("\n클래스별 총 검출 수:") + for cls, cnt in total_counts.items(): + avg = cnt / max(processed, 1) + bar = "#" * min(cnt, 30) + print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}") + print(f"\nJSON 저장: {out_dir.resolve()}") + if vis_dir: + print(f"시각화: {vis_dir.resolve()}") + print("\n다음 단계:") + print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택") + print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습") + + +if __name__ == "__main__": + main() diff --git a/tools/sam3_everything_explore.py b/tools/sam3_everything_explore.py new file mode 100644 index 0000000..bae2369 --- /dev/null +++ b/tools/sam3_everything_explore.py @@ -0,0 +1,291 @@ +""" +SAM3.1 탐색 모드 (Discovery Sweep) +이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 → +나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정. + +이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음. +대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다. + +사용법: + python tools/sam3_everything_explore.py \\ + --input "data/역사구간/1.회덕역/..." \\ + --cols 8 --rows 6 + +사전 조건: SAM3.1 서버 실행 (start_server.bat) +""" +import argparse +import base64 +import json +from collections import Counter +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import cv2 +import numpy as np +import requests + +SAM3_SERVER = "http://localhost:8000" +SAM3_MODEL_ID = "segment_anything_3" + + +# ── 이미지 인코딩 ───────────────────────────────────────────────────────────── +def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple: + h, w = image_bgr.shape[:2] + scale = 1.0 + if max(h, w) > max_size: + scale = max_size / max(h, w) + image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale))) + _, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) + return base64.b64encode(buf).decode("utf-8"), scale + + +# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함 +DISCOVERY_PROMPT = ( + "railroad track, railway rail, " + "catenary pole, overhead line pole, electric pole, " + "overhead wire, catenary wire, power line cable, " + "railway sleeper, concrete tie, " + "guardrail, highway barrier, road fence, " + "bridge, viaduct, overpass, " + "vegetation, tree, bush, grass, " + "building, structure, roof, wall, " + "vehicle, car, truck, " + "road, asphalt, pavement, " + "slope, embankment, retaining wall, " + "noise barrier, sound wall, " + "signal, sign board" +) + + +# ── SAM3.1 discovery sweep 호출 ─────────────────────────────────────────────── +def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list: + """넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트.""" + b64, scale = encode_image(tile_bgr) + payload = { + "model": SAM3_MODEL_ID, + "image": b64, + "params": { + "text_prompt": prompt, + "conf_threshold": conf, + "show_masks": True, + "show_boxes": False, + }, + } + try: + r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120) + r.raise_for_status() + resp = r.json() + if not resp.get("success"): + return [] + shapes = resp.get("data", {}).get("shapes", []) + shapes = [s if isinstance(s, dict) else s.dict() for s in shapes] + if scale < 1.0: + inv = 1.0 / scale + for s in shapes: + if s.get("shape_type") == "polygon": + s["points"] = [[x * inv, y * inv] for x, y in s["points"]] + return [s for s in shapes if s.get("shape_type") == "polygon"] + except Exception as e: + print(f" [SAM3 오류] {e}") + return [] + + +# ── NMS ─────────────────────────────────────────────────────────────────────── +def _bbox(pts): + xs = [p[0] for p in pts]; ys = [p[1] for p in pts] + return min(xs), min(ys), max(xs), max(ys) + + +def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list: + if not shapes: + return [] + bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32) + scores = np.array([float(s.get("score", 0.5)) for s in shapes]) + order = scores.argsort()[::-1] + keep = [] + while len(order): + i = order[0]; keep.append(i) + if len(order) == 1: break + xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0]) + yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1]) + xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2]) + yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3]) + inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1) + a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1]) + a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1]) + iou = inter / (a_i + a_j - inter + 1e-6) + order = order[1:][iou < iou_thresh] + return [shapes[i] for i in keep] + + +# ── 타일 분할 + 병렬 검출 ───────────────────────────────────────────────────── +def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt): + H, W = image_bgr.shape[:2] + base_w = W / cols + base_h = H / rows + pad_x = int(base_w * overlap) + pad_y = int(base_h * overlap) + + tiles = [] + for r in range(rows): + for c in range(cols): + idx = r * cols + c + 1 + x0 = max(0, int(c * base_w) - pad_x) + x1 = min(W, int((c + 1) * base_w) + pad_x) + y0 = max(0, int(r * base_h) - pad_y) + y1 = min(H, int((r + 1) * base_h) + pad_y) + tiles.append((idx, x0, y0, x1, y1)) + + total = len(tiles) + done = [0] + all_shapes = [] + + def process(args): + idx, x0, y0, x1, y1 = args + tile = image_bgr[y0:y1, x0:x1] + shapes = sam3_everything(tile, conf, prompt) + for s in shapes: + s["points"] = [[px + x0, py + y0] for px, py in s["points"]] + return shapes + + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = {ex.submit(process, t): t for t in tiles} + for fut in as_completed(futs): + result = fut.result() + all_shapes.extend(result) + done[0] += 1 + print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}개", end="\r") + print() + return all_shapes + + +# ── 시각화 ──────────────────────────────────────────────────────────────────── +def draw_everything(image_bgr, shapes, cols, rows): + vis = image_bgr.copy() + H, W = vis.shape[:2] + + # 타일 경계 + for r in range(rows): + for c in range(cols): + bx0, by0 = int(c * W / cols), int(r * H / rows) + bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows) + cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1) + + rng = np.random.default_rng(42) + for s in shapes: + pts = np.array(s["points"], dtype=np.int32) + color = tuple(int(v) for v in rng.integers(80, 255, size=3)) + overlay = vis.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis) + cv2.polylines(vis, [pts], True, color, 1) + + # 라벨 표시 (있을 경우) + label = s.get("label", "") + if label: + cx = int(np.mean([p[0] for p in s["points"]])) + cy = int(np.mean([p[1] for p in s["points"]])) + cv2.putText(vis, label, (cx, cy), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA) + + cv2.putText(vis, f"total segments: {len(shapes)}", + (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2) + return vis + + +# ── 라벨 분석 → text_prompt 후보 출력 ──────────────────────────────────────── +def analyze_labels(shapes): + labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()] + if not labels: + print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.") + return + + counter = Counter(labels) + print(f"\n{'─'*50}") + print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)") + print(f"{'─'*50}") + for label, cnt in counter.most_common(30): + bar = "#" * min(cnt, 40) + print(f" {label:35s} {cnt:4d} {bar}") + print(f"{'─'*50}") + + top_labels = [lb for lb, _ in counter.most_common(10)] + print(f"\n[text_prompt 후보]") + print(f' "{", ".join(top_labels)}"') + + +# ── 메인 ───────────────────────────────────────────────────────────────────── +def main(): + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--input", required=True, help="입력 이미지") + ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)") + ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)") + ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)") + ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)") + ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)") + ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)") + ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)") + ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)") + args = ap.parse_args() + + prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "") + + img_path = Path(args.input) + buf = np.fromfile(str(img_path), dtype=np.uint8) + image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image_bgr is None: + print(f"이미지 로드 실패: {img_path}"); return + + H, W = image_bgr.shape[:2] + print(f"이미지 : {W}×{H}") + print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}개") + print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n") + + import time + print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):") + for item in prompt.split(","): + print(f" · {item.strip()}") + print() + + t0 = time.time() + shapes = detect_everything_tiled( + image_bgr, args.cols, args.rows, args.overlap, + args.conf, args.workers, prompt + ) + print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...") + shapes = nms_shapes(shapes, iou_thresh=args.nms) + print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n") + + analyze_labels(shapes) + + vis = draw_everything(image_bgr, shapes, args.cols, args.rows) + h, w = vis.shape[:2] + if max(h, w) > 4096: + s = 4096 / max(h, w) + vis = cv2.resize(vis, (int(w*s), int(h*s))) + + out_path = (Path(args.output) if args.output + else img_path.parent / (img_path.stem + "_everything.jpg")) + cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path)) + print(f"\n저장: {out_path}") + + # JSON으로 라벨 데이터도 저장 (분석용) + json_path = out_path.with_suffix(".json") + label_data = { + "total_segments": len(shapes), + "label_counts": dict(Counter( + s.get("label", "(no label)") for s in shapes + )), + "segments": [ + {"label": s.get("label",""), "score": s.get("score",0), + "bbox": list(_bbox(s["points"]))} + for s in shapes + ] + } + json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"라벨 데이터: {json_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/sam3_receipt_ocr.py b/tools/sam3_receipt_ocr.py new file mode 100644 index 0000000..685c871 --- /dev/null +++ b/tools/sam3_receipt_ocr.py @@ -0,0 +1,219 @@ +import argparse +import sys +import os +import time +import json +from pathlib import Path +import cv2 +import numpy as np +import torch +from PIL import Image + +# Add server to path so we can import sam3 locally +server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server")) +models_path = os.path.join(server_path, "app", "models") +if server_path not in sys.path: + sys.path.insert(0, server_path) +if models_path not in sys.path: + sys.path.insert(0, models_path) + +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor + +# PaddleOCR will be imported inside the function to avoid errors if not installed +def run_paddleocr(image_np): + # Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows + import os as _os + _os.environ["FLAGS_use_mkldnn"] = "0" + _os.environ["PADDLE_WITH_MKLDNN"] = "0" + _os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error + + from paddleocr import PaddleOCR + # Force use_gpu=False if oneDNN is failing on CPU, or let it detect. + # On Windows, sometimes CPU + oneDNN is the default and it fails. + ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available()) + result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility + return result + +def build_point_grid(n_per_side: int) -> np.ndarray: + offset = 1.0 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + pts_x, pts_y = np.meshgrid(points_one_side, points_one_side) + grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1) + return grid + +def mask_iou(mask1, mask2): + inter = np.logical_and(mask1, mask2).sum() + union = np.logical_or(mask1, mask2).sum() + if union == 0: return 0 + return inter / union + +def get_receipt_mask(image_bgr, model_path, points_per_side=32): + print("Loading SAM3 Model for receipt detection...") + device = "cuda" if torch.cuda.is_available() else "cpu" + bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz") + + model = build_sam3_image_model( + bpe_path=bpe_path, + device=device, + checkpoint_path=model_path, + ) + + processor = Sam3Processor(model, confidence_threshold=0.7, device=device) + + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(image_rgb) + + state = processor.set_image(pil_image) + grid_points = build_point_grid(points_per_side) + + masks = [] + scores = [] + + print(f"Sampling {len(grid_points)} points...") + for i, (nx, ny) in enumerate(grid_points): + processor.reset_all_prompts(state) + state = processor.add_point_prompt(point=[nx, ny], label=True, state=state) + + if "masks" in state and len(state["masks"]) > 0: + best_idx = torch.argmax(state["scores"]) + mask = state["masks"][best_idx].cpu().numpy() + score = state["scores"][best_idx].item() + if score > 0.8: + masks.append(mask) + scores.append(score) + + if not masks: + return None + + # Pick the largest mask that covers a significant area (receipt is usually big) + areas = [m.sum() for m in masks] + # Simple heuristic: largest mask + best_mask_idx = np.argmax(areas) + return masks[best_mask_idx] + +def crop_and_warp(image, mask): + # Ensure mask is 2D + mask = np.squeeze(mask) + if mask.ndim != 2: + print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...") + if mask.ndim == 3: mask = mask[0] + + # Find contours + mask_uint8 = (mask > 0).astype(np.uint8) * 255 + if np.sum(mask_uint8) == 0: + print("Warning: Mask is empty.") + return image + + contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return image + + cnt = max(contours, key=cv2.contourArea) + + # Get bounding box + x, y, w, h = cv2.boundingRect(cnt) + + # Try to find 4 corners for perspective transform + epsilon = 0.02 * cv2.arcLength(cnt, True) + approx = cv2.approxPolyDP(cnt, epsilon, True) + + if len(approx) == 4: + print("Found 4 corners, applying perspective transform...") + pts = approx.reshape(4, 2) + # Sort points: top-left, top-right, bottom-right, bottom-left + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + + (tl, tr, br, bl) = rect + widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) + widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) + maxWidth = max(int(widthA), int(widthB)) + + heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) + heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) + maxHeight = max(int(heightA), int(heightB)) + + dst = np.array([ + [0, 0], + [maxWidth - 1, 0], + [maxWidth - 1, maxHeight - 1], + [0, maxHeight - 1]], dtype="float32") + + M = cv2.getPerspectiveTransform(rect, dst) + warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) + return warped + else: + print("Could not find 4 clear corners, just cropping to bounding box.") + # Create a mask image to black out background + masked_img = cv2.bitwise_and(image, image, mask=mask_uint8) + crop = masked_img[y:y+h, x:x+w] + return crop + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="Input image path") + parser.add_argument("--output_dir", default="output/ocr", help="Output directory") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + # Read image + buf = np.fromfile(args.input, dtype=np.uint8) + image = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image is None: + print(f"Failed to read {args.input}") + return + + model_path = os.path.join(server_path, "sam3.pt") + + # 1. SAM3 Masking + mask = get_receipt_mask(image, model_path) + if mask is None: + print("No receipt-like object found.") + return + + # 2. Cropping / Warping + processed_img = crop_and_warp(image, mask) + + # Save processed image for debugging + processed_path = os.path.join(args.output_dir, "processed_receipt.jpg") + cv2.imwrite(processed_path, processed_img) + print(f"Saved processed image to {processed_path}") + + # 3. PaddleOCR + print("Running PaddleOCR...") + ocr_results = run_paddleocr(processed_img) + + # 4. Save results + output_json = os.path.join(args.output_dir, "ocr_results.json") + with open(output_json, "w", encoding="utf-8") as f: + json.dump(ocr_results, f, ensure_ascii=False, indent=2) + + print(f"OCR results saved to {output_json}") + + # Print summary + if ocr_results: + print("\n--- OCR Extracted Text ---") + for page in ocr_results: + if page is None: + continue + # New v3.x format: list of dicts with 'rec_text' key + if isinstance(page, dict): + text = page.get('rec_text', '') + score = page.get('rec_score', 0) + if text: + print(f"{text} (conf: {score:.2f})") + elif isinstance(page, list): + for line in page: + if line and isinstance(line, list) and len(line) >= 2: + print(line[1][0]) + print("--------------------------") + +if __name__ == "__main__": + main() diff --git a/tools/sam3_segment_everything.py b/tools/sam3_segment_everything.py new file mode 100644 index 0000000..12f6dfb --- /dev/null +++ b/tools/sam3_segment_everything.py @@ -0,0 +1,185 @@ +import argparse +import sys +import os +import time +from pathlib import Path +import cv2 +import numpy as np +import torch + +# Add server to path so we can import sam3 locally +server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server")) +models_path = os.path.join(server_path, "app", "models") +if server_path not in sys.path: + sys.path.insert(0, server_path) +if models_path not in sys.path: + sys.path.insert(0, models_path) + +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +from app.models.segment_anything_3 import SegmentAnything3 + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1].""" + offset = 1.0 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + pts_x, pts_y = np.meshgrid(points_one_side, points_one_side) + grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1) + return grid + +def mask_iou(mask1, mask2): + inter = np.logical_and(mask1, mask2).sum() + union = np.logical_or(mask1, mask2).sum() + if union == 0: + return 0 + return inter / union + +def mask_to_polygon(mask, epsilon_factor=0.001): + mask = np.squeeze(mask) + mask_uint8 = (mask > 0).astype(np.uint8) + contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return [] + largest = max(contours, key=cv2.contourArea) + if epsilon_factor > 0: + epsilon = epsilon_factor * cv2.arcLength(largest, True) + approx = cv2.approxPolyDP(largest, epsilon, True) + else: + approx = largest + points = [[float(p[0][0]), float(p[0][1])] for p in approx] + return points + +def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5): + print("Loading SAM3 Model locally...") + device = "cuda" if torch.cuda.is_available() else "cpu" + bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz") + + model = build_sam3_image_model( + bpe_path=bpe_path, + device=device, + checkpoint_path=model_path, + ) + + if device == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device) + + # PIL image format + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + from PIL import Image + pil_image = Image.fromarray(image_rgb) + + print("Computing image embedding...") + t0 = time.time() + state = processor.set_image(pil_image) + print(f"Image embedding done in {time.time() - t0:.2f}s") + + grid_points = build_point_grid(points_per_side) + print(f"Generated {len(grid_points)} grid points for sampling.") + + masks = [] + scores = [] + + t0 = time.time() + for i, (nx, ny) in enumerate(grid_points): + if i % 100 == 0: + print(f" Processed {i}/{len(grid_points)} points...") + + processor.reset_all_prompts(state) + state = processor.add_point_prompt(point=[nx, ny], label=True, state=state) + + if "masks" in state and len(state["masks"]) > 0: + # Take the mask with the highest score + best_idx = torch.argmax(state["scores"]) + mask = state["masks"][best_idx].cpu().numpy() + score = state["scores"][best_idx].item() + if score > conf_thresh: + masks.append(mask) + scores.append(score) + print(f"Grid prediction done in {time.time() - t0:.2f}s") + print(f"Found {len(masks)} raw masks.") + + if not masks: + return [] + + # Simple NMS based on IoU + print("Applying NMS...") + order = np.argsort(scores)[::-1] + keep = [] + + for idx in order: + if len(keep) == 0: + keep.append(idx) + continue + + current_mask = masks[idx] + overlap = False + for k in keep: + iou = mask_iou(current_mask, masks[k]) + if iou > nms_thresh: + overlap = True + break + if not overlap: + keep.append(idx) + + final_masks = [masks[idx] for idx in keep] + final_scores = [scores[idx] for idx in keep] + print(f"Kept {len(final_masks)} masks after NMS.") + + # Convert to polygons + results = [] + for m, s in zip(final_masks, final_scores): + poly = mask_to_polygon(m) + if poly: + results.append({"polygon": poly, "score": s}) + + return results + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="Input image path") + parser.add_argument("--output", required=True, help="Output vis image path") + parser.add_argument("--points", type=int, default=32, help="Points per side") + args = parser.parse_args() + + buf = np.fromfile(args.input, dtype=np.uint8) + image = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if image is None: + print(f"Failed to read {args.input}") + return + + # Shrink image if too large just to make NMS faster + h, w = image.shape[:2] + max_dim = 1024 + if max(h, w) > max_dim: + scale = max_dim / max(h, w) + image_proc = cv2.resize(image, (int(w * scale), int(h * scale))) + else: + image_proc = image.copy() + + model_path = os.path.join(server_path, "sam3.pt") + + results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7) + + vis = image_proc.copy() + np.random.seed(42) + for res in results: + poly = res["polygon"] + pts = np.array(poly, dtype=np.int32) + color = np.random.randint(0, 255, (3,)).tolist() + + overlay = vis.copy() + cv2.fillPoly(overlay, [pts], color) + cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis) + cv2.polylines(vis, [pts], True, color, 1) + + # Fix unicode paths in output + is_success, im_buf_arr = cv2.imencode(".jpg", vis) + if is_success: + im_buf_arr.tofile(args.output) + print(f"Saved visualization to {args.output}") + +if __name__ == "__main__": + main() diff --git a/tools/show_tiles.py b/tools/show_tiles.py new file mode 100644 index 0000000..ecb10ee --- /dev/null +++ b/tools/show_tiles.py @@ -0,0 +1,72 @@ +""" +이미지를 cols×rows 타일로 분할하고 번호를 좌상단에 표시. +사용법: + python tools/show_tiles.py --input <이미지> --cols 8 --rows 6 --overlap 0.10 +""" +import argparse +from pathlib import Path + +import cv2 +import numpy as np + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--input", required=True) + ap.add_argument("--output", default=None) + ap.add_argument("--cols", type=int, default=8) + ap.add_argument("--rows", type=int, default=6) + ap.add_argument("--overlap", type=float, default=0.10) + args = ap.parse_args() + + img_path = Path(args.input) + buf = np.fromfile(str(img_path), dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + H, W = img.shape[:2] + + base_w = W / args.cols + base_h = H / args.rows + pad_x = int(base_w * args.overlap) + pad_y = int(base_h * args.overlap) + + vis = img.copy() + + for r in range(args.rows): + for c in range(args.cols): + idx = r * args.cols + c + 1 + x0 = max(0, int(c * base_w) - pad_x) + x1 = min(W, int((c + 1) * base_w) + pad_x) + y0 = max(0, int(r * base_h) - pad_y) + y1 = min(H, int((r + 1) * base_h) + pad_y) + + # 타일 경계선 (기준선 = 중복 없는 경계) + bx0 = int(c * base_w) + by0 = int(r * base_h) + bx1 = min(W, int((c + 1) * base_w)) + by1 = min(H, int((r + 1) * base_h)) + cv2.rectangle(vis, (bx0, by0), (bx1, by1), (0, 200, 255), 3) + + # 번호 좌상단 + font_scale = max(0.8, min(W, H) / 3000) + thickness = max(2, int(font_scale * 3)) + tx, ty = bx0 + 10, by0 + int(base_h * 0.12) + + # 배경 박스 + label = str(idx) + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, thickness) + cv2.rectangle(vis, (tx - 4, ty - th - 4), (tx + tw + 4, ty + 4), (0, 0, 0), -1) + cv2.putText(vis, label, (tx, ty), + cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, (0, 200, 255), thickness, cv2.LINE_AA) + + # 출력 크기 제한 + h, w = vis.shape[:2] + if max(h, w) > 4096: + s = 4096 / max(h, w) + vis = cv2.resize(vis, (int(w * s), int(h * s))) + + out = (Path(args.output) if args.output + else img_path.parent / (img_path.stem + "_tiles.jpg")) + cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 92])[1].tofile(str(out)) + print(f"저장: {out} ({args.cols}×{args.rows}={args.cols*args.rows}타일)") + +if __name__ == "__main__": + main() diff --git a/tools/video_sam3_segment.py b/tools/video_sam3_segment.py new file mode 100644 index 0000000..426a412 --- /dev/null +++ b/tools/video_sam3_segment.py @@ -0,0 +1,237 @@ +""" +드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션 +Usage: + 1. SAM3 서버 시작: start_server.bat + 2. python tools/video_sam3_segment.py +""" + +import base64 +import cv2 +import json +import numpy as np +import requests +import sys +from pathlib import Path + +# === 설정 === +VIDEO_PATH = Path("sample/rail.mp4") +OUTPUT_DIR = Path("output/video_segmentation") +SAM3_URL = "http://localhost:8000" +MODEL_ID = "segment_anything_3" +FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격 + +# 철도 시설물 + 일반 객체 프롬프트 +PROMPTS = [ + "catenary pole", # 전철주 + "junction box", # 전기박스 + "utility box", # 통신박스 + "rail track", # 레일 + "fence", # 펜스 + "cable", # 전선/케이블 + "sign", # 표지판 + "building", # 건물 + "vegetation", # 식생 +] + +# 클래스별 색상 (BGR) +COLORS = { + "catenary pole": (0, 0, 255), # 빨강 + "junction box": (0, 165, 255), # 주황 + "utility box": (0, 255, 255), # 노랑 + "rail track": (255, 0, 0), # 파랑 + "fence": (255, 0, 255), # 자홍 + "cable": (0, 255, 0), # 초록 + "sign": (255, 255, 0), # 시안 + "building": (128, 128, 255), # 연한 빨강 + "vegetation": (0, 128, 0), # 진한 초록 +} + + +def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]: + """동영상에서 일정 간격으로 프레임 추출""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + print(f"ERROR: Cannot open video: {video_path}") + sys.exit(1) + + fps = cap.get(cv2.CAP_PROP_FPS) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s") + + frames = [] + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if idx % interval == 0: + frames.append((idx, frame)) + idx += 1 + + cap.release() + print(f"Extracted {len(frames)} frames (interval={interval})") + return frames + + +def encode_frame(frame: np.ndarray) -> str: + """프레임을 base64로 인코딩""" + _, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) + return base64.b64encode(buffer).decode("utf-8") + + +def predict_sam3(image_b64: str, text_prompt: str) -> dict: + """SAM3 서버에 예측 요청""" + payload = { + "model": MODEL_ID, + "image": image_b64, + "params": { + "text_prompt": text_prompt, + "conf_threshold": 0.25, + }, + } + try: + resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60) + resp.raise_for_status() + return resp.json() + except requests.exceptions.ConnectionError: + print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.") + sys.exit(1) + except Exception as e: + print(f" ERROR [{text_prompt}]: {e}") + return {"success": False} + + +def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray: + """세그멘테이션 결과를 프레임 위에 그리기""" + overlay = frame.copy() + color = COLORS.get(prompt, (200, 200, 200)) + + for shape in shapes: + points = np.array(shape["points"], dtype=np.int32) + shape_type = shape.get("shape_type", "polygon") + + if shape_type == "polygon" and len(points) >= 3: + cv2.fillPoly(overlay, [points], color) + cv2.polylines(frame, [points], True, color, 2) + elif shape_type == "rectangle" and len(points) == 2: + cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1) + cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2) + + # 라벨 텍스트 + label = shape.get("label", prompt) + score = shape.get("score") + text = f"{label}" + (f" {score:.2f}" if score else "") + if len(points) > 0: + tx, ty = int(points[0][0]), int(points[0][1]) - 5 + cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + # 반투명 오버레이 블렌딩 + cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame) + return frame + + +def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray: + """클래스별 개별 레이어 이미지 (투명 배경 위 마스크)""" + h, w = frame_shape[:2] + color = COLORS.get(prompt, (200, 200, 200)) + bgr = np.zeros((h, w, 3), dtype=np.uint8) + alpha = np.zeros((h, w), dtype=np.uint8) + + for shape in shapes: + points = np.array(shape["points"], dtype=np.int32) + if shape.get("shape_type") == "polygon" and len(points) >= 3: + cv2.fillPoly(bgr, [points], color) + cv2.fillPoly(alpha, [points], 180) + + layer = np.dstack([bgr, alpha]) + return layer + + +def main(): + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # 1. 서버 상태 확인 + print("=== SAM3 서버 연결 확인 ===") + try: + health = requests.get(f"{SAM3_URL}/health", timeout=5) + print(f"Server status: {health.json()}") + except requests.exceptions.ConnectionError: + print("ERROR: SAM3 서버가 실행 중이 아닙니다!") + print(" → start_server.bat을 먼저 실행하세요.") + sys.exit(1) + + # 2. 프레임 추출 + print("\n=== 프레임 추출 ===") + frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL) + + # 3. 각 프레임별, 각 프롬프트별 세그멘테이션 + all_results = {} + + for frame_idx, (fidx, frame) in enumerate(frames): + print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===") + + frame_b64 = encode_frame(frame) + frame_results = {} + composite = frame.copy() + + for prompt in PROMPTS: + print(f" Segmenting: {prompt}...", end=" ") + result = predict_sam3(frame_b64, prompt) + + if result.get("success") and result.get("data", {}).get("shapes"): + shapes = result["data"]["shapes"] + n = len(shapes) + print(f"→ {n} objects found") + frame_results[prompt] = shapes + + # 합성 이미지에 그리기 + composite = draw_shapes_on_frame(composite, shapes, prompt) + + # 개별 레이어 저장 (PNG with alpha) + layer = create_layer_image(frame.shape, shapes, prompt) + layer_name = prompt.replace(" ", "_") + layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png" + cv2.imwrite(str(layer_path), layer) + else: + print("→ no objects") + + # 합성 이미지 저장 + composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg" + cv2.imwrite(str(composite_path), composite) + + # 원본 프레임 저장 (비교용) + original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg" + cv2.imwrite(str(original_path), frame) + + all_results[f"frame_{fidx}"] = { + "frame_index": fidx, + "detections": { + prompt: len(shapes) + for prompt, shapes in frame_results.items() + }, + } + + # 4. 결과 요약 저장 + summary_path = OUTPUT_DIR / "segmentation_summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(all_results, f, indent=2, ensure_ascii=False) + + # 5. 범례 이미지 생성 + legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8) + for i, (prompt, color) in enumerate(COLORS.items()): + y = i * 30 + 20 + cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1) + cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend) + + print(f"\n=== 완료 ===") + print(f"결과 저장: {OUTPUT_DIR}/") + print(f" - frame_XXXX_original.jpg : 원본 프레임") + print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성") + print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)") + print(f" - segmentation_summary.json : 검출 요약") + print(f" - legend.jpg : 색상 범례") + + +if __name__ == "__main__": + main() diff --git a/tools/web_ui.py b/tools/web_ui.py new file mode 100644 index 0000000..f99e5c6 --- /dev/null +++ b/tools/web_ui.py @@ -0,0 +1,618 @@ +""" +Railway Detection Web UI +사용법: python tools/web_ui.py +브라우저: http://localhost:7000 +""" +import asyncio +import base64 +import json +import os +import queue +import subprocess +import sys +import threading +import uuid +from pathlib import Path + +import cv2 +import numpy as np + +try: + from fastapi import FastAPI + from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse + import uvicorn +except ImportError: + print("FastAPI/uvicorn 설치 필요: pip install fastapi uvicorn") + sys.exit(1) + +app = FastAPI() +jobs: dict = {} # job_id -> {queue, status, result_path} + +ROOT = Path(__file__).parent.parent # 프로젝트 루트 + + +# ── 미리보기 이미지 생성 ─────────────────────────────────────────────────────── +def make_preview(image_path: str, cols: int, rows: int, selected_rows: list) -> str: + buf = np.fromfile(image_path, dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if img is None: + raise ValueError(f"이미지 로드 실패: {image_path}") + + H, W = img.shape[:2] + tw = 2000 + scale = tw / W + vis = cv2.resize(img, (tw, int(H * scale))) + H_s, W_s = vis.shape[:2] + + tile_w = W_s / cols + tile_h = H_s / rows + + for r in range(rows): + for c in range(cols): + idx = r * cols + c + 1 + x0, y0 = int(c * tile_w), int(r * tile_h) + x1, y1 = min(W_s, int((c + 1) * tile_w)), min(H_s, int((r + 1) * tile_h)) + row_num = r + 1 + + if row_num in selected_rows: + overlay = vis.copy() + cv2.rectangle(overlay, (x0, y0), (x1, y1), (0, 200, 0), -1) + cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis) + cv2.rectangle(vis, (x0, y0), (x1, y1), (0, 255, 0), 2) + else: + cv2.rectangle(vis, (x0, y0), (x1, y1), (180, 180, 180), 1) + + label = str(idx) + fs = max(0.35, tile_w / 120) + (tw2, th2), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fs, 1) + cx, cy = (x0 + x1) // 2, (y0 + y1) // 2 + cv2.putText(vis, label, (cx - tw2 // 2, cy + th2 // 2), + cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA) + + # Row 번호 + y0 = int(r * tile_h) + row_num = r + 1 + color = (0, 255, 0) if row_num in selected_rows else (160, 160, 160) + cv2.putText(vis, f"Row{row_num}", (4, y0 + 18), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2, cv2.LINE_AA) + + _, enc = cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 88]) + return base64.b64encode(enc).decode("utf-8") + + +def tiles_from_rows(selected_rows: list, cols: int) -> str: + if not selected_rows: + return "all" + ids = [] + for r in sorted(selected_rows): + ids.extend(range((r - 1) * cols + 1, r * cols + 1)) + if ids and ids[-1] - ids[0] + 1 == len(ids): + return f"{ids[0]}-{ids[-1]}" + return ",".join(str(t) for t in ids) + + +# ── API ─────────────────────────────────────────────────────────────────────── +@app.post("/api/preview") +async def api_preview(data: dict): + try: + b64 = make_preview( + data["image_path"], + data.get("cols", 8), + data.get("rows", 6), + data.get("selected_rows", []), + ) + return {"ok": True, "image": b64} + except Exception as e: + return {"ok": False, "error": str(e)} + + +@app.post("/api/detect") +async def api_detect(data: dict): + job_id = str(uuid.uuid4())[:8] + q: queue.Queue = queue.Queue() + jobs[job_id] = {"queue": q, "status": "running", "result_path": None, "proc": None} + + def run(): + tiles_str = tiles_from_rows(data.get("selected_rows", []), data.get("cols", 8)) + cmd = [ + sys.executable, "-u", str(ROOT / "tools" / "detect_all_objects.py"), + "--input", data["image_path"], + "--categories", data.get("categories", "configs/railway_zone.json"), + "--tiles", tiles_str, + "--cols", str(data.get("cols", 8)), + "--rows", str(data.get("rows", 6)), + "--overlap", str(data.get("overlap", 0.20)), + "--conf", str(data.get("conf", 0.20)), + "--workers", str(data.get("workers", 8)), + "--save-json" + ] + + env = {**os.environ, "PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"} + flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0 + try: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, encoding="utf-8", errors="replace", + env=env, cwd=str(ROOT), creationflags=flags, + ) + jobs[job_id]["proc"] = proc + for raw in proc.stdout: + for line in raw.replace("\r", "\n").split("\n"): + line = line.strip() + if not line: + continue + q.put(("log", line)) + if line.startswith("저장:"): + jobs[job_id]["result_path"] = line.replace("저장:", "").strip() + proc.wait() + if jobs[job_id]["status"] == "stopped": + q.put(("error", "사용자가 중지했습니다")) + elif proc.returncode == 0: + q.put(("done", jobs[job_id]["result_path"] or "완료")) + jobs[job_id]["status"] = "done" + else: + q.put(("error", f"종료코드 {proc.returncode}")) + jobs[job_id]["status"] = "error" + except Exception as e: + q.put(("error", str(e))) + jobs[job_id]["status"] = "error" + + threading.Thread(target=run, daemon=True).start() + return {"job_id": job_id} + + +@app.post("/api/stop/{job_id}") +async def api_stop(job_id: str): + job = jobs.get(job_id) + if not job: + return {"ok": False, "error": "not found"} + job["status"] = "stopped" + proc = job.get("proc") + if proc and proc.poll() is None: + if os.name == "nt": + subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + else: + proc.terminate() + return {"ok": True} + + +@app.get("/api/progress/{job_id}") +async def api_progress(job_id: str): + if job_id not in jobs: + return JSONResponse({"error": "not found"}, status_code=404) + + async def event_gen(): + q = jobs[job_id]["queue"] + while True: + try: + type_, msg = q.get_nowait() + yield f"data: {json.dumps({'type': type_, 'msg': msg})}\n\n" + if type_ in ("done", "error"): + break + except queue.Empty: + await asyncio.sleep(0.25) + yield ": ping\n\n" + + return StreamingResponse(event_gen(), media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + + +@app.get("/api/result/{job_id}") +async def api_result(job_id: str): + job = jobs.get(job_id, {}) + path = job.get("result_path") + if not path or not Path(path).exists(): + return {"ok": False, "error": "결과 없음"} + buf = np.fromfile(path, dtype=np.uint8) + img = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if img is None: + return {"ok": False, "error": "이미지 로드 실패"} + h, w = img.shape[:2] + if max(h, w) > 3000: + s = 3000 / max(h, w) + img = cv2.resize(img, (int(w * s), int(h * s))) + _, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 90]) + return {"ok": True, "image": base64.b64encode(enc).decode(), "path": path} + + +@app.post("/api/open-folder") +async def api_open_folder(data: dict): + path = Path(data.get("path", "output/detect")) + folder = path.parent if path.is_file() else path + if os.name == "nt": + os.startfile(str(folder.resolve())) + return {"ok": True} + + +# ── HTML ────────────────────────────────────────────────────────────────────── +@app.get("/", response_class=HTMLResponse) +async def index(): + return HTML + + +HTML = r""" + + + +Railway Detection UI + + + +

🛤 Railway Detection UI

+ +
+ + +
+ +
+
+ +
+ +
+ + +
+
+ +
+ +
+
+ +
+ + +
+ +
+ +
+
+
+
+
+
+
+
+ + + + +
+ +
+ + 휠: 확대/축소  |  드래그: 이동 + 이미지 경로 입력 후 미리보기 클릭 + +
+
+ +
+
+ + 결과 없음 + + 휠: 확대/축소  |  드래그: 이동 +
+
+
검출 완료 후 결과가 표시됩니다
+ +
+
+ +
+
+
대기 중
+
+ + + +""" + + +if __name__ == "__main__": + print("Railway Detection UI 시작: http://localhost:7000") + uvicorn.run(app, host="0.0.0.0", port=7000, log_level="warning") diff --git a/tools/yoloworld_sam3_pipeline.py b/tools/yoloworld_sam3_pipeline.py new file mode 100644 index 0000000..8cc62b1 --- /dev/null +++ b/tools/yoloworld_sam3_pipeline.py @@ -0,0 +1,333 @@ +""" +YOLO-World + SAM3 반자동 레이블링 파이프라인 +============================================= +YOLO-World로 bbox 검출 → SAM3로 polygon mask 생성 → 시각화 저장 + +사용법: + python tools/yoloworld_sam3_pipeline.py --input sample/rail --output output/labeled + python tools/yoloworld_sam3_pipeline.py --input sample/rail/frame_00000.jpg + +SAM3 서버가 실행 중이어야 합니다: + cd X-AnyLabeling-Server && uvicorn app.main:app --host 0.0.0.0 --port 8000 +""" + +import argparse +import base64 +import json +import sys +from pathlib import Path + +import cv2 +import numpy as np +import requests + +# ── 검출 대상 클래스 (YOLO-World 텍스트 프롬프트) ────────────────────────── +TARGET_CLASSES = [ + "catenary pole", # 전철주 (세로 기둥) + "catenary arm", # 전철주 (가로 암) + "junction box", # 통신/전기 박스 + "fence", # 펜스 +] + +# 클래스별 색상 (BGR) +CLASS_COLORS = { + "catenary pole": (0, 200, 255), # 주황 + "catenary arm": (0, 100, 255), # 빨강 + "junction box": (255, 180, 0), # 파랑 + "fence": (0, 255, 100), # 초록 +} + +SAM3_SERVER = "http://localhost:8000" +MODEL_ID = "segment_anything_3" + + +# ───────────────────────────────────────────────────────────────────────────── +# YOLO-World 초기화 +# ───────────────────────────────────────────────────────────────────────────── +def load_yolo_world(model_size: str = "s"): + """YOLO-World 모델 로드 (자동 다운로드).""" + from ultralytics import YOLOWorld + + model_name = f"yolov8{model_size}-worldv2.pt" + print(f"[YOLO-World] 모델 로드: {model_name}") + model = YOLOWorld(model_name) + model.set_classes(TARGET_CLASSES) + print(f"[YOLO-World] 검출 클래스: {TARGET_CLASSES}") + return model + + +def detect_with_yoloworld(model, image_bgr: np.ndarray, conf: float = 0.15): + """YOLO-World로 bbox 검출. [(x1,y1,x2,y2,conf,class_name), ...] 반환.""" + results = model.predict(image_bgr, conf=conf, verbose=False) + detections = [] + if results and len(results) > 0: + r = results[0] + boxes = r.boxes + if boxes is not None and len(boxes) > 0: + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + score = float(box.conf[0].cpu()) + cls_idx = int(box.cls[0].cpu()) + cls_name = TARGET_CLASSES[cls_idx] if cls_idx < len(TARGET_CLASSES) else f"class_{cls_idx}" + detections.append((float(x1), float(y1), float(x2), float(y2), score, cls_name)) + return detections + + +# ───────────────────────────────────────────────────────────────────────────── +# SAM3 서버 호출 +# ───────────────────────────────────────────────────────────────────────────── +def encode_image(image_bgr: np.ndarray) -> str: + """이미지를 base64 문자열로 인코딩.""" + _, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95]) + return base64.b64encode(buf).decode("utf-8") + + +def sam3_segment(image_bgr: np.ndarray, boxes: list, conf_threshold: float = 0.25): + """SAM3 서버에 bbox 전달 → polygon masks 반환. + + Args: + image_bgr: 원본 이미지 + boxes: [(x1,y1,x2,y2,score,class_name), ...] + conf_threshold: SAM3 신뢰도 임계값 + + Returns: + shapes: [{"label": str, "points": [[x,y],...], "score": float}, ...] + """ + marks = [ + { + "type": "rectangle", + "label": 1, + "data": [b[0], b[1], b[2], b[3]], + } + for b in boxes + ] + + payload = { + "model": MODEL_ID, + "image": encode_image(image_bgr), + "params": { + "marks": marks, + "show_masks": True, + "show_boxes": False, + "conf_threshold": conf_threshold, + "epsilon_factor": 0.002, + }, + } + + try: + resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60) + resp.raise_for_status() + data = resp.json() + except requests.exceptions.ConnectionError: + print(f" [ERROR] SAM3 서버에 연결할 수 없습니다: {SAM3_SERVER}") + return [] + except Exception as e: + print(f" [ERROR] SAM3 호출 실패: {e}") + return [] + + if data.get("status") != "success": + print(f" [ERROR] SAM3 응답 오류: {data}") + return [] + + return data.get("data", {}).get("shapes", []) + + +# ───────────────────────────────────────────────────────────────────────────── +# 시각화 +# ───────────────────────────────────────────────────────────────────────────── +def draw_results(image_bgr: np.ndarray, detections: list, shapes: list) -> np.ndarray: + """bbox + mask를 이미지에 그리기.""" + vis = image_bgr.copy() + overlay = image_bgr.copy() + + # SAM3 polygon masks 그리기 + for shape in shapes: + pts = np.array(shape["points"], dtype=np.int32) + # 첫점=끝점이면 마지막 제거 + if len(pts) > 1 and np.array_equal(pts[0], pts[-1]): + pts = pts[:-1] + label = shape.get("label", "unknown") + color = CLASS_COLORS.get(label, (200, 200, 200)) + cv2.fillPoly(overlay, [pts], color) + cv2.polylines(vis, [pts], True, color, 2) + + cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis) + + # YOLO-World bbox + 라벨 그리기 + for x1, y1, x2, y2, score, cls_name in detections: + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + color = CLASS_COLORS.get(cls_name, (200, 200, 200)) + cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2) + label_text = f"{cls_name} {score:.2f}" + (tw, th), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1) + cv2.rectangle(vis, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1) + cv2.putText(vis, label_text, (x1 + 2, y1 - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA) + + # 범례 + legend_y = 20 + for cls_name, color in CLASS_COLORS.items(): + cv2.rectangle(vis, (10, legend_y), (25, legend_y + 15), color, -1) + cv2.putText(vis, cls_name, (30, legend_y + 13), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + legend_y += 22 + + return vis + + +# ───────────────────────────────────────────────────────────────────────────── +# 단일 이미지 처리 +# ───────────────────────────────────────────────────────────────────────────── +def process_image(yolo_model, image_path: Path, output_dir: Path, + yolo_conf: float, sam3_conf: float, skip_sam3: bool) -> dict: + """이미지 1장 처리. 결과 딕셔너리 반환.""" + image_bgr = cv2.imread(str(image_path)) + if image_bgr is None: + print(f" [SKIP] 이미지 읽기 실패: {image_path}") + return {} + + h, w = image_bgr.shape[:2] + print(f"\n 이미지: {image_path.name} ({w}x{h})") + + # ── Step 1: YOLO-World 검출 ──────────────────────────────────────────── + detections = detect_with_yoloworld(yolo_model, image_bgr, conf=yolo_conf) + print(f" YOLO-World: {len(detections)}개 검출") + for d in detections: + print(f" [{d[5]}] conf={d[4]:.3f} bbox=({d[0]:.0f},{d[1]:.0f},{d[2]:.0f},{d[3]:.0f})") + + shapes = [] + if not skip_sam3 and detections: + # ── Step 2: SAM3 mask 생성 ───────────────────────────────────────── + print(f" SAM3: {len(detections)}개 bbox → mask 요청 중...") + shapes = sam3_segment(image_bgr, detections, conf_threshold=sam3_conf) + print(f" SAM3: {len(shapes)}개 mask 반환") + for s in shapes: + pts_count = len(s.get("points", [])) + print(f" [{s.get('label')}] score={s.get('score', 0):.3f} points={pts_count}") + + # ── Step 3: 시각화 저장 ─────────────────────────────────────────────── + vis = draw_results(image_bgr, detections, shapes) + output_path = output_dir / f"{image_path.stem}_result.jpg" + cv2.imwrite(str(output_path), vis) + print(f" 저장: {output_path}") + + return { + "image": image_path.name, + "detections": [ + {"class": d[5], "conf": round(d[4], 3), + "bbox": [round(d[0]), round(d[1]), round(d[2]), round(d[3])]} + for d in detections + ], + "masks": len(shapes), + } + + +# ───────────────────────────────────────────────────────────────────────────── +# 메인 +# ───────────────────────────────────────────────────────────────────────────── +def main(): + parser = argparse.ArgumentParser(description="YOLO-World + SAM3 파이프라인") + parser.add_argument("--input", default="sample/rail", + help="이미지 파일 또는 폴더 경로") + parser.add_argument("--output", default="output/yoloworld_sam3", + help="결과 저장 폴더") + parser.add_argument("--model-size", default="s", choices=["s", "m", "l", "x"], + help="YOLO-World 모델 크기 (s/m/l/x)") + parser.add_argument("--yolo-conf", type=float, default=0.10, + help="YOLO-World 검출 임계값 (기본 0.10)") + parser.add_argument("--sam3-conf", type=float, default=0.20, + help="SAM3 마스크 임계값 (기본 0.20)") + parser.add_argument("--skip-sam3", action="store_true", + help="SAM3 건너뛰고 YOLO-World bbox만 시각화") + parser.add_argument("--server", default="http://localhost:8000", + help="SAM3 서버 주소") + args = parser.parse_args() + + global SAM3_SERVER + SAM3_SERVER = args.server + + # SAM3 서버 상태 확인 + if not args.skip_sam3: + try: + resp = requests.get(f"{SAM3_SERVER}/health", timeout=5) + if resp.status_code == 200: + print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}") + else: + print(f"[WARN] SAM3 서버 응답 이상 (status={resp.status_code})") + except Exception: + print(f"[WARN] SAM3 서버 연결 실패 ({SAM3_SERVER}). --skip-sam3 로 bbox만 볼 수 있음.") + ans = input("계속 진행하시겠습니까? (y/N): ").strip().lower() + if ans != "y": + sys.exit(0) + + # 입력 경로 처리 + input_path = Path(args.input) + if input_path.is_file(): + image_files = [input_path] + elif input_path.is_dir(): + image_files = sorted( + list(input_path.glob("*.jpg")) + + list(input_path.glob("*.jpeg")) + + list(input_path.glob("*.png")) + ) + else: + print(f"[ERROR] 입력 경로를 찾을 수 없습니다: {input_path}") + sys.exit(1) + + if not image_files: + print(f"[ERROR] 이미지 파일 없음: {input_path}") + sys.exit(1) + + print(f"\n총 {len(image_files)}개 이미지 처리 예정") + + # 출력 폴더 생성 + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + + # YOLO-World 로드 + yolo_model = load_yolo_world(args.model_size) + + # 처리 + summary = [] + for img_path in image_files: + result = process_image( + yolo_model, img_path, output_dir, + yolo_conf=args.yolo_conf, + sam3_conf=args.sam3_conf, + skip_sam3=args.skip_sam3, + ) + if result: + summary.append(result) + + # 요약 저장 + summary_path = output_dir / "summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + # 통계 출력 + print("\n" + "="*50) + print("처리 완료 요약") + print("="*50) + total_det = sum(len(r["detections"]) for r in summary) + total_mask = sum(r["masks"] for r in summary) + print(f"처리 이미지: {len(summary)}장") + print(f"YOLO 검출: {total_det}개 (평균 {total_det/max(len(summary),1):.1f}/장)") + print(f"SAM3 마스크: {total_mask}개 (평균 {total_mask/max(len(summary),1):.1f}/장)") + + # 클래스별 집계 + class_counts: dict = {} + for r in summary: + for d in r["detections"]: + cls = d["class"] + class_counts[cls] = class_counts.get(cls, 0) + 1 + if class_counts: + print("\n클래스별 검출 수:") + for cls, cnt in sorted(class_counts.items(), key=lambda x: -x[1]): + print(f" {cls:20s}: {cnt}개") + + print(f"\n결과 저장: {output_dir.resolve()}") + print(f"요약 JSON: {summary_path.resolve()}") + + +if __name__ == "__main__": + main()