프로젝트 분리 이동

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
minsung
2026-05-20 14:28:27 +09:00
commit ccba1266b5
24 changed files with 7900 additions and 0 deletions

67
.gitignore vendored Normal file
View File

@@ -0,0 +1,67 @@
# 가상환경 (용량 크므로 제외)
.venv/
.venv_client/
# 모델 파일 (용량 매우 큼)
*.pt
*.pth
*.onnx
*.gz
# 로그
X-AnyLabeling-Server/logs/
*.log
# Python
__pycache__/
*.pyc
*.pyo
# 캐시
.cache/
# IDE
.vscode/
.idea/
# OS
.DS_Store
Thumbs.db
# 임시 파일
create_shortcuts.ps1
# 클로드 개인 설정/메모리
.claude/
# X-AnyLabeling-Server (별도 git repo, submodule 아님)
X-AnyLabeling-Server/
# 용량 큰 데이터/결과 폴더
drone_2cm/
output/
runs/
data/
# 용량 큰 원본 영상 파일
*.tif
경부선_2-2구간_미션33-34(5cm).png
# 논문/문서 파일
*.pdf
*.txt
# 미정리 툴 (작업 중)
tools/render_polygons_rainbow.py
tools/render_yolo_labels.py
tools/sam3_sleeper_detect.py
tools/tif_hollow_detect.py
# 기타
src/
PRPs/
nul
triton-*.whl
pyproject.toml
Project_Status.md
.usage/

65
CLAUDE.md Normal file
View File

@@ -0,0 +1,65 @@
# CLAUDE.md
Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed.
**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment.
## 1. Think Before Coding
**Don't assume. Don't hide confusion. Surface tradeoffs.**
Before implementing:
- State your assumptions explicitly. If uncertain, ask.
- If multiple interpretations exist, present them - don't pick silently.
- If a simpler approach exists, say so. Push back when warranted.
- If something is unclear, stop. Name what's confusing. Ask.
## 2. Simplicity First
**Minimum code that solves the problem. Nothing speculative.**
- No features beyond what was asked.
- No abstractions for single-use code.
- No "flexibility" or "configurability" that wasn't requested.
- No error handling for impossible scenarios.
- If you write 200 lines and it could be 50, rewrite it.
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
## 3. Surgical Changes
**Touch only what you must. Clean up only your own mess.**
When editing existing code:
- Don't "improve" adjacent code, comments, or formatting.
- Don't refactor things that aren't broken.
- Match existing style, even if you'd do it differently.
- If you notice unrelated dead code, mention it - don't delete it.
When your changes create orphans:
- Remove imports/variables/functions that YOUR changes made unused.
- Don't remove pre-existing dead code unless asked.
The test: Every changed line should trace directly to the user's request.
## 4. Goal-Driven Execution
**Define success criteria. Loop until verified.**
Transform tasks into verifiable goals:
- "Add validation" → "Write tests for invalid inputs, then make them pass"
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
- "Refactor X" → "Ensure tests pass before and after"
For multi-step tasks, state a brief plan:
```
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
```
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
---
**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes.

143
configs/railway_zone.json Normal file
View File

@@ -0,0 +1,143 @@
{
"_comment": "철도 구간 집중 검출용 카테고리 설정 (타일 9-24 기준)",
"_conf_note": "conf: 클래스별 신뢰도 임계값. priority: 낮을수록 cross-class NMS에서 우선 보존(화면 위에 표시).",
"_priority_order": "2=컨트롤박스 > 3=전철주 > 4=팬스 > 5=레일 > 6=침목 > 7=자갈(맨 밑)",
"cross_class_nms_iou": 0.45,
"categories": [
{
"name": "control_box",
"name_kr": "컨트롤박스",
"prompt": "small square gray metal box beside rail, compact trackside junction box, small near-square electrical enclosure on the ground, small cube-shaped equipment box next to track",
"keywords": ["small square box", "junction box", "compact enclosure", "square metal box", "small cube box"],
"color_bgr": [255, 255, 0],
"conf": 0.15,
"priority": 2
},
{
"name": "catenary_pole",
"name_kr": "전철주",
"prompt": "railway catenary pole, overhead line pole, catenary mast, electric railway pole",
"keywords": ["catenary pole", "overhead line pole", "catenary mast", "railway pole"],
"color_bgr": [255, 0, 255],
"conf": 0.25,
"priority": 3
},
{
"name": "fence",
"name_kr": "팬스/울타리",
"prompt": "railway fence, trackside fence, perimeter fence, chain link fence, railway boundary fence",
"keywords": ["fence", "chain link", "perimeter fence", "boundary fence"],
"color_bgr": [0, 255, 0],
"conf": 0.50,
"priority": 5
},
{
"name": "railway",
"name_kr": "철도 레일",
"prompt": "two parallel steel rails, narrow metallic longitudinal beam, shiny steel rail line on sleeper",
"keywords": ["railroad", "railway rail", "steel rail", "train track", "parallel rail"],
"color_bgr": [0, 0, 255],
"conf": 0.25,
"priority": 4
},
{
"name": "sleeper",
"name_kr": "침목",
"prompt": "rectangular concrete sleeper, dark crosswise tie, evenly spaced railroad tie perpendicular to rail, concrete slab between ballast",
"keywords": ["sleeper", "railroad tie", "rail tie", "wooden tie", "concrete tie", "crosswise tie"],
"color_bgr": [0, 128, 255],
"conf": 0.20,
"priority": 6
},
{
"name": "ballast",
"name_kr": "자갈도상",
"prompt": "crushed gray stone gravel, railway ballast aggregate, coarse gravel track bed, angular stone between rails",
"keywords": ["ballast", "gravel between rails", "track bed", "crushed stone", "aggregate"],
"color_bgr": [30, 50, 100],
"conf": 0.20,
"priority": 7
},
{
"name": "bracket",
"name_kr": "브라켓/암",
"prompt": "catenary bracket arm, horizontal cantilever arm, pole cross arm, overhead wire bracket",
"keywords": ["bracket", "cantilever", "cross arm"],
"color_bgr": [0, 80, 200],
"conf": 0.20,
"priority": 8
},
{
"name": "bridge",
"name_kr": "교량/교각",
"prompt": "railway bridge, road bridge, overpass, viaduct, concrete bridge structure",
"keywords": ["bridge", "overpass", "viaduct"],
"color_bgr": [0, 165, 255],
"conf": 0.25,
"priority": 8
},
{
"name": "retaining_wall",
"name_kr": "방음벽/옹벽",
"prompt": "railway retaining wall, noise barrier wall, sound barrier, concrete embankment wall",
"keywords": ["retaining wall", "noise barrier", "sound barrier", "embankment wall"],
"color_bgr": [60, 180, 60],
"conf": 0.25,
"priority": 9
},
{
"name": "service_road",
"name_kr": "유지보수 도로",
"prompt": "railway maintenance road, service road alongside track, unpaved railway access road",
"keywords": ["service road", "maintenance road", "access road"],
"color_bgr": [80, 120, 160],
"conf": 0.30,
"priority": 10
},
{
"name": "culvert",
"name_kr": "암거/소교량",
"prompt": "railway culvert, drainage culvert, small underpass tunnel, concrete drainage structure",
"keywords": ["culvert", "underpass", "drainage tunnel"],
"color_bgr": [180, 60, 180],
"conf": 0.20,
"priority": 9
},
{
"name": "vehicle",
"name_kr": "차량",
"prompt": "car, truck, vehicle, automobile, van",
"keywords": ["car", "truck", "vehicle", "automobile", "van"],
"color_bgr": [255, 255, 255],
"conf": 0.25,
"priority": 2
},
{
"name": "building",
"name_kr": "건물",
"prompt": "building, house, rooftop, structure, roof",
"keywords": ["building", "house", "rooftop", "structure", "roof"],
"color_bgr": [50, 50, 255],
"conf": 0.25,
"priority": 8
},
{
"name": "farmland",
"name_kr": "농지",
"prompt": "farmland, agricultural field, cropland, vegetable garden",
"keywords": ["farmland", "field", "cropland", "vegetable", "agricultural"],
"color_bgr": [50, 200, 50],
"conf": 0.25,
"priority": 11
},
{
"name": "vegetation",
"name_kr": "식생",
"prompt": "trees, forest, shrubs, vegetation, bushes",
"keywords": ["tree", "forest", "shrub", "vegetation", "bush"],
"color_bgr": [0, 120, 0],
"conf": 0.25,
"priority": 12
}
]
}

301
tools/auto_rail_detect.py Normal file
View File

@@ -0,0 +1,301 @@
"""
auto_rail_detect.py
===================
항공 이미지에서 레일 라인을 자동 검출하여 Rhino용 DXF로 저장.
수동 라벨링 없이 이미지 1장에서 바로 실행.
사용법:
python tools/auto_rail_detect.py <image.png> [output.dxf]
예시:
python tools/auto_rail_detect.py "경부선_2-2구간_미션33-34(5cm).png"
python tools/auto_rail_detect.py "경부선.png" output/rail.dxf
원리:
이미지 → 엣지 검출 → Hough 직선 검출 → 방향별 클러스터링 → DXF 폴리라인
"""
import sys
import cv2
import numpy as np
from pathlib import Path
# ─── 설정 (조정 가능) ─────────────────────────────────────────────────────────
CANNY_LOW = 30 # Canny 엣지 낮은 임계값 (낮을수록 엣지 많이 검출)
CANNY_HIGH = 80 # Canny 엣지 높은 임계값
HOUGH_THRESHOLD = 100 # Hough 투표 임계값 (높을수록 긴 선만 검출)
HOUGH_MIN_LEN = 200 # 최소 선 길이 (픽셀)
HOUGH_MAX_GAP = 30 # 선 간격 허용 (픽셀, 클수록 끊긴 선도 연결)
ANGLE_TOLERANCE = 5 # 같은 방향으로 볼 각도 범위 (도)
CLUSTER_DIST = 20 # 같은 레일로 볼 선 간격 (픽셀)
MIN_TOTAL_LEN = 500 # 최종 레일 최소 총 길이 (픽셀, 짧은 잡선 제거)
# ──────────────────────────────────────────────────────────────────────────────
def detect_edges(img_gray):
"""CLAHE 대비 강화 → Gaussian 블러 → Canny 엣지"""
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(img_gray)
blurred = cv2.GaussianBlur(enhanced, (5, 5), 0)
edges = cv2.Canny(blurred, CANNY_LOW, CANNY_HIGH, apertureSize=3)
return edges
def detect_lines(edges):
"""확률적 Hough 변환으로 직선 검출"""
lines = cv2.HoughLinesP(
edges,
rho=1,
theta=np.pi / 180,
threshold=HOUGH_THRESHOLD,
minLineLength=HOUGH_MIN_LEN,
maxLineGap=HOUGH_MAX_GAP,
)
if lines is None:
return []
return [tuple(l[0]) for l in lines] # [(x1,y1,x2,y2), ...]
def line_angle(x1, y1, x2, y2):
"""선의 각도 (0~180도)"""
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180
return angle
def line_length(x1, y1, x2, y2):
return np.hypot(x2 - x1, y2 - y1)
def line_midpoint(x1, y1, x2, y2):
return ((x1 + x2) / 2, (y1 + y2) / 2)
def perpendicular_dist(x1, y1, x2, y2, px, py):
"""점 (px,py)에서 선 (x1,y1)-(x2,y2)까지 수직거리"""
dx, dy = x2 - x1, y2 - y1
length = np.hypot(dx, dy)
if length == 0:
return np.hypot(px - x1, py - y1)
return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / length
def cluster_lines(lines):
"""비슷한 방향 + 가까운 위치의 선들을 같은 레일로 묶기"""
if not lines:
return []
# 각도별 그룹 먼저 분리
angle_groups = {}
for line in lines:
x1, y1, x2, y2 = line
ang = round(line_angle(x1, y1, x2, y2) / ANGLE_TOLERANCE) * ANGLE_TOLERANCE
angle_groups.setdefault(ang, []).append(line)
clusters = []
for ang, grp in angle_groups.items():
# 같은 방향 그룹 내에서 거리 기반 클러스터링
used = [False] * len(grp)
for i, line_i in enumerate(grp):
if used[i]:
continue
cluster = [line_i]
used[i] = True
mx_i, my_i = line_midpoint(*line_i)
for j, line_j in enumerate(grp):
if used[j]:
continue
mx_j, my_j = line_midpoint(*line_j)
dist = perpendicular_dist(*line_i, mx_j, my_j)
if dist < CLUSTER_DIST:
cluster.append(line_j)
used[j] = True
clusters.append(cluster)
return clusters
def merge_cluster_to_polyline(cluster):
"""클러스터의 선들을 하나의 정렬된 폴리라인으로 합치기"""
# 모든 끝점 수집
all_pts = []
for x1, y1, x2, y2 in cluster:
all_pts.append((x1, y1))
all_pts.append((x2, y2))
if not all_pts:
return []
# 주성분 방향으로 정렬
pts_arr = np.array(all_pts, dtype=float)
mean = pts_arr.mean(axis=0)
centered = pts_arr - mean
cov = np.cov(centered.T)
if cov.ndim < 2:
direction = np.array([1.0, 0.0])
else:
eigenvalues, eigenvectors = np.linalg.eig(cov)
direction = eigenvectors[:, np.argmax(eigenvalues)]
# 방향으로 투영하여 정렬
projections = centered.dot(direction)
sorted_idx = np.argsort(projections)
sorted_pts = pts_arr[sorted_idx]
# 중복 제거 (가까운 점 합치기)
merged = [sorted_pts[0]]
for pt in sorted_pts[1:]:
if np.hypot(pt[0] - merged[-1][0], pt[1] - merged[-1][1]) > 5:
merged.append(pt)
return merged
def total_polyline_length(polyline):
if len(polyline) < 2:
return 0
total = 0
for i in range(len(polyline) - 1):
total += np.hypot(
polyline[i+1][0] - polyline[i][0],
polyline[i+1][1] - polyline[i][1]
)
return total
def save_debug_image(img, all_polylines, output_path):
"""검출 결과를 이미지에 시각화 (확인용)"""
vis = img.copy() if len(img.shape) == 3 else cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255),
(255, 0, 255), (255, 255, 0)]
for i, poly in enumerate(all_polylines):
color = colors[i % len(colors)]
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
cv2.polylines(vis, [pts], False, color, 2)
# 번호 표시
cx, cy = int(poly[len(poly)//2][0]), int(poly[len(poly)//2][1])
cv2.putText(vis, str(i+1), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX,
1.5, color, 3)
cv2.imwrite(str(output_path), vis)
print(f" 시각화 저장: {output_path}")
def process(image_path: str, dxf_path: str):
import ezdxf
print(f"[입력] {image_path}")
# 한글 경로 대응: numpy로 읽어서 cv2 디코딩
import numpy as np_io
with open(image_path, "rb") as f:
data = f.read()
img_arr = np_io.frombuffer(data, dtype=np_io.uint8)
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img is None:
print(f"오류: 이미지를 열 수 없습니다: {image_path}")
sys.exit(1)
H, W = img.shape[:2]
print(f"[이미지] {W} x {H} px")
# 전처리: 작은 이미지로 축소 (속도)
scale = 1.0
if W > 4000:
scale = 4000 / W
small = cv2.resize(img, (int(W * scale), int(H * scale)))
print(f" 축소: {int(W*scale)} x {int(H*scale)} (scale={scale:.2f})")
else:
small = img.copy()
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
# 엣지 검출
print("[1단계] 엣지 검출...")
edges = detect_edges(gray)
edge_count = int(edges.sum() / 255)
print(f" 엣지 픽셀: {edge_count:,}")
# Hough 직선 검출
print("[2단계] Hough 직선 검출...")
lines = detect_lines(edges)
print(f" 검출된 선: {len(lines)}")
if not lines:
print("선을 검출하지 못했습니다. HOUGH_THRESHOLD를 낮춰보세요.")
sys.exit(1)
# 클러스터링
print("[3단계] 레일 클러스터링...")
clusters = cluster_lines(lines)
print(f" 클러스터: {len(clusters)}")
# 폴리라인 변환 + 길이 필터
polylines = []
for cluster in clusters:
poly = merge_cluster_to_polyline(cluster)
length = total_polyline_length(poly)
if length >= MIN_TOTAL_LEN / scale:
# 원본 해상도로 역변환
poly_orig = [(p[0] / scale, p[1] / scale) for p in poly]
polylines.append((poly_orig, length / scale))
# 길이순 정렬
polylines.sort(key=lambda x: -x[1])
print(f" 유효 레일 라인: {len(polylines)}")
if not polylines:
print("유효한 레일을 찾지 못했습니다. MIN_TOTAL_LEN을 낮춰보세요.")
sys.exit(1)
for i, (poly, length) in enumerate(polylines):
print(f" 레일 {i+1}: {length:.0f}px ({length*0.05:.1f}m)")
# DXF 저장
print("[4단계] DXF 저장...")
doc = ezdxf.new("R2010")
msp = doc.modelspace()
doc.layers.add("RAIL_AUTO", color=1) # 빨강 — 자동검출 레일
doc.layers.add("RAIL_AUTO_LABEL", color=2) # 노랑 — 번호
for i, (poly, length) in enumerate(polylines):
# Y축 반전 (이미지→DXF)
dxf_pts = [(float(p[0]), float(-p[1])) for p in poly]
msp.add_lwpolyline(dxf_pts, dxfattribs={"layer": "RAIL_AUTO"})
# 중간점에 번호 텍스트
mid = poly[len(poly) // 2]
msp.add_text(
f"Rail_{i+1}",
dxfattribs={
"layer": "RAIL_AUTO_LABEL",
"height": 20,
"insert": (float(mid[0]), float(-mid[1])),
}
)
Path(dxf_path).parent.mkdir(parents=True, exist_ok=True)
doc.saveas(dxf_path)
print(f"[완료] DXF 저장: {dxf_path}")
print(f" 레이어: RAIL_AUTO(빨강) = {len(polylines)}개 레일 중심선")
# 디버그 이미지 저장
debug_path = Path(dxf_path).parent / (Path(dxf_path).stem + "_debug.jpg")
all_polys = [p for p, _ in polylines]
save_debug_image(small, [([(x*scale, y*scale) for x, y in p]) for p in all_polys],
debug_path)
print(f"\nRhino 사용법:")
print(f" 1. File -> Import -> {Path(dxf_path).name}")
print(f" 2. RAIL_AUTO 레이어 선택")
print(f" 3. Sweep1 명령 -> 레일 단면 선택 -> 3D 레일 완성")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("사용법: python tools/auto_rail_detect.py <image.png> [output.dxf]")
sys.exit(1)
img_path = sys.argv[1]
dxf_path = sys.argv[2] if len(sys.argv) >= 3 else str(
Path(img_path).with_suffix(".dxf")
)
process(img_path, dxf_path)

565
tools/detect_all_objects.py Normal file
View File

@@ -0,0 +1,565 @@
"""
이미지에서 객체를 SAM3.1로 검출하여 색상별로 시각화.
전략:
- cols×rows 타일로 분할 (overlap 중복)
- --tiles 로 처리할 타일 번호 지정 (예: 9-24, 1,5,9, 전체=all)
- --categories 로 JSON 설정 파일 로드 (카테고리·프롬프트·색상 정의)
- 타일당 SAM3.1 1회 호출 (모든 카테고리 프롬프트 합산)
- 병렬 처리(ThreadPoolExecutor) → NMS → 색상 시각화
사용법:
python tools/detect_all_objects.py \\
--input data/역사이미지/slope/DJI_20260306113839_0005.JPG \\
--categories configs/railway_zone.json \\
--tiles 9-24 \\
--cols 8 --rows 6 --overlap 0.10 --workers 4
"""
import argparse
import base64
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
# 기본 카테고리 (--categories 미지정 시 사용)
_DEFAULT_CATEGORIES = [
{"name": "railway", "prompt": "railroad track, railway rail, steel rail",
"keywords": ["railroad", "railway rail", "steel rail"], "color_bgr": [0, 200, 255]},
{"name": "catenary_pole", "prompt": "railway catenary pole, overhead line pole, catenary mast",
"keywords": ["catenary pole", "overhead line pole", "catenary mast"], "color_bgr": [255, 130, 0]},
{"name": "highway", "prompt": "highway road, expressway asphalt, paved road lane",
"keywords": ["highway", "expressway", "paved road"], "color_bgr": [160, 160, 160]},
{"name": "vehicle", "prompt": "car, truck, vehicle, automobile",
"keywords": ["car", "truck", "vehicle", "automobile"], "color_bgr": [0, 255, 0]},
{"name": "building", "prompt": "building, house, rooftop, structure",
"keywords": ["building", "house", "rooftop", "structure"], "color_bgr": [50, 50, 255]},
{"name": "farmland", "prompt": "farmland, agricultural field, cropland, vegetable garden",
"keywords": ["farmland", "field", "cropland", "vegetable"], "color_bgr": [50, 200, 50]},
{"name": "vegetation", "prompt": "trees, forest, shrubs, vegetation, bushes",
"keywords": ["tree", "forest", "shrub", "vegetation", "bush"], "color_bgr": [0, 120, 0]},
{"name": "guardrail", "prompt": "guardrail, highway barrier, road fence, crash barrier",
"keywords": ["guardrail", "highway barrier", "road fence", "crash barrier"], "color_bgr": [200, 0, 200]},
{"name": "bridge", "prompt": "bridge, overpass, viaduct",
"keywords": ["bridge", "overpass", "viaduct"], "color_bgr": [0, 165, 255]},
{"name": "wire", "prompt": "overhead wire, catenary wire, electric cable line",
"keywords": ["catenary wire", "overhead wire", "electric cable"], "color_bgr": [200, 200, 255]},
]
# ── 타일 번호 파싱 ────────────────────────────────────────────────────────────
def parse_tiles(tile_str: str, total: int) -> set:
"""'9-24', '1,3,5', 'all' → tile index 집합 (1-based)."""
if tile_str.lower() == "all":
return set(range(1, total + 1))
result = set()
for part in tile_str.split(","):
part = part.strip()
if "-" in part:
a, b = part.split("-", 1)
result.update(range(int(a), int(b) + 1))
else:
result.add(int(part))
return result
# ── 카테고리 로드 ─────────────────────────────────────────────────────────────
def load_categories(json_path: str | None) -> list:
if json_path:
data = json.loads(Path(json_path).read_text(encoding="utf-8"))
return data["categories"]
return _DEFAULT_CATEGORIES
def label_to_category(label: str, categories: list) -> int:
label_l = label.lower()
for i, cat in enumerate(categories):
for kw in cat["keywords"]:
if kw in label_l:
return i
return -1
def build_combined_prompt(categories: list) -> str:
return ", ".join(cat["prompt"] for cat in categories)
# ── SAM3 호출 ─────────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
def sam3_segment_tile(tile_bgr: np.ndarray, prompt: str, conf: float) -> list:
b64, scale = encode_image(tile_bgr)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {"text_prompt": prompt, "conf_threshold": conf,
"show_masks": True, "show_boxes": False},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception:
return []
# ── NMS ───────────────────────────────────────────────────────────────────────
def _bbox(pts):
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
return min(xs), min(ys), max(xs), max(ys)
def _nms_core(shapes: list, iou_thresh: float) -> list:
"""IoU 기반 NMS. shapes 각 항목에 score 필드 필요."""
if not shapes:
return []
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]; keep.append(i)
if len(order) == 1: break
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [shapes[i] for i in keep]
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
return _nms_core(shapes, iou_thresh)
def _poly_orient(points: list, H: int, W: int) -> str: # post_merge_poles.py에서도 사용
"""폴리곤 장축 방향 판별 (render_skeleton_overlay.py 동일 로직).
V: 장축이 이미지 중심에서 방사형 방향과 정렬 (cos_sim > 0.7) → 세로 기둥
H: 장축이 radial 직교 방향 → 수평 빔
?: aspect ratio < 1.3 으로 판별 불가
"""
pts = np.array(points, dtype=np.float32)
rect = cv2.minAreaRect(pts)
(rx, ry), (rw, rh), angle = rect
if min(rw, rh) < 1:
return '?'
ar = max(rw, rh) / min(rw, rh)
if ar < 1.3:
return '?'
long_angle_deg = angle if rw >= rh else angle + 90
lx = float(np.cos(np.radians(long_angle_deg)))
ly = float(np.sin(np.radians(long_angle_deg)))
img_cx, img_cy = W / 2.0, H / 2.0
rdx, rdy = rx - img_cx, ry - img_cy
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
if radial_norm < 1:
return '?'
rdx, rdy = rdx / radial_norm, rdy / radial_norm
cos_sim = abs(lx * rdx + ly * rdy)
return 'V' if cos_sim > 0.7 else 'H'
def merge_nonramen_poles(shapes: list, H: int, W: int,
x_overlap_thresh: float = 0.30,
y_gap_thresh: int = 150) -> list:
"""타일 경계 분할된 전철주 병합 — V+V 조합만 허용.
_poly_orient로 각 폴리곤 V/H 분류.
두 폴리곤 모두 V(세로 기둥)이고 공간 기준 충족 시만 병합.
H(수평 빔) 포함 쌍 = 라멘 관련 조각 → 병합 건너뜀.
"""
if len(shapes) <= 1:
return shapes
orients = [_poly_orient(s["points"], H, W) for s in shapes]
v_count = sum(1 for o in orients if o == 'V')
h_count = sum(1 for o in orients if o == 'H')
print(f" [orient] V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}")
def get_bbox(s):
xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]]
return min(xs), min(ys), max(xs), max(ys)
def x_overlap_ratio(b1, b2):
ox = min(b1[2], b2[2]) - max(b1[0], b2[0])
ux = max(b1[2], b2[2]) - min(b1[0], b2[0])
return ox / ux if ux > 0 else 0.0
def y_gap(b1, b2):
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
def merge_two(s1, s2):
mask = np.zeros((H, W), dtype=np.uint8)
for s in (s1, s2):
cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return s1
c = max(contours, key=cv2.contourArea)
eps = 0.002 * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, eps, True)
merged = dict(s1)
merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx]
merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0)))
return merged
merged_flags = [False] * len(shapes)
result = []
merged_count = 0
for i in range(len(shapes)):
if merged_flags[i]:
continue
cur = shapes[i]
cur_ori = orients[i]
cb = get_bbox(cur)
for j in range(i + 1, len(shapes)):
if merged_flags[j]:
continue
if cur_ori != 'V' or orients[j] != 'V':
continue # 둘 다 V가 아니면 병합 안 함
jb = get_bbox(shapes[j])
if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh:
cur = merge_two(cur, shapes[j])
cur_ori = 'V'
cb = get_bbox(cur)
merged_flags[j] = True
merged_count += 1
result.append(cur)
print(f" [merge] 병합={merged_count}")
return result
def cross_class_nms(buckets: list, categories: list, iou_thresh: float) -> list:
"""클래스 간 NMS: 동일 영역에 다른 클래스가 중복 검출될 때 우선순위 높은 쪽 보존.
정렬 기준: (priority 오름차순, score 내림차순)
→ priority 낮은 값(=중요 클래스)이 우선 보존됨.
"""
# 모든 shape에 클래스 인덱스 태깅
tagged = []
for i, shapes in enumerate(buckets):
priority = categories[i].get("priority", 99)
for s in shapes:
tagged.append((priority, -float(s.get("score", 0.5)), i, s))
# priority 오름차순, score 내림차순 정렬
tagged.sort(key=lambda x: (x[0], x[1]))
if not tagged:
return [[] for _ in buckets]
all_shapes = [t[3] for t in tagged]
cls_ids = [t[2] for t in tagged]
bboxes = np.array([_bbox(s["points"]) for s in all_shapes], dtype=np.float32)
suppressed = [False] * len(all_shapes)
for i in range(len(all_shapes)):
if suppressed[i]:
continue
for j in range(i + 1, len(all_shapes)):
if suppressed[j]:
continue
xx1 = max(bboxes[i,0], bboxes[j,0])
yy1 = max(bboxes[i,1], bboxes[j,1])
xx2 = min(bboxes[i,2], bboxes[j,2])
yy2 = min(bboxes[i,3], bboxes[j,3])
inter = max(0, xx2-xx1) * max(0, yy2-yy1)
if inter == 0:
continue
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[j,2]-bboxes[j,0])*(bboxes[j,3]-bboxes[j,1])
iou = inter / (a_i + a_j - inter + 1e-6)
if iou >= iou_thresh:
suppressed[j] = True # i가 우선순위 높으므로 j 제거
new_buckets = [[] for _ in buckets]
for i, (keep, cls_i, s) in enumerate(zip(suppressed, cls_ids, all_shapes)):
if not keep:
new_buckets[cls_i].append(s)
return new_buckets
# ── 타일 그리드 검출 (병렬) ───────────────────────────────────────────────────
def detect_tiled(image_bgr: np.ndarray, cols: int, rows: int, overlap: float,
conf: float, workers: int, tile_filter: set,
prompt: str) -> list:
H, W = image_bgr.shape[:2]
base_w = W / cols
base_h = H / rows
pad_x = int(base_w * overlap)
pad_y = int(base_h * overlap)
tiles = []
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
if idx not in tile_filter:
continue
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
tiles.append((idx, x0, y0, x1, y1))
total = len(tiles)
done = [0]
all_shapes = []
def process(args):
idx, x0, y0, x1, y1 = args
tile = image_bgr[y0:y1, x0:x1]
shapes = sam3_segment_tile(tile, prompt, conf)
for s in shapes:
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
return shapes
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(process, t): t for t in tiles}
for fut in as_completed(futs):
all_shapes.extend(fut.result())
done[0] += 1
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}", end="\r")
print()
return all_shapes
# ── 시각화 ────────────────────────────────────────────────────────────────────
def draw_detections(image_bgr: np.ndarray, buckets: list,
categories: list, tile_filter: set,
cols: int, rows: int, overlap: float) -> np.ndarray:
vis = image_bgr.copy()
H, W = vis.shape[:2]
# 처리된 타일 경계 표시
base_w = W / cols
base_h = H / rows
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
if idx in tile_filter:
bx0, by0 = int(c * base_w), int(r * base_h)
bx1, by1 = min(W, int((c+1)*base_w)), min(H, int((r+1)*base_h))
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (255, 255, 255), 1)
# 마스크 + 순번 레이블
font = cv2.FONT_HERSHEY_SIMPLEX
font_sc = max(0.4, min(W, H) / 8000)
thickness = max(1, int(font_sc * 2))
for i, cat in enumerate(categories):
color = tuple(cat["color_bgr"])
prefix = cat["name"][:3].upper() # 예: RAI, CAT, BRA …
for seq, s in enumerate(buckets[i], start=1):
pts = np.array(s["points"], dtype=np.int32)
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis)
cv2.polylines(vis, [pts], True, color, 2)
# 무게중심에 순번 표시
cx = int(np.mean(pts[:, 0]))
cy = int(np.mean(pts[:, 1]))
score = float(s.get("score", 0.0))
label = f"{prefix}{seq:03d} {score:.2f}"
(tw, th), _ = cv2.getTextSize(label, font, font_sc, thickness)
tx, ty = cx - tw // 2, cy + th // 2
# 배경 박스는 검정, 텍스트는 흰색으로 변경
cv2.rectangle(vis, (tx - 2, ty - th - 2), (tx + tw + 2, ty + 2),
(0, 0, 0), -1)
cv2.putText(vis, label, (tx, ty), font, font_sc,
(255, 255, 255), thickness, cv2.LINE_AA)
# 범례
lx, ly = W - 280, 20
panel_h = len(categories) * 24 + 10
vis[0:panel_h, lx-8:W] = (vis[0:panel_h, lx-8:W] * 0.35).astype(np.uint8)
for i, cat in enumerate(categories):
color = tuple(cat["color_bgr"])
prefix = cat["name"][:3].upper()
cv2.rectangle(vis, (lx, ly-13), (lx+15, ly+3), color, -1)
cv2.putText(vis, f"[{prefix}] {cat['name']} ({len(buckets[i])})",
(lx+20, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.50, color, 1, cv2.LINE_AA)
ly += 24
return vis
# ── 메인 ──────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__)
ap.add_argument("--input", required=True, help="입력 이미지 경로")
ap.add_argument("--output", default=None, help="출력 이미지 경로 (기본: 입력파일명_out.jpg)")
ap.add_argument("--categories", default=None, help="카테고리 JSON 경로 (기본: 내장 10개)")
ap.add_argument("--tiles", default="all", help="처리할 타일 번호: 9-24 / 1,5,9 / all (기본: all)")
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본: 8)")
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본: 6)")
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본: 0.10)")
ap.add_argument("--conf", type=float, default=0.20, help="SAM3 신뢰도 임계값 (기본: 0.20)")
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본: 4)")
ap.add_argument("--save-labels", action="store_true", help="YOLO 폴리곤 포맷 .txt 라벨 파일 저장")
ap.add_argument("--save-json", action="store_true", help="AnyLabeling JSON 포맷 저장 (railway.json 동일 양식)")
args = ap.parse_args()
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f"이미지 로드 실패: {img_path}"); return
H, W = image_bgr.shape[:2]
total_tiles = args.cols * args.rows
tile_filter = parse_tiles(args.tiles, total_tiles)
categories = load_categories(args.categories)
combined_prompt = build_combined_prompt(categories)
# cross-class NMS IoU: JSON > 기본값 0.45
cc_iou = 0.45
if args.categories:
raw = json.loads(Path(args.categories).read_text(encoding="utf-8"))
cc_iou = raw.get("cross_class_nms_iou", cc_iou)
# SAM3 호출용 conf = 모든 카테고리 conf 중 최솟값 (낮은 쪽부터 받아 후처리로 필터)
sam3_conf = min(cat.get("conf", args.conf) for cat in categories)
print(f"이미지 : {W}×{H}")
print(f"타일 그리드: {args.cols}×{args.rows}={total_tiles}개 | 처리 대상: {sorted(tile_filter)}")
print(f"카테고리 : {len(categories)}개 | 중복: {args.overlap*100:.0f}%")
print(f"SAM3 conf : {sam3_conf} (전체 최솟값) | cross-class NMS IoU: {cc_iou}")
print(f"SAM3 호출 : {len(tile_filter)}회 | 병렬: {args.workers}스레드\n")
t0 = time.time()
all_shapes = detect_tiled(image_bgr, args.cols, args.rows, args.overlap,
sam3_conf, args.workers, tile_filter, combined_prompt)
print(f"전체 검출 {len(all_shapes)}개 → 분류 + per-class conf 필터 + NMS...")
buckets = [[] for _ in categories]
unmatched = 0
for s in all_shapes:
idx = label_to_category(s.get("label", ""), categories)
if idx < 0:
unmatched += 1
continue
# per-class conf 필터
cat_conf = categories[idx].get("conf", args.conf)
if float(s.get("score", 0.0)) < cat_conf:
continue
buckets[idx].append(s)
# 클래스 내 NMS
print(" [1] 클래스 내 NMS")
for i, cat in enumerate(categories):
before = len(buckets[i])
buckets[i] = nms_shapes(buckets[i])
print(f" {cat['name']:18s}: {before:3d}{len(buckets[i]):3d}개 (conf≥{cat.get('conf', args.conf)})")
# 클래스 간 NMS
total_before = sum(len(b) for b in buckets)
print(f" [2] cross-class NMS (IoU≥{cc_iou})")
buckets = cross_class_nms(buckets, categories, cc_iou)
total_after = sum(len(b) for b in buckets)
print(f" {total_before}개 → {total_after}")
if unmatched:
print(f" (미분류/conf미달 {unmatched}개 제외)")
print(f"\n완료: {time.time()-t0:.0f}")
vis = draw_detections(image_bgr, buckets, categories,
tile_filter, args.cols, args.rows, args.overlap)
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w*s), int(h*s)))
if args.output:
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
else:
tile_tag = args.tiles.replace(",", "_").replace("-", "to")
cat_tag = Path(args.categories).stem if args.categories else "default"
out_dir = Path("output") / "detect" / img_path.stem
out_dir.mkdir(parents=True, exist_ok=True)
base_name = f"tiles{tile_tag}_{cat_tag}"
n = 1
while True:
out_path = out_dir / f"{base_name}_{n:03d}.jpg"
if not out_path.exists():
break
n += 1
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
print(f"저장: {out_path}")
if args.save_labels:
label_path = out_path.with_suffix(".txt")
with open(label_path, "w", encoding="utf-8") as f:
for cls_idx, shapes in enumerate(buckets):
for s in shapes:
pts_norm = [[px / W, py / H] for px, py in s["points"]]
coords = " ".join(f"{x:.6f} {y:.6f}" for x, y in pts_norm)
f.write(f"{cls_idx} {coords}\n")
print(f"라벨 저장: {label_path}")
if args.save_json:
import json as _json
json_shapes = []
for cls_idx, shapes in enumerate(buckets):
cat_name = categories[cls_idx]["name"] if cls_idx < len(categories) else str(cls_idx)
for s in shapes:
json_shapes.append({
"label": cat_name,
"score": float(s.get("score", 0.0)),
"points": [[float(px), float(py)] for px, py in s["points"]],
"group_id": None,
"description": None,
"shape_type": "polygon",
"flags": None,
})
anylabel = {
"version": "3.3.9",
"flags": {},
"shapes": json_shapes,
"imagePath": img_path.name,
"imageData": None,
"imageHeight": H,
"imageWidth": W,
}
json_path = out_path.with_suffix(".json")
json_path.write_text(_json.dumps(anylabel, ensure_ascii=False, indent=2),
encoding="utf-8")
print(f"JSON 저장: {json_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
"""
얇은 중공단면(도너츠 형태) 검출 및 표시 스크립트
드론 영상에서 전철주 상단 원형 단면을 찾아 표시
사용법:
python detect_hollow_section.py <image> [--radius <px>] [--tol <px>] [--topk <n>]
--radius : 3D 프로그램 줌 기준 단면 반지름 (픽셀). 미지정 시 자동 탐색.
--tol : radius ± 허용 오차 (기본 30px)
--topk : 표시할 최대 후보 수 (기본 3)
"""
import argparse
import cv2
import numpy as np
from pathlib import Path
def ring_score(edges: np.ndarray, cx: int, cy: int, r: int,
H: int, W: int, inner_ratio: float = 0.60) -> float:
"""
링(도넛) 특성 점수.
- 링 영역(외곽~내곽) 엣지 강도 / 내부 엣지 강도 비율 (최대 5 클리핑)
- 이미지 중심 근접도 보정 (중심에 가까울수록 가산점)
Canny edges는 외부에서 한 번만 계산해 전달받음.
"""
inner_r = max(int(r * inner_ratio), 3)
mask_full = np.zeros((H, W), np.uint8)
mask_inner = np.zeros((H, W), np.uint8)
cv2.circle(mask_full, (cx, cy), r, 255, -1)
cv2.circle(mask_inner, (cx, cy), inner_r, 255, -1)
mask_ring = cv2.subtract(mask_full, mask_inner)
if mask_ring.any() == 0 or mask_inner.any() == 0:
return 0.0
ring_edge = float(edges[mask_ring > 0].mean())
inner_edge = float(edges[mask_inner > 0].mean()) + 1e-6
edge_ratio = min(ring_edge / inner_edge, 5.0)
dist_from_center = np.hypot(cx - W / 2, cy - H / 2)
max_dist = np.hypot(W / 2, H / 2)
center_bonus = 1.0 - dist_from_center / max_dist # 0~1
return edge_ratio * (0.7 + 0.3 * center_bonus)
def detect_hollow_section(
image_path: str,
output_path: str | None = None,
radius_px: int | None = None, # 3D 프로그램 줌 기준 알려진 반지름
tol_px: int = 30, # radius_px ± 허용 오차
top_k: int = 3,
) -> str:
img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
if img is None:
raise FileNotFoundError(f"이미지 로드 실패: {image_path}")
H, W = img.shape[:2]
print(f"이미지 크기: {W}×{H}")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
blurred = cv2.GaussianBlur(clahe.apply(gray), (7, 7), 2)
# Canny 엣지를 미리 한 번만 계산 (ring_score에서 재사용)
edges = cv2.Canny(blurred, 30, 90)
# 반경 범위 결정
if radius_px is not None:
min_r = max(5, radius_px - tol_px)
max_r = radius_px + tol_px
else:
short = min(W, H)
min_r = max(10, short // 20)
max_r = short // 3
print(f"탐색 반경: {min_r} ~ {max_r} px"
+ (f" (기준={radius_px}px ±{tol_px})" if radius_px else " (자동)"))
circles = cv2.HoughCircles(
blurred,
cv2.HOUGH_GRADIENT,
dp=1.0,
minDist=min_r * 3,
param1=80,
param2=35,
minRadius=min_r,
maxRadius=max_r,
)
result = img.copy()
if circles is None:
print("HoughCircles 미검출")
else:
circles = np.round(circles[0]).astype(int)
print(f"HoughCircles 후보: {len(circles)}개 → 링 스코어 계산 중...")
scored = sorted(
[(ring_score(edges, cx, cy, r, H, W), cx, cy, r)
for cx, cy, r in circles],
reverse=True,
)
# 1위: 녹색(굵게), 2위: 노랑, 3위: 파랑
colors = [(0, 255, 0), (0, 220, 255), (255, 120, 0)]
print(f"\n상위 {min(top_k, len(scored))}개 중공단면 후보:")
for rank, (score, cx, cy, r) in enumerate(scored[:top_k]):
color = colors[rank % len(colors)]
inner_r = max(int(r * 0.60), 5)
lw = 3 if rank == 0 else 2
cv2.circle(result, (cx, cy), r, color, lw)
cv2.circle(result, (cx, cy), inner_r, color, 1)
cv2.circle(result, (cx, cy), 6, (0, 0, 255), -1)
arm = r + 15
cv2.line(result, (cx - arm, cy), (cx + arm, cy), color, 1)
cv2.line(result, (cx, cy - arm), (cx, cy + arm), color, 1)
cv2.putText(result, f"#{rank+1} r={r}px s={score:.2f}",
(cx - r, cy - r - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 2)
print(f" #{rank+1}: 중심=({cx},{cy}), 반지름={r}px, 링스코어={score:.3f}")
if output_path is None:
p = Path(image_path)
output_path = str(p.parent / (p.stem + "_hollow_detected" + p.suffix))
cv2.imencode(Path(output_path).suffix, result)[1].tofile(output_path)
print(f"\n결과 저장: {output_path}")
return output_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="얇은 중공단면 검출")
parser.add_argument("image", nargs="?",
default=r"d:\MYCLAUDE_PROJECT\x-anylabeling01\data\Message_2026-04-23T11_13_29+09_00.png")
parser.add_argument("--radius", type=int, default=None,
help="3D 프로그램 줌 기준 단면 반지름(px). 미지정 시 자동 탐색.")
parser.add_argument("--tol", type=int, default=30,
help="radius ± 허용 오차 (기본 30px)")
parser.add_argument("--topk", type=int, default=3,
help="표시할 최대 후보 수 (기본 3)")
args = parser.parse_args()
detect_hollow_section(
args.image,
radius_px=args.radius,
tol_px=args.tol,
top_k=args.topk,
)

666
tools/detect_raamen.py Normal file
View File

@@ -0,0 +1,666 @@
"""
드론 사선 촬영 이미지에서 라멘형(門자형) 전철주 검출.
기존 픽셀 위상수학(Skeleton) 방식을 공간 기하학 방식으로 교체.
파이프라인:
Phase 1: 폴리곤 단순화(approxPolyDP) + 소실점(Vanishing Point) 계산
Phase 2: 동적 V/H 분류 (소실점 기반 기대 각도)
Phase 3: 근접성 기반 그룹핑 (H 앵커 → 아래 V 탐색)
Phase 4: 라멘 구조 판정 + 예외(가림) 처리
사용:
python tools/detect_raamen.py \
--image <path> --label <path> --output <path> \
[--class-ids 1] [--epsilon 4.0] [--v-thresh 20.0]
"""
import argparse
import numpy as np
import cv2
from pathlib import Path
# ── Phase 1: 파싱 + 단순화 + 소실점 ─────────────────────────────────────
def load_polygons(label_path, W, H, class_ids=None, class_names=None):
"""
AnyLabeling JSON 또는 YOLO .txt 자동 감지 후 파싱.
Returns: (픽셀 좌표 폴리곤 리스트, 절대 shapes[] 인덱스 리스트)
절대 인덱스는 AnyLabeling JSON shapes[] 배열 인덱스와 동일.
"""
import json as _json
if not label_path.exists():
raise FileNotFoundError(f"라벨 파일을 찾을 수 없습니다: {label_path}")
if label_path.suffix.lower() == ".json":
data = _json.loads(label_path.read_text(encoding="utf-8"))
shapes = data.get("shapes", [])
polys, abs_indices = [], []
for abs_idx, s in enumerate(shapes):
if class_names is not None and s.get("label", "") not in class_names:
continue
pts = np.array([[float(p[0]), float(p[1])] for p in s.get("points", [])],
dtype=np.float32)
if len(pts) >= 3:
polys.append(pts)
abs_indices.append(abs_idx)
return polys, abs_indices
# YOLO .txt
polys, abs_indices = [], []
for abs_idx, line in enumerate(label_path.read_text(encoding="utf-8").splitlines()):
parts = line.split()
if not parts:
continue
if class_ids is not None and int(parts[0]) not in class_ids:
continue
coords = list(map(float, parts[1:]))
pts = np.array(
[[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)],
dtype=np.float32,
)
if len(pts) >= 3:
polys.append(pts)
abs_indices.append(abs_idx)
return polys, abs_indices
def smooth_polygon(pts, epsilon):
"""approxPolyDP로 노이즈 경계 → 직선 위주 단순화."""
approx = cv2.approxPolyDP(pts.astype(np.int32), epsilon, closed=True)
result = approx.reshape(-1, 2).astype(np.float32)
return result if len(result) >= 3 else pts.astype(np.float32)
def _minrect(pts):
"""cv2.minAreaRect 래퍼 (int32 변환 포함)."""
return cv2.minAreaRect(pts.astype(np.int32))
def _long_axis_angle(pts):
"""minAreaRect 장축 각도 (degrees) 반환."""
(cx, cy), (rw, rh), angle = _minrect(pts)
return angle if rw >= rh else angle + 90, cx, cy, max(rw, rh), min(rw, rh)
def _radial_cos_sim(pts, img_cx, img_cy):
"""VP 후보 bootstrap용 기존 radial cos_sim."""
long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts)
if short_side < 1 or long_side / short_side < 1.3:
return 0.0
lx = np.cos(np.radians(long_angle_deg))
ly = np.sin(np.radians(long_angle_deg))
rdx, rdy = cx - img_cx, cy - img_cy
n = (rdx ** 2 + rdy ** 2) ** 0.5
return abs(lx * rdx / n + ly * rdy / n) if n > 1 else 0.0
def compute_vanishing_point(polys):
"""
후보 폴리곤 장축 선분들의 최소자승 교점 (소실점) 계산.
각 장축 선분: 법선 n=(-dy, dx), 방정식 -dy*x + dx*y = -dy*cx + dx*cy
"""
A_rows, b_vals = [], []
for pts in polys:
long_angle_deg, cx, cy, _, _ = _long_axis_angle(pts)
dx = np.cos(np.radians(long_angle_deg))
dy = np.sin(np.radians(long_angle_deg))
A_rows.append([-dy, dx])
b_vals.append(-dy * cx + dx * cy)
vp, *_ = np.linalg.lstsq(np.array(A_rows), np.array(b_vals), rcond=None)
return float(vp[0]), float(vp[1])
def _estimate_vp_iterative(polys, seed_indices, v_thresh, h_max_diff, vp_min_len,
x_horiz_thresh=10.0, max_iter=6):
"""초기 후보에서 반복 정제 VP 추정. Returns (vp_x, vp_y, n_v, orients, adiffs)."""
n = len(polys)
orients = ['?'] * n
adiffs = [90.0] * n
if len(seed_indices) < 2:
return None, None, 0, orients, adiffs
vp_x, vp_y = compute_vanishing_point([polys[i] for i in seed_indices])
for _ in range(max_iter):
for i, pts in enumerate(polys):
orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff,
x_horiz_thresh)
v_cands = [i for i in range(n)
if orients[i] == 'V' and _long_axis_angle(polys[i])[3] > vp_min_len]
if len(v_cands) < 3:
break
nx, ny = compute_vanishing_point([polys[i] for i in v_cands])
shift = ((nx - vp_x) ** 2 + (ny - vp_y) ** 2) ** 0.5
vp_x, vp_y = nx, ny
if shift < 5.0:
for i, pts in enumerate(polys):
orients[i], adiffs[i] = classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff,
x_horiz_thresh)
break
return vp_x, vp_y, orients.count('V'), orients, adiffs
# ── Phase 2: 동적 V/H 분류 ──────────────────────────────────────────────
def classify_vh(pts, vp_x, vp_y, v_thresh, h_max_diff=75.0, x_horiz_thresh=10.0):
"""
소실점 기준 V/H 분류.
- 이미지 절대 수평(X축 ±x_horiz_thresh°) AND AR≥4 → '?' (레일·전선 등)
- diff < v_thresh → V (기둥)
- v_thresh ≤ diff < h_max_diff → H (빔)
- diff ≥ h_max_diff → '?'
Returns: (orient, angle_diff_deg)
"""
long_angle_deg, cx, cy, long_side, short_side = _long_axis_angle(pts)
if short_side < 1 or long_side / short_side < 1.3:
return '?', 90.0
# 절대 수평 제외: X축 ±x_horiz_thresh° 이내 + AR≥4 (레일·전선 등 가느다란 수평체)
abs_from_horiz = long_angle_deg % 180.0
if abs_from_horiz > 90.0:
abs_from_horiz = 180.0 - abs_from_horiz
if abs_from_horiz < x_horiz_thresh and long_side / short_side >= 4.0:
return '?', 90.0
# VP 기준 상대 각도 분류
exp_angle = np.degrees(np.arctan2(vp_y - cy, vp_x - cx))
diff = abs(long_angle_deg - exp_angle) % 180.0
if diff > 90.0:
diff = 180.0 - diff
if diff < v_thresh:
return 'V', diff
if diff < h_max_diff:
return 'H', diff
return '?', diff
# ── Phase 3: 폴리곤 접촉/교차 기반 그룹핑 ───────────────────────────────
def connectivity_groups(polys, orients, margin=30):
"""
폴리곤 bbox가 margin px 이내로 닿거나 교차하면 같은 그룹 (H/V 구분 없음).
Union-Find로 연결된 폴리곤들을 묶은 뒤, 각 그룹 내에서 H/V 목록 분리.
Returns: list of {'id': int, 'H': [idx,...], 'V': [idx,...]}
"""
n = len(polys)
parent = list(range(n))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a, b):
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
# 폴리곤 bbox 미리 계산
bboxes = []
for pts in polys:
bboxes.append((pts[:, 0].min(), pts[:, 1].min(),
pts[:, 0].max(), pts[:, 1].max()))
# bbox가 margin 이내로 닿거나 교차하면 union
for i in range(n):
ax0, ay0, ax1, ay1 = bboxes[i]
for j in range(i + 1, n):
bx0, by0, bx1, by1 = bboxes[j]
if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and
ay1 + margin >= by0 and by1 + margin >= ay0):
union(i, j)
# 연결 컴포넌트 수집
from collections import defaultdict
comp = defaultdict(list)
for i in range(n):
comp[find(i)].append(i)
groups = []
for gid, members in enumerate(comp.values(), 1):
h_list = sorted(i for i in members if orients[i] == 'H')
v_list = sorted(i for i in members if orients[i] == 'V')
groups.append({'id': gid, 'H': h_list, 'V': v_list})
# 면적 내림차순으로 ID 재부여 (큰 그룹이 G1)
for g in groups:
g['area'] = sum(cv2.contourArea(polys[i].astype(np.int32))
for i in g['H'] + g['V'])
groups.sort(key=lambda x: x['area'], reverse=True)
for gid, g in enumerate(groups, 1):
g['id'] = gid
return groups
# ── Phase 4: 라멘 구조 판정 ─────────────────────────────────────────────
def _cluster_polys(indices, polys, margin=60):
"""
인접 폴리곤을 클러스터링 → 물리적 객체(기둥) 단위 반환.
Returns: list of [idx, ...] clusters
"""
if not indices:
return []
n = len(indices)
parent = list(range(n))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
bboxes = [(polys[i][:, 0].min(), polys[i][:, 1].min(),
polys[i][:, 0].max(), polys[i][:, 1].max()) for i in indices]
for a in range(n):
ax0, ay0, ax1, ay1 = bboxes[a]
for b in range(a + 1, n):
bx0, by0, bx1, by1 = bboxes[b]
if (ax1 + margin >= bx0 and bx1 + margin >= ax0 and
ay1 + margin >= by0 and by1 + margin >= ay0):
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
from collections import defaultdict
clusters = defaultdict(list)
for i in range(n):
clusters[find(i)].append(indices[i])
return list(clusters.values())
def judge_raamen(group, polys, W, center_ratio=0.2, v_cluster_margin=60):
"""
라멘 구조 판정.
- 빔 기준: 그룹 내 가장 큰 H 폴리곤
- 기둥 수: V 폴리곤을 근접 클러스터링한 클러스터 수
- 빔 x 범위(±50%) 밖의 V 클러스터는 잡폴리곤으로 무시
Returns: ('RAAMEN' | 'RAAMEN_OCCLUDED' | 'PARTIAL' | '', n_poles)
"""
hs, vs = group['H'], group['V']
# H 없음: 중앙 영역이면 V/H 분류 자체가 신뢰 불가 (기둥·빔 각도 수렴)
# → 2개 이상의 V 폴리곤이 중앙에 모여 있으면 RAAMEN_CENTER 처리
if not hs:
if len(vs) >= 2:
all_v_pts = np.vstack([polys[i] for i in vs])
gcx = float(all_v_pts[:, 0].mean())
if abs(gcx - W / 2) < W * center_ratio:
return 'RAAMEN_CENTER', len(vs)
return '', 0
# 큰 H 폴리곤 최대 2개를 빔 기준으로 사용 (2nd가 1st 면적의 50% 이상이면 포함)
h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True)
main_hs = [h_by_area[0]]
if len(h_by_area) > 1:
a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32))
a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32))
if a1 >= a0 * 0.5:
main_hs.append(h_by_area[1])
hx0 = int(min(polys[i][:, 0].min() for i in main_hs))
hx1 = int(max(polys[i][:, 0].max() for i in main_hs))
hcx = (hx0 + hx1) / 2.0
is_center = abs(hcx - W / 2) < W * center_ratio
span = hx1 - hx0
# V 폴리곤 클러스터링 → 기둥 단위
pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin)
# 기둥은 빔 양 끝단(좌 35% / 우 35%)에만 존재. 중앙부 클러스터는 부속물로 제외.
x_tol = max(span * 0.5, 50)
left_zone = hx0 + span * 0.35 # 좌끝단 경계
right_zone = hx1 - span * 0.35 # 우끝단 경계
valid_cxs = []
for cluster in pole_clusters:
ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster]))
in_range = hx0 - x_tol <= ccx <= hx1 + x_tol
in_end_zone = ccx <= left_zone or ccx >= right_zone
if in_range and in_end_zone:
valid_cxs.append(ccx)
n_poles = len(valid_cxs)
if n_poles >= 2:
lcx, rcx = min(valid_cxs), max(valid_cxs)
if lcx <= hx0 + span * 0.4 and rcx >= hx1 - span * 0.4:
return 'RAAMEN', n_poles
return 'PARTIAL', n_poles
if n_poles == 1:
return ('RAAMEN_OCCLUDED', 1) if not is_center else ('', 1)
return '', 0
def _merge_poly_hull(indices, polys):
"""여러 폴리곤 점들을 합쳐 Convex Hull 좌표 반환 (AnyLabeling points 형식)."""
all_pts = np.vstack([polys[i] for i in indices])
hull = cv2.convexHull(all_pts.astype(np.int32))
return [[float(p[0][0]), float(p[0][1])] for p in hull]
def group_detail(group, polys, W, center_ratio=0.2, v_cluster_margin=60):
"""
라멘 그룹의 세부 구성 분석.
Returns dict: main_h, junk_h, valid_pole_clusters, attach_clusters
"""
hs, vs = group['H'], group['V']
if not hs:
return {'main_h': None, 'junk_h': [], 'valid_pole_clusters': [], 'attach_clusters': []}
h_by_area = sorted(hs, key=lambda i: cv2.contourArea(polys[i].astype(np.int32)), reverse=True)
main_hs = [h_by_area[0]]
if len(h_by_area) > 1:
a0 = cv2.contourArea(polys[h_by_area[0]].astype(np.int32))
a1 = cv2.contourArea(polys[h_by_area[1]].astype(np.int32))
if a1 >= a0 * 0.5:
main_hs.append(h_by_area[1])
main_h = main_hs[0] # JSON 출력용 대표 빔
junk_h = [i for i in hs if i not in main_hs]
hx0 = int(min(polys[i][:, 0].min() for i in main_hs))
hx1 = int(max(polys[i][:, 0].max() for i in main_hs))
span = hx1 - hx0
pole_clusters = _cluster_polys(vs, polys, margin=v_cluster_margin)
x_tol = max(span * 0.5, 50)
left_zone = hx0 + span * 0.35
right_zone = hx1 - span * 0.35
valid_pole_clusters, attach_clusters = [], []
for cluster in pole_clusters:
ccx = float(np.mean([polys[i][:, 0].mean() for i in cluster]))
in_range = hx0 - x_tol <= ccx <= hx1 + x_tol
in_end_zone = ccx <= left_zone or ccx >= right_zone
if in_range and in_end_zone:
valid_pole_clusters.append(cluster)
elif in_range:
attach_clusters.append(cluster)
return {
'main_h': main_h,
'main_hs': main_hs,
'junk_h': junk_h,
'valid_pole_clusters': valid_pole_clusters,
'attach_clusters': attach_clusters,
}
# ── 시각화 상수 ─────────────────────────────────────────────────────────
_VH_COLOR = {
'V': (255, 80, 0), # 주황 (수직 기둥)
'H': ( 0, 80, 255), # 파란 (수평 빔)
'?': (140, 140, 140), # 회색 (미분류)
}
_RAAMEN_COLOR = {
'RAAMEN': ( 0, 255, 0), # 초록
'RAAMEN_CENTER': ( 0, 255, 255), # 노랑 (중앙 영역, H/V 분류 불신뢰)
'RAAMEN_OCCLUDED': ( 0, 165, 255), # 주황 (가림/부분 검출)
'PARTIAL': (128, 128, 128), # 회색
}
# ── 메인 렌더링 ─────────────────────────────────────────────────────────
def render(image_path, label_path, output_path, args,
class_ids=None, class_names=None, epsilon=4.0, v_thresh=20.0,
h_max_diff=75.0, vp_min_ar=2.5, vp_min_len=80.0, vp_outer_ratio=0.2,
x_horiz_thresh=10.0):
buf = np.fromfile(str(image_path), dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
H, W = img.shape[:2]
# ── Phase 1 ──────────────────────────────────────────────────────────
raw_polys, poly_abs_idx = load_polygons(label_path, W, H, class_ids, class_names)
polys = [smooth_polygon(p, epsilon) for p in raw_polys]
print(f" {len(polys)}개 폴리곤 파싱 (epsilon={epsilon})")
img_cx, img_cy = W / 2.0, H / 2.0
# ── Phase 1+2: 두 가지 VP 시드 방식으로 시도 → V 폴리곤이 더 많은 VP 채택 ──
elong_idx = []
for i, pts in enumerate(polys):
la, cx, cy, long_side, short_side = _long_axis_angle(pts)
if short_side > 0 and long_side / short_side > vp_min_ar and long_side > vp_min_len:
elong_idx.append(i)
# 시드 A: 지배적 장축 각도 클러스터 (이미지 방향 비의존)
seeds_A = []
if len(elong_idx) >= 2:
avals_all = [(i, _long_axis_angle(polys[i])[0] % 180.0) for i in elong_idx]
hist, edges = np.histogram([a for _, a in avals_all], bins=12, range=(0.0, 180.0))
pk = int(np.argmax(hist))
peak_c = (edges[pk] + edges[pk + 1]) / 2.0
seeds_A = [i for i, a in avals_all
if min(abs(a - peak_c), 180.0 - abs(a - peak_c)) < 20.0]
# 시드 B: radial cos_sim (이미지 중심 방사 방향 정렬)
seeds_B = []
for i in elong_idx:
_, cx, cy, _, _ = _long_axis_angle(polys[i])
dist = ((cx - img_cx) ** 2 + (cy - img_cy) ** 2) ** 0.5
if dist > min(W, H) * vp_outer_ratio and _radial_cos_sim(polys[i], img_cx, img_cy) > 0.5:
seeds_B.append(i)
# 두 시드 모두 시도 → V가 더 많이 나오는 VP 채택
best_vp_x, best_vp_y = img_cx, -H * 3.0
best_n_v = 0
orients, adiffs = ['?'] * len(polys), [90.0] * len(polys)
for label, seeds in [('지배각', seeds_A), ('radial', seeds_B)]:
if len(seeds) < 2:
continue
vx, vy, nv, ors, ads = _estimate_vp_iterative(
polys, seeds, v_thresh, h_max_diff, vp_min_len, x_horiz_thresh)
print(f" VP [{label}]: ({vx:.0f}, {vy:.0f}) V={nv}")
if nv > best_n_v:
best_vp_x, best_vp_y, best_n_v = vx, vy, nv
orients, adiffs = ors, ads
vp_x, vp_y = best_vp_x, best_vp_y
print(f" → 채택 VP: ({vp_x:.1f}, {vp_y:.1f})")
print(f"\n [V/H 분류] threshold=±{v_thresh}° h_max=±{h_max_diff}°")
for i in range(len(polys)):
print(f" poly {i:>2d}: {orients[i]} diff={adiffs[i]:.1f}°")
print(f" V:{orients.count('V')} H:{orients.count('H')} ?:{orients.count('?')}")
# ── Phase 3 ──────────────────────────────────────────────────────────
groups = connectivity_groups(polys, orients, margin=args.margin)
print(f"\n [그룹핑] 연결 컴포넌트 {len(groups)}개 (margin={args.margin}px)")
for g in groups:
print(f" G{g['id']}: H={g['H']} V={g['V']}")
# ── Phase 4 ──────────────────────────────────────────────────────────
print(f"\n [라멘 판정]")
for g in groups:
verdict, n_poles = judge_raamen(g, polys, W)
g['verdict'] = verdict
g['n_poles'] = n_poles
pole_str = f"{n_poles}poles" if n_poles else "-"
print(f" G{g['id']}: H={g['H']} V={g['V']}{verdict or '-':18s} ({pole_str})")
valid_items = [g for g in groups if g['verdict']]
valid_items.sort(key=lambda x: x['area'], reverse=True)
print(f"\n [최종 라멘 객체] {len(valid_items)}개 (면적순)")
for g in valid_items:
print(f" G{g['id']}: H={g['H']} V={g['V']}{g['verdict']:18s} ({g['n_poles']}poles) Area={g['area']:,.0f}")
# 최소 면적 필터링
if args.min_group_area > 0:
before = len(valid_items)
valid_items = [g for g in valid_items if g['area'] >= args.min_group_area]
print(f" [필터] 최소 면적 {args.min_group_area} 미만 제거: {before}{len(valid_items)}")
# 모든 그룹: 그룹 내 최하단 꼭짓점 포함 폴리곤이 H이면 → V로 재분류
for g in valid_items:
all_idxs = g['H'] + g['V']
bottom_vi = max(all_idxs, key=lambda i: polys[i][:, 1].max())
if bottom_vi in g['H']:
g['H'].remove(bottom_vi)
g['V'].append(bottom_vi)
orients[bottom_vi] = 'V'
# RAAMEN_CENTER (H 없는 그룹): 최하단=V, 나머지=H로 display 재조정
if not g['H']:
for vi in g['V']:
orients[vi] = 'V' if vi == bottom_vi else 'H'
# ── 시각화 ───────────────────────────────────────────────────────────
# 1. 폴리곤 V/H 색상 반투명 오버레이
for i, (pts, orient) in enumerate(zip(polys, orients)):
color = _VH_COLOR[orient]
pts_i = pts.astype(np.int32)
ov = img.copy()
cv2.fillPoly(ov, [pts_i], color)
cv2.addWeighted(ov, 0.25, img, 0.75, 0, img)
cv2.polylines(img, [pts_i], True, color, 2)
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
lbl = f"{i}{orient}"
cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4)
cv2.putText(img, lbl, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
# 2. VP 시드 폴리곤에 청록 테두리 (채택된 시드셋 재구성)
for i in seeds_A + [i for i in seeds_B if i not in seeds_A]:
cv2.polylines(img, [polys[i].astype(np.int32)], True, (0, 220, 220), 3)
# 3. 소실점 표시 (이미지 내부: 원, 외부: 방향 화살표)
vp_ix, vp_iy = int(vp_x), int(vp_y)
if 0 <= vp_ix < W and 0 <= vp_iy < H:
cv2.circle(img, (vp_ix, vp_iy), 20, (0, 255, 255), 3)
cv2.putText(img, "VP", (vp_ix + 25, vp_iy),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
else:
ax_s, ay_s = W // 2, H // 5
dx, dy = vp_x - img_cx, vp_y - img_cy
n = (dx ** 2 + dy ** 2) ** 0.5
if n > 0:
ax_e = max(0, min(W - 1, int(ax_s + dx / n * 80)))
ay_e = max(0, min(H - 1, int(ay_s + dy / n * 80)))
cv2.arrowedLine(img, (ax_s, ay_s), (ax_e, ay_e), (0, 255, 255), 3)
cv2.putText(img, f"VP({vp_ix},{vp_iy})", (ax_s + 5, ay_s - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
# 4. 라멘 그룹 강조: 바운딩 박스 + 그룹 ID + 판정 라벨
for g in valid_items:
verdict = g['verdict']
color = _RAAMEN_COLOR[verdict]
all_idx = g['H'] + g['V']
all_pts = np.vstack([polys[i] for i in all_idx]).astype(np.int32)
x0 = all_pts[:, 0].min() - 15; y0 = all_pts[:, 1].min() - 15
x1 = all_pts[:, 0].max() + 15; y1 = all_pts[:, 1].max() + 15
cv2.rectangle(img, (x0, y0), (x1, y1), color, 4)
lbl = f"G{g['id']} {verdict} ({g['n_poles']}poles)"
cv2.putText(img, lbl, (x0, y0 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 0), 6)
cv2.putText(img, lbl, (x0, y0 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)
# 이미지 저장
scale = min(1.0, 4096 / max(H, W))
if scale < 1.0:
img = cv2.resize(img, (int(W * scale), int(H * scale)))
output_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path))
print(f"\n{output_path}")
# AnyLabeling JSON 저장
import json
def _shape(label, points, gid, desc):
return {"label": label, "score": None, "points": points,
"group_id": gid, "description": desc,
"shape_type": "polygon", "flags": None}
shapes = []
for g in valid_items:
gid = g['id']
verdict = g['verdict']
desc_base = f"G{gid} {verdict}"
def abs_ids(rel_list):
"""상대 폴리곤 인덱스 → 절대 JSON shapes[] 인덱스 변환."""
return sorted(poly_abs_idx[i] for i in rel_list)
if not g['H']:
# RAAMEN_CENTER: 최하단 꼭짓점 포함 폴리곤 = 기둥, 나머지 = 빔
bottom_vi = max(g['V'], key=lambda i: polys[i][:, 1].max())
for vi in g['V']:
label = "raamen_pole" if vi == bottom_vi else "raamen_beam"
shapes.append(_shape(label,
[[float(p[0]), float(p[1])] for p in polys[vi]],
gid, f"{desc_base} shape#{poly_abs_idx[vi]}"))
else:
# 일반 RAAMEN: V = 기둥, H = 빔
for vi in g['V']:
shapes.append(_shape("raamen_pole",
[[float(p[0]), float(p[1])] for p in polys[vi]],
gid, f"{desc_base} pole shape#{poly_abs_idx[vi]}"))
for hi in g['H']:
shapes.append(_shape("raamen_beam",
[[float(p[0]), float(p[1])] for p in polys[hi]],
gid, f"{desc_base} beam shape#{poly_abs_idx[hi]}"))
anylabel_json = {
"version": "3.3.9",
"flags": {},
"shapes": shapes,
"imagePath": image_path.name,
"imageData": None,
"imageHeight": H,
"imageWidth": W,
}
json_path = output_path.with_suffix(".json")
json_path.write_text(json.dumps(anylabel_json, ensure_ascii=False, indent=2),
encoding="utf-8")
print(f"{json_path}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True)
ap.add_argument("--label", required=True)
ap.add_argument("--output", required=True)
ap.add_argument("--class-ids", default="",
help="포함할 클래스 ID, 콤마 구분 (.txt 전용)")
ap.add_argument("--class-names", default="catenary_pole",
help="포함할 클래스 이름, 콤마 구분 (.json 전용, 기본: 'catenary_pole')")
ap.add_argument("--epsilon", type=float, default=4.0,
help="approxPolyDP epsilon (기본 4.0px)")
ap.add_argument("--v-thresh", type=float, default=20.0,
help="V/H 분류 각도 임계값 degrees (기본 20°)")
ap.add_argument("--h-max-diff", type=float, default=75.0,
help="H(빔) 최대 각도 diff; 이 이상은 레일/전선 등으로 제외 (기본 75°)")
ap.add_argument("--x-horiz-thresh", type=float, default=10.0,
help="X축 절대 수평 제외 임계값 degrees; AR≥4 AND 이 각도 이내 → 제외 (기본 10°)")
ap.add_argument("--margin", type=int, default=30,
help="폴리곤 접촉 판정 margin px (기본 30)")
ap.add_argument("--min-group-area", type=float, default=0,
help="라멘 그룹의 최소 면적 합계 (기본 0)")
args = ap.parse_args()
class_ids = ({int(x) for x in args.class_ids.split(',') if x.strip()}
if args.class_ids else None)
class_names = ({x.strip() for x in args.class_names.split(',') if x.strip()}
if args.class_names else None)
out = Path(args.output)
folder = out.parent / out.stem # e.g. output/0004_test
if folder.exists():
n = 1
while (out.parent / f"{out.stem}_{n}").exists():
n += 1
folder = out.parent / f"{out.stem}_{n}"
folder.mkdir(parents=True, exist_ok=True)
out = folder / out.name # e.g. output/0004_test/0004_test.jpg
print(f" [출력 폴더] {folder}")
render(Path(args.image), Path(args.label), out, args,
class_ids=class_ids, class_names=class_names,
epsilon=args.epsilon, v_thresh=args.v_thresh,
h_max_diff=args.h_max_diff, x_horiz_thresh=args.x_horiz_thresh)
if __name__ == "__main__":
main()

503
tools/labeling_server.py Normal file
View File

@@ -0,0 +1,503 @@
"""
Control Box 라벨링 서버 v2 — 전체 이미지 + bbox 오버레이 방식
사용법: python tools/labeling_server.py \
--json "data/역사이미지/slope/DJI_20260306113838_0004_everything.json"
[--reset] DB 초기화
브라우저: http://localhost:7001
"""
import argparse, json, sqlite3, sys
from collections import defaultdict
from pathlib import Path
import cv2, numpy as np
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse, Response
import uvicorn
ROOT = Path(__file__).parent.parent
DB = ROOT / "labels" / "labeling.db"
MIN_VOTES = 3
TRUE_RATIO = 0.6
MAX_DIM = 3000 # served image max dimension (px)
CONTROL_LABELS = {
"small dark object on ballast", "manhole cover", "trackside enclosure",
"dark rectangle on ground", "small square box", "small metal cabinet",
"small cube on gravel", "compact trackside junction box",
"small square gray metal box beside rail",
"small near-square electrical enclosure on the ground",
}
app = FastAPI()
_img_data: bytes = None
_orig_w = _orig_h = _disp_w = _disp_h = 0
# ── DB ────────────────────────────────────────────────────────────────────────
def get_db():
con = sqlite3.connect(str(DB))
con.row_factory = sqlite3.Row
return con
def init_db(candidates: list):
DB.parent.mkdir(parents=True, exist_ok=True)
con = get_db()
con.executescript("""
CREATE TABLE IF NOT EXISTS candidates (
id INTEGER PRIMARY KEY,
json_idx INTEGER,
label TEXT,
score REAL,
bbox TEXT,
image_path TEXT
);
CREATE TABLE IF NOT EXISTS votes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
candidate_id INTEGER,
user TEXT,
vote INTEGER,
ts DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(candidate_id, user)
);
""")
if con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0] > 0:
n = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0]
print(f"DB 기존 데이터 유지. candidates={n}")
con.close()
return
for c in candidates:
con.execute(
"INSERT INTO candidates(json_idx,label,score,bbox,image_path) VALUES(?,?,?,?,?)",
(c["json_idx"], c["label"], c["score"], json.dumps(c["bbox"]), c["image_path"]),
)
con.commit()
print(f"DB 등록: {len(candidates)}")
con.close()
# ── 후보 로드 ─────────────────────────────────────────────────────────────────
def load_candidates(json_path: Path) -> list:
data = json.loads(json_path.read_text(encoding="utf-8"))
stem = json_path.stem.replace("_everything", "")
img_path = next(
(json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png")
if (json_path.parent / f"{stem}{ext}").exists()),
None,
)
if img_path is None:
raise FileNotFoundError(f"원본 이미지 없음: {json_path.parent / stem}.*")
candidates = []
for idx, seg in enumerate(data.get("segments", [])):
label = seg.get("label", "").strip()
if label not in CONTROL_LABELS:
continue
x0, y0, x1, y1 = [float(v) for v in seg["bbox"]]
if (x1 - x0) < 4 or (y1 - y0) < 4:
continue
candidates.append({
"json_idx": idx,
"label": label,
"score": float(seg.get("score", 0)),
"bbox": [x0, y0, x1, y1],
"image_path": str(img_path),
})
segs_total = len(data.get("segments", []))
print(f"필터링: {segs_total}개 → {len(candidates)}개 control_box 후보")
return candidates
# ── 이미지 준비 (서버 시작 시 1회) ───────────────────────────────────────────
def prepare_image(img_path: Path):
global _img_data, _orig_w, _orig_h, _disp_w, _disp_h
buf = np.fromfile(str(img_path), dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if img is None:
raise ValueError(f"이미지 로드 실패: {img_path}")
_orig_h, _orig_w = img.shape[:2]
scale = min(1.0, MAX_DIM / max(_orig_w, _orig_h))
_disp_w, _disp_h = int(_orig_w * scale), int(_orig_h * scale)
if scale < 1.0:
img = cv2.resize(img, (_disp_w, _disp_h), interpolation=cv2.INTER_LANCZOS4)
_, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 92])
_img_data = bytes(enc)
print(f"이미지 준비: {_orig_w}×{_orig_h}{_disp_w}×{_disp_h}")
# ── API ───────────────────────────────────────────────────────────────────────
@app.get("/image")
async def serve_image():
return Response(content=_img_data, media_type="image/jpeg")
@app.get("/api/image_info")
async def api_image_info():
return {"orig_w": _orig_w, "orig_h": _orig_h, "disp_w": _disp_w, "disp_h": _disp_h}
@app.get("/api/candidates")
async def api_candidates(user: str = ""):
con = get_db()
rows = con.execute("""
SELECT c.id, c.bbox, c.label, c.score, v.vote
FROM candidates c
LEFT JOIN votes v ON c.id=v.candidate_id AND v.user=?
""", (user,)).fetchall()
con.close()
return {"candidates": [
{
"id": r["id"],
"bbox": json.loads(r["bbox"]),
"label": r["label"],
"score": round(r["score"], 2),
"voted": r["vote"] is not None,
"is_true": r["vote"] == 1 if r["vote"] is not None else None,
} for r in rows
]}
@app.post("/api/vote")
async def api_vote(data: dict):
user = data.get("user", "").strip()
all_ids = data.get("all_ids", [])
true_ids = set(data.get("true_ids", []))
if not user or not all_ids:
return {"ok": False, "error": "파라미터 오류"}
con = get_db()
for cid in all_ids:
con.execute(
"INSERT OR REPLACE INTO votes(candidate_id,user,vote) VALUES(?,?,?)",
(cid, user, 1 if cid in true_ids else 0),
)
con.commit()
con.close()
return {"ok": True}
@app.get("/api/stats")
async def api_stats():
con = get_db()
total = con.execute("SELECT COUNT(*) FROM candidates").fetchone()[0]
voted3 = con.execute(f"""
SELECT COUNT(*) FROM (
SELECT candidate_id FROM votes GROUP BY candidate_id HAVING COUNT(*) >= {MIN_VOTES}
)
""").fetchone()[0]
users = con.execute("SELECT COUNT(DISTINCT user) FROM votes").fetchone()[0]
tvotes = con.execute("SELECT COUNT(*) FROM votes").fetchone()[0]
con.close()
return {"total": total, "voted3plus": voted3, "users": users, "total_votes": tvotes}
@app.post("/api/export")
async def api_export():
con = get_db()
confirmed = con.execute(f"""
SELECT c.id, c.bbox, c.image_path,
SUM(v.vote) AS true_v, COUNT(v.id) AS total_v
FROM candidates c JOIN votes v ON c.id=v.candidate_id
GROUP BY c.id
HAVING total_v >= {MIN_VOTES} AND (CAST(true_v AS REAL)/total_v) >= {TRUE_RATIO}
""").fetchall()
con.close()
by_image = defaultdict(list)
for row in confirmed:
by_image[row["image_path"]].append(row)
out_dir = ROOT / "labels" / "yolo_export"
out_dir.mkdir(parents=True, exist_ok=True)
count = 0
for img_path, rows in by_image.items():
buf = np.fromfile(img_path, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
H, W = img.shape[:2]
txt = out_dir / (Path(img_path).stem + ".txt")
with open(txt, "w") as f:
for row in rows:
x0, y0, x1, y1 = json.loads(row["bbox"])
f.write(f"0 {((x0+x1)/2)/W:.6f} {((y0+y1)/2)/H:.6f} "
f"{(x1-x0)/W:.6f} {(y1-y0)/H:.6f}\n")
count += 1
return {"ok": True, "labels": count, "dir": str(out_dir)}
# ── HTML ──────────────────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def index():
return HTML
HTML = r"""<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="UTF-8">
<title>Control Box 라벨링</title>
<style>
*{box-sizing:border-box;margin:0;padding:0}
body{font-family:'Segoe UI',sans-serif;background:#1a1a2e;color:#e0e0e0;display:flex;flex-direction:column;height:100vh;overflow:hidden}
#toolbar{background:#16213e;border-bottom:1px solid #0f3460;padding:8px 14px;display:flex;align-items:center;gap:10px;flex-shrink:0;flex-wrap:wrap}
h1{color:#e94560;font-size:1rem;white-space:nowrap}
input[type=text]{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0;padding:5px 10px;border-radius:4px;font-size:.85rem;width:130px}
input[type=text]:focus{outline:none;border-color:#e94560}
.btn{background:#e94560;color:#fff;border:none;padding:6px 12px;border-radius:5px;cursor:pointer;font-size:.82rem;font-weight:700;white-space:nowrap}
.btn:hover{background:#c73652}
.btn-sec{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0}
.btn-sec:hover{background:#1a5299}
#stats{font-size:.75rem;color:#777;white-space:nowrap}
#legend{display:flex;gap:10px;align-items:center;font-size:.72rem;color:#888}
.leg{display:inline-flex;align-items:center;gap:3px}
.lb{width:12px;height:12px;border:2px solid;border-radius:2px;flex-shrink:0}
#wrap{flex:1;position:relative;overflow:hidden;cursor:crosshair}
canvas{display:block}
#tip{position:absolute;background:rgba(0,0,0,.85);color:#fff;font-size:.7rem;padding:3px 8px;border-radius:3px;pointer-events:none;display:none;max-width:300px}
#hint{position:absolute;bottom:8px;left:50%;transform:translateX(-50%);font-size:.72rem;color:#444;pointer-events:none;white-space:nowrap}
</style>
</head>
<body>
<div id="toolbar">
<h1>🎯 Control Box</h1>
<input type="text" id="uname" placeholder="이름/사번" autocomplete="off">
<button class="btn" onclick="init()">시작</button>
<button class="btn btn-sec" onclick="fitView()">맞춤(F)</button>
<span id="stats">—</span>
<div id="legend">
<span class="leg"><span class="lb" style="border-color:#ffaa00"></span>미투표</span>
<span class="leg"><span class="lb" style="border-color:#00cc66"></span>컨트롤박스 ✓</span>
<span class="leg"><span class="lb" style="border-color:#cc4444"></span>아님 ✗</span>
<span class="leg"><span class="lb" style="border-color:#444"></span>타인투표완료</span>
</div>
<div style="flex:1"></div>
<button class="btn btn-sec" onclick="doExport()" style="font-size:.75rem">YOLO 내보내기</button>
</div>
<div id="wrap">
<canvas id="cv"></canvas>
<div id="tip"></div>
<div id="hint">휠=줌 | 드래그=이동 | 클릭=YES/NO 토글 | F=맞춤</div>
</div>
<script>
const cv = document.getElementById('cv');
const ctx = cv.getContext('2d');
const wrap = document.getElementById('wrap');
let user = '', cands = [], imgScale = 1;
let origW = 0, origH = 0, dispW = 0, dispH = 0;
let sc = 1, ox = 0, oy = 0;
let isDrag = false, dragX = 0, dragY = 0, moved = false;
const img = new Image();
function resizeCv() {
cv.width = wrap.clientWidth;
cv.height = wrap.clientHeight;
render();
}
window.addEventListener('resize', resizeCv);
resizeCv();
async function init() {
user = document.getElementById('uname').value.trim();
if (!user) { alert('이름 입력 필요'); return; }
const info = await fetch('/api/image_info').then(r => r.json());
origW = info.orig_w; origH = info.orig_h;
dispW = info.disp_w; dispH = info.disp_h;
imgScale = dispW / origW;
img.onload = () => { fitView(); loadCands(); };
img.src = '/image?' + Date.now();
}
async function loadCands() {
const d = await fetch('/api/candidates?user=' + encodeURIComponent(user)).then(r => r.json());
cands = d.candidates.map(c => ({...c, _sel: c.voted ? c.is_true : undefined}));
updateStats();
render();
}
function updateStats() {
const myVotes = cands.filter(c => c._sel !== undefined).length;
const yes = cands.filter(c => c._sel === true).length;
const othersVoted = cands.filter(c => c.voted && c._sel === undefined).length;
document.getElementById('stats').textContent =
`후보 ${cands.length}개 | 내투표 ${myVotes}개 (YES:${yes}) | 타인완료 ${othersVoted}개`;
}
function fitView() {
if (!img.naturalWidth) return;
const ww = wrap.clientWidth, wh = wrap.clientHeight;
sc = Math.min(ww / dispW, wh / dispH) * 0.97;
ox = (ww - dispW * sc) / 2;
oy = (wh - dispH * sc) / 2;
render();
}
function render() {
const W = cv.width, H = cv.height;
ctx.fillStyle = '#0d0d1e';
ctx.fillRect(0, 0, W, H);
if (!img.naturalWidth) return;
ctx.save();
ctx.translate(ox, oy);
ctx.scale(sc, sc);
ctx.drawImage(img, 0, 0, dispW, dispH);
const lw = Math.max(1, 2 / sc);
for (const c of cands) {
const [x0, y0, x1, y1] = c.bbox.map(v => v * imgScale);
const w = x1 - x0, h = y1 - y0;
let color, fill = null;
if (c._sel === true) { color = '#00cc66'; fill = 'rgba(0,204,102,0.15)'; }
else if (c._sel === false) { color = '#cc4444'; fill = 'rgba(204,68,68,0.15)'; }
else if (c.voted) { color = '#444444'; }
else { color = c.score > 0.5 ? '#ffaa00' : '#ff7700aa'; }
ctx.lineWidth = lw;
ctx.strokeStyle = color;
if (fill) {
ctx.fillStyle = fill;
ctx.fillRect(x0, y0, w, h);
}
ctx.strokeRect(x0, y0, w, h);
}
ctx.restore();
}
// ── wheel zoom ────────────────────────────────────────────────────────────────
wrap.addEventListener('wheel', e => {
e.preventDefault();
const rect = wrap.getBoundingClientRect();
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
const f = e.deltaY < 0 ? 1.15 : 1/1.15;
ox = mx - (mx - ox) * f;
oy = my - (my - oy) * f;
sc *= f;
render();
}, {passive: false});
// ── drag pan ──────────────────────────────────────────────────────────────────
wrap.addEventListener('mousedown', e => {
if (e.button !== 0) return;
isDrag = true; moved = false;
dragX = e.clientX; dragY = e.clientY;
wrap.style.cursor = 'grabbing';
});
window.addEventListener('mousemove', e => {
if (isDrag) {
const dx = e.clientX - dragX, dy = e.clientY - dragY;
if (Math.abs(dx) + Math.abs(dy) > 3) moved = true;
if (moved) { ox += dx; oy += dy; dragX = e.clientX; dragY = e.clientY; render(); }
}
// tooltip
const c = hitTest(e);
const tip = document.getElementById('tip');
if (c) {
const rect = wrap.getBoundingClientRect();
tip.style.display = 'block';
tip.style.left = (e.clientX - rect.left + 14) + 'px';
tip.style.top = (e.clientY - rect.top + 4) + 'px';
const status = c._sel === true ? '✅ YES' : c._sel === false ? '❌ NO' : c.voted ? '⬜ 타인완료' : '❓ 미투표';
tip.textContent = `[${c.id}] ${c.label} (${c.score}) — ${status}`;
} else {
tip.style.display = 'none';
}
});
window.addEventListener('mouseup', e => {
if (!isDrag) return;
wrap.style.cursor = 'crosshair';
isDrag = false;
if (!moved) handleClick(e);
});
// ── hit test ──────────────────────────────────────────────────────────────────
function imgCoords(e) {
const rect = wrap.getBoundingClientRect();
return [(e.clientX - rect.left - ox) / sc / imgScale,
(e.clientY - rect.top - oy) / sc / imgScale];
}
function hitTest(e) {
const [ix, iy] = imgCoords(e);
let best = null, bestArea = Infinity;
for (const c of cands) {
const [x0, y0, x1, y1] = c.bbox;
if (ix >= x0 && ix <= x1 && iy >= y0 && iy <= y1) {
const a = (x1-x0)*(y1-y0);
if (a < bestArea) { bestArea = a; best = c; }
}
}
return best;
}
async function handleClick(e) {
if (!user) return;
const c = hitTest(e);
if (!c) return;
// Toggle: undefined→YES→NO→YES...
c._sel = c._sel === true ? false : true;
render();
updateStats();
// Save immediately
await fetch('/api/vote', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({user, all_ids: [c.id], true_ids: c._sel ? [c.id] : []})
});
c.voted = true;
c.is_true = c._sel;
}
// ── keyboard ──────────────────────────────────────────────────────────────────
window.addEventListener('keydown', e => {
if (e.key === 'f' || e.key === 'F') fitView();
});
async function doExport() {
const r = await fetch('/api/export', {method: 'POST'}).then(r => r.json());
alert(`완료: ${r.labels}개 라벨\n경로: ${r.dir}`);
}
</script>
</body>
</html>"""
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--json", required=True, help="*_everything.json 경로")
ap.add_argument("--port", type=int, default=7001)
ap.add_argument("--reset", action="store_true", help="DB 초기화 후 재로드")
args = ap.parse_args()
json_path = Path(args.json)
if not json_path.exists():
print(f"JSON 없음: {json_path}"); sys.exit(1)
if args.reset and DB.exists():
DB.unlink()
print("DB 초기화 완료.")
print("후보 로딩 중...")
candidates = load_candidates(json_path)
init_db(candidates)
# Prepare image
stem = json_path.stem.replace("_everything", "")
img_path = next(
(json_path.parent / f"{stem}{ext}" for ext in (".JPG", ".jpg", ".png")
if (json_path.parent / f"{stem}{ext}").exists()),
None,
)
if img_path is None:
print(f"원본 이미지 없음: {json_path.parent / stem}.*"); sys.exit(1)
prepare_image(img_path)
print(f"\n라벨링 서버 시작: http://localhost:{args.port}")
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")
if __name__ == "__main__":
main()

90
tools/merge_tiles_vis.py Normal file
View File

@@ -0,0 +1,90 @@
"""
6×8 (또는 임의 그리드) 타일 라벨을 원본 이미지 좌표로 병합해 시각화.
사용:
python tools/merge_tiles_vis.py \
--orig data/역사이미지/slope/DJI_20260306113838_0004.JPG \
--labels output/autolabel/tile6x8/labels \
--output output/autolabel/tile6x8/merged_vis.jpg \
--cols 6 --rows 8
"""
import argparse
import cv2
import numpy as np
from pathlib import Path
CLASS_NAMES = ["catenary_pole", "bracket"]
CLASS_COLORS = [(0, 200, 255), (255, 130, 0)]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--orig", required=True)
parser.add_argument("--labels", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--cols", type=int, default=6)
parser.add_argument("--rows", type=int, default=8)
parser.add_argument("--alpha", type=float, default=0.3)
args = parser.parse_args()
buf = np.fromfile(args.orig, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
H, W = img.shape[:2]
tw, th = W // args.cols, H // args.rows
label_dir = Path(args.labels)
counts = [0] * len(CLASS_NAMES)
for r in range(args.rows):
for c in range(args.cols):
label_file = label_dir / f"tile_r{r+1:02d}_c{c+1:02d}.txt"
if not label_file.exists():
continue
x0 = c * tw
y0 = r * th
tile_w = tw if c < args.cols - 1 else W - x0
tile_h = th if r < args.rows - 1 else H - y0
text = label_file.read_text(encoding="utf-8").strip()
if not text:
continue
for line in text.splitlines():
parts = line.split()
if not parts:
continue
cls_id = int(parts[0])
coords = list(map(float, parts[1:]))
pts = np.array(
[[coords[i] * tile_w + x0, coords[i + 1] * tile_h + y0]
for i in range(0, len(coords), 2)],
dtype=np.int32,
)
color = CLASS_COLORS[cls_id % len(CLASS_COLORS)]
overlay = img.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, args.alpha, img, 1 - args.alpha, 0, img)
cv2.polylines(img, [pts], True, color, 3)
if cls_id < len(counts):
counts[cls_id] += 1
# 범례
for i, name in enumerate(CLASS_NAMES):
y = 30 + i * 50
cv2.rectangle(img, (15, y - 20), (55, y + 10), CLASS_COLORS[i], -1)
cv2.putText(img, f"{name}: {counts[i]}",
(65, y), cv2.FONT_HERSHEY_SIMPLEX, 1.5, CLASS_COLORS[i], 3)
total = sum(counts)
print(f"{total}" + ", ".join(f"{CLASS_NAMES[i]}={counts[i]}" for i in range(len(counts))))
# 4096px 이하로 미리보기 저장
scale = min(1.0, 4096 / max(H, W))
vis = cv2.resize(img, (int(W * scale), int(H * scale)))
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imencode(out_path.suffix, vis)[1].tofile(str(out_path))
print(f"저장: {out_path}")
if __name__ == "__main__":
main()

154
tools/post_merge_poles.py Normal file
View File

@@ -0,0 +1,154 @@
"""post_merge_poles.py — detect_all_objects.py 출력 JSON에서 catenary_pole 병합.
detecting은 한 번만, 병합 파라미터 조정 시 이 스크립트만 재실행.
Usage:
python tools/post_merge_poles.py INPUT.json [--x-overlap 0.30] [--y-gap 150]
python tools/post_merge_poles.py INPUT.json --inplace
python tools/post_merge_poles.py INPUT.json --output OUTPUT.json
"""
import argparse
import json
import sys
from pathlib import Path
import cv2
import numpy as np
def _poly_orient(points: list, H: int, W: int) -> str:
"""폴리곤 장축 방향 판별.
V: 장축이 이미지 중심 radial 방향 정렬(cos_sim > 0.7) → 세로 기둥
H: 직교 → 수평 빔
?: aspect ratio < 1.3
"""
pts = np.array(points, dtype=np.float32)
rect = cv2.minAreaRect(pts)
(rx, ry), (rw, rh), angle = rect
if min(rw, rh) < 1:
return '?'
ar = max(rw, rh) / min(rw, rh)
if ar < 1.3:
return '?'
long_angle_deg = angle if rw >= rh else angle + 90
lx = float(np.cos(np.radians(long_angle_deg)))
ly = float(np.sin(np.radians(long_angle_deg)))
img_cx, img_cy = W / 2.0, H / 2.0
rdx, rdy = rx - img_cx, ry - img_cy
radial_norm = (rdx ** 2 + rdy ** 2) ** 0.5
if radial_norm < 1:
return '?'
rdx, rdy = rdx / radial_norm, rdy / radial_norm
cos_sim = abs(lx * rdx + ly * rdy)
return 'V' if cos_sim > 0.7 else 'H'
def merge_poles(shapes: list, H: int, W: int,
x_overlap_thresh: float = 0.30,
y_gap_thresh: int = 150) -> list:
"""V+V 조합 전철주만 타일 경계 병합."""
if len(shapes) <= 1:
return shapes
orients = [_poly_orient(s["points"], H, W) for s in shapes]
v_count = sum(1 for o in orients if o == 'V')
h_count = sum(1 for o in orients if o == 'H')
print(f" orient: V={v_count}, H={h_count}, ?={len(orients)-v_count-h_count}")
def get_bbox(s):
xs = [p[0] for p in s["points"]]; ys = [p[1] for p in s["points"]]
return min(xs), min(ys), max(xs), max(ys)
def x_overlap_ratio(b1, b2):
ox = min(b1[2], b2[2]) - max(b1[0], b2[0])
ux = max(b1[2], b2[2]) - min(b1[0], b2[0])
return ox / ux if ux > 0 else 0.0
def y_gap(b1, b2):
return max(0.0, max(b1[1], b2[1]) - min(b1[3], b2[3]))
def merge_two(s1, s2):
mask = np.zeros((H, W), dtype=np.uint8)
for s in (s1, s2):
cv2.fillPoly(mask, [np.array(s["points"], dtype=np.int32)], 255)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return s1
c = max(contours, key=cv2.contourArea)
eps = 0.002 * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, eps, True)
merged = dict(s1)
merged["points"] = [[float(p[0][0]), float(p[0][1])] for p in approx]
merged["score"] = max(float(s1.get("score", 0)), float(s2.get("score", 0)))
return merged
merged_flags = [False] * len(shapes)
result = []
merged_count = 0
for i in range(len(shapes)):
if merged_flags[i]:
continue
cur = shapes[i]
cur_ori = orients[i]
cb = get_bbox(cur)
for j in range(i + 1, len(shapes)):
if merged_flags[j]:
continue
if cur_ori != 'V' or orients[j] != 'V':
continue
jb = get_bbox(shapes[j])
if x_overlap_ratio(cb, jb) >= x_overlap_thresh and y_gap(cb, jb) <= y_gap_thresh:
cur = merge_two(cur, shapes[j])
cur_ori = 'V'
cb = get_bbox(cur)
merged_flags[j] = True
merged_count += 1
result.append(cur)
print(f" 병합: {len(shapes)}{len(result)}개 ({merged_count}쌍 합침)")
return result
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", help="AnyLabeling JSON 파일")
ap.add_argument("--output", default=None, help="출력 JSON (기본: INPUT_merged.json)")
ap.add_argument("--inplace", action="store_true", help="원본 덮어쓰기")
ap.add_argument("--x-overlap", type=float, default=0.30, help="x-range 겹침 비율 임계값")
ap.add_argument("--y-gap", type=int, default=150, help="y-range 간격 임계값 (px)")
ap.add_argument("--label", default="catenary_pole", help="병합 대상 라벨명")
args = ap.parse_args()
src = Path(args.input)
if not src.exists():
print(f"파일 없음: {src}", file=sys.stderr)
sys.exit(1)
data = json.loads(src.read_text(encoding="utf-8"))
shapes = data.get("shapes", [])
iW = data.get("imageWidth", 0)
iH = data.get("imageHeight", 0)
if iW == 0 or iH == 0:
print("imageWidth/imageHeight 정보 없음", file=sys.stderr)
sys.exit(1)
target = [s for s in shapes if s.get("label") == args.label]
others = [s for s in shapes if s.get("label") != args.label]
print(f"{args.label}: {len(target)}개 → 병합 처리")
merged = merge_poles(target, iH, iW, args.x_overlap, args.y_gap)
data["shapes"] = others + merged
if args.inplace:
dst = src
elif args.output:
dst = Path(args.output)
else:
dst = src.with_stem(src.stem + "_merged")
dst.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"저장: {dst}")
if __name__ == "__main__":
main()

370
tools/rail_alignment_fit.py Normal file
View File

@@ -0,0 +1,370 @@
"""
철도 평면 선형 피팅 모듈
스켈레톤 좌표점 → 직선/원곡선/완화곡선 분류 및 피팅 → 매끄러운 폴리라인
선형 구성요소:
1. 직선 (Straight) : y = ax + b
2. 원곡선 (Circular) : R = 상수, Taubin 원 피팅
3. 완화곡선 (Transition) : y = x³/(6RL), 3차포물선 (KR 설계기준)
"""
import numpy as np
from scipy.signal import savgol_filter
from scipy.linalg import eig as scipy_eig
# ── 1. 곡률 계산 ──────────────────────────────────────────────────────────
def compute_curvature(pts, smooth_window=15):
"""
길이 가중 방향벡터를 이용한 곡률 계산
Parameters
----------
pts : (N, 2) array 세계 좌표 (미터)
smooth_window : int Savitzky-Golay 스무딩 윈도우 (홀수)
Returns
-------
curvatures : (N,) 곡률 κ [1/m]
arc_lengths : (N,) 각 점의 누적 호 길이 [m]
angles : (N,) 방향각 [rad], unwrap 처리됨
"""
pts = np.asarray(pts, dtype=float)
n = len(pts)
if n < 3:
return np.zeros(n), np.zeros(n), np.zeros(n)
# 세그먼트 벡터 및 길이
segs = np.diff(pts, axis=0) # (n-1, 2)
lens = np.linalg.norm(segs, axis=1) # (n-1,)
lens = np.maximum(lens, 1e-12)
# 단위 방향벡터
dirs = segs / lens[:, np.newaxis]
# 가중 방향벡터: 인접 세그먼트 길이를 가중치로 평균
directions = np.empty((n, 2))
directions[0] = dirs[0]
directions[-1] = dirs[-1]
for i in range(1, n - 1):
w1, w2 = lens[i - 1], lens[i]
directions[i] = (w1 * dirs[i - 1] + w2 * dirs[i]) / (w1 + w2)
# 정규화
norms = np.linalg.norm(directions, axis=1, keepdims=True)
directions /= np.maximum(norms, 1e-12)
# 방향각 (연속성 보장)
angles = np.arctan2(directions[:, 1], directions[:, 0])
angles = np.unwrap(angles)
# 누적 호 길이
arc_lengths = np.zeros(n)
arc_lengths[1:] = np.cumsum(lens)
# 곡률 κ = dθ/ds (중앙차분)
curvatures = np.zeros(n)
for i in range(1, n - 1):
dtheta = angles[i + 1] - angles[i - 1]
ds = arc_lengths[i + 1] - arc_lengths[i - 1]
if ds > 1e-10:
curvatures[i] = dtheta / ds
curvatures[0] = curvatures[1]
curvatures[-1] = curvatures[-2]
# Savitzky-Golay 스무딩 (노이즈 억제)
w = smooth_window
if n >= w and w >= 5:
w = w if w % 2 == 1 else w - 1
curvatures = savgol_filter(curvatures, w, 3)
return curvatures, arc_lengths, angles
# ── 2. 구간 분류 ──────────────────────────────────────────────────────────
def segment_by_curvature(curvatures, arc_lengths,
kappa_straight=0.001, # 1/1000m → R > 1000m 직선 처리
cv_thresh=0.35, # 변동계수(표준편차/평균) 임계값
min_seg_len=3.0): # 최소 구간 길이 [m]
"""
곡률 패턴으로 구간 분류
Returns
-------
segments : list of (start_idx, end_idx, type_str)
type_str ∈ {'straight', 'circular', 'transition'}
"""
n = len(curvatures)
kappa_abs = np.abs(curvatures)
win = max(5, n // 15)
def classify_window(k_arr):
km = k_arr.mean()
ks = k_arr.std()
if km < kappa_straight:
return 'straight'
cv = ks / (km + 1e-12)
if cv < cv_thresh:
return 'circular'
return 'transition'
# 각 점의 로컬 타입
local_type = []
for i in range(n):
lo = max(0, i - win // 2)
hi = min(n, i + win // 2 + 1)
local_type.append(classify_window(kappa_abs[lo:hi]))
# 연속된 같은 타입 구간 합치기
segments = []
cur_type = local_type[0]
cur_start = 0
for i in range(1, n):
if local_type[i] != cur_type:
segments.append((cur_start, i - 1, cur_type))
cur_type = local_type[i]
cur_start = i
segments.append((cur_start, n - 1, cur_type))
# 너무 짧은 구간은 인접 구간에 병합
merged = True
while merged and len(segments) > 1:
merged = False
new_segs = []
i = 0
while i < len(segments):
s, e, t = segments[i]
seg_len = arc_lengths[e] - arc_lengths[s]
if seg_len < min_seg_len and len(segments) > 1:
if i == 0 and i + 1 < len(segments):
ns, ne, nt = segments[i + 1]
segments[i + 1] = (s, ne, nt)
i += 1
merged = True
elif new_segs:
new_segs[-1] = (new_segs[-1][0], e, new_segs[-1][2])
merged = True
else:
new_segs.append((s, e, t))
else:
new_segs.append((s, e, t))
i += 1
if merged:
segments = new_segs
return segments
# ── 3. 개별 구간 피팅 ────────────────────────────────────────────────────
def fit_straight(pts, spacing=0.5):
"""PCA 직선 피팅 → 균등 샘플 점 목록"""
pts = np.asarray(pts, dtype=float)
center = pts.mean(axis=0)
_, _, Vt = np.linalg.svd(pts - center)
direction = Vt[0]
t = (pts - center) @ direction
t_min, t_max = t.min(), t.max()
n_pts = max(2, int((t_max - t_min) / spacing) + 1)
ts = np.linspace(t_min, t_max, n_pts)
return [(center + ti * direction).tolist() for ti in ts]
def taubin_circle_fit(pts):
"""Taubin 방법으로 안정적 원 피팅 → (cx, cy, R) 또는 None"""
pts = np.asarray(pts, dtype=float)
mean = pts.mean(axis=0)
x = pts[:, 0] - mean[0]
y = pts[:, 1] - mean[1]
z = x**2 + y**2
Mxx = (x * x).mean()
Myy = (y * y).mean()
Mxy = (x * y).mean()
Mxz = (x * z).mean()
Myz = (y * z).mean()
Mzz = (z * z).mean()
Mz = Mxx + Myy
Cov_xy = Mxx * Myy - Mxy**2
A3 = 4 * Mz
A2 = -3 * Mz**2 - Mzz
A1 = Mzz * Mz + 4 * Cov_xy * Mz - Mxz**2 - Myz**2 - Mz**3
A0 = Mxz**2 * Myy + Myz**2 * Mxx - Mzz * Cov_xy - 2 * Mxz * Myz * Mxy + Mz**2 * Cov_xy
try:
roots = np.roots([A3, A2, A1, A0])
real_roots = roots[np.isreal(roots)].real
if len(real_roots) == 0:
return None
xn = max(real_roots)
yn = (xn**2 - Mz) * xn + A0 / A3
det = xn**2 - xn * Mz + Cov_xy
if abs(det) < 1e-12:
return None
xcen = (Mxz * (Myy - xn) - Myz * Mxy) / (2 * det) + mean[0]
ycen = (Myz * (Mxx - xn) - Mxz * Mxy) / (2 * det) + mean[1]
R = np.sqrt(np.mean((pts[:, 0] - xcen)**2 + (pts[:, 1] - ycen)**2))
return xcen, ycen, R
except Exception:
return None
def fit_circular(pts, spacing=0.5):
"""원곡선 피팅 → 호 위의 균등 샘플 점 목록"""
result = taubin_circle_fit(pts)
if result is None:
return fit_straight(pts, spacing)
cx, cy, R = result
pts = np.asarray(pts, dtype=float)
angs = np.arctan2(pts[:, 1] - cy, pts[:, 0] - cx)
angs = np.unwrap(angs)
a_start, a_end = angs[0], angs[-1]
arc_len = abs(a_end - a_start) * R
n_pts = max(2, int(arc_len / spacing) + 1)
thetas = np.linspace(a_start, a_end, n_pts)
return [(cx + R * np.cos(t), cy + R * np.sin(t)) for t in thetas]
def fit_transition_cubic(pts, R_adj, spacing=0.5):
"""
3차포물선 완화곡선 피팅 y = x³ / (6·R·L)
로컬 좌표계(시작점, 시작방향 기준)에서 피팅 후 세계 좌표로 역변환
Parameters
----------
pts : (N, 2) 세계 좌표
R_adj : 인접 원곡선 반경 [m] (없으면 곡률 역수로 추정)
"""
pts = np.asarray(pts, dtype=float)
origin = pts[0].copy()
# 로컬 좌표계 축 (PCA 주방향)
_, _, Vt = np.linalg.svd(pts - origin)
axis_x = Vt[0]
axis_y = np.array([-axis_x[1], axis_x[0]])
local = pts - origin
lx = local @ axis_x
ly = local @ axis_y
L = lx.max()
if L < 1e-5 or R_adj < 1e-5:
return pts.tolist()
# 균등 샘플
n_pts = max(2, int(L / spacing) + 1)
xs = np.linspace(0, L, n_pts)
ys = xs**3 / (6.0 * R_adj * L)
return [(origin + xi * axis_x + yi * axis_y).tolist()
for xi, yi in zip(xs, ys)]
# ── 4. 전체 피팅 파이프라인 ───────────────────────────────────────────────
def fit_alignment(pts, spacing=0.5):
"""
메인 함수: 세계 좌표 점 목록 → 선형 피팅 → 매끄러운 폴리라인
Parameters
----------
pts : list of (x, y) EPSG:5186 세계 좌표 [m]
spacing : float 출력 점 간격 [m]
Returns
-------
smooth_pts : list of (x, y) 피팅된 매끄러운 폴리라인 점 목록
seg_info : list of dict 각 구간 정보 (디버깅용)
"""
pts = np.asarray(pts, dtype=float)
n = len(pts)
if n < 5:
return pts.tolist(), []
# --- 곡률 계산 ---
sw = min(21, n if n % 2 == 1 else n - 1)
sw = max(5, sw)
curvatures, arc_lengths, angles = compute_curvature(pts, sw)
# --- 구간 분류 ---
segments = segment_by_curvature(curvatures, arc_lengths)
# --- 인접 원곡선 R 사전 계산 (완화곡선에서 사용) ---
seg_R = {}
for idx, (s, e, t) in enumerate(segments):
if t == 'circular':
km = np.abs(curvatures[s:e+1]).mean()
seg_R[idx] = 1.0 / km if km > 1e-6 else 1000.0
smooth_pts = []
seg_info = []
for idx, (s, e, seg_type) in enumerate(segments):
seg_pts = pts[s:e+1]
if len(seg_pts) < 2:
smooth_pts.extend(seg_pts.tolist())
continue
km = np.abs(curvatures[s:e+1]).mean()
R_est = 1.0 / km if km > 1e-6 else 9999.0
if seg_type == 'straight':
fitted = fit_straight(seg_pts, spacing)
elif seg_type == 'circular':
fitted = fit_circular(seg_pts, spacing)
else: # transition
# 인접 원곡선 R 탐색
R_use = R_est
for di in [1, -1]:
ni = idx + di
if ni in seg_R:
R_use = seg_R[ni]
break
fitted = fit_transition_cubic(seg_pts, R_use, spacing)
seg_len = arc_lengths[e] - arc_lengths[s]
seg_info.append({
'type': seg_type,
'start_idx': s,
'end_idx': e,
'length_m': round(seg_len, 2),
'R_m': round(R_est, 1),
'n_pts_in': len(seg_pts),
'n_pts_out': len(fitted),
})
# 이음새 중복 제거
if smooth_pts and fitted:
smooth_pts.extend(fitted[1:])
else:
smooth_pts.extend(fitted)
return smooth_pts, seg_info
# ── 5. 간단 테스트 ────────────────────────────────────────────────────────
if __name__ == '__main__':
# 직선 + 원곡선 + 완화곡선 합성 테스트
t = np.linspace(0, 2*np.pi, 300)
# 직선 구간
straight = [(float(i)*0.5, 0.0) for i in range(40)]
# 원곡선 구간 R=300m
R = 300.0
theta = np.linspace(0, np.pi/4, 60)
cx, cy = straight[-1][0], straight[-1][1] + R
circ = [(cx + R*np.sin(th), cy - R*np.cos(th)) for th in theta]
pts_test = straight + circ[1:]
smooth, info = fit_alignment(pts_test, spacing=1.0)
print(f"입력: {len(pts_test)}점 → 출력: {len(smooth)}")
for s in info:
print(f" [{s['type']:10s}] L={s['length_m']:.1f}m R={s['R_m']:.0f}m "
f"점: {s['n_pts_in']}{s['n_pts_out']}")

View File

@@ -0,0 +1,175 @@
"""
2cm GSD 드론 TIF 전체에서 레일 중심선 추출 → DXF 저장
중심선 추출 방법:
masks.xy 폴리곤 → PCA 주축(길이 방향) 찾기
→ 주축 방향으로 슬라이싱 → 폭 방향 중앙값
→ 계단 현상 없는 매끄러운 중심선
"""
import numpy as np
import rasterio
from rasterio.windows import Window
from ultralytics import YOLO
from PIL import Image
import ezdxf
def polygon_to_centerline(xy_tile, spacing_px=10):
"""
타일 픽셀 좌표 폴리곤 → 중심선 점 목록
Parameters
----------
xy_tile : list of (x, y) 타일 픽셀 좌표 (masks.xy × sx/sy)
spacing_px : int 슬라이싱 간격 [픽셀] (10px = 0.2m at 2cm GSD)
Returns
-------
list of (x, y) 중심선 타일 픽셀 좌표
"""
pts = np.asarray(xy_tile, dtype=float)
if len(pts) < 4:
return []
# PCA: 주축(레일 길이 방향) 탐색
center = pts.mean(axis=0)
_, _, Vt = np.linalg.svd(pts - center, full_matrices=False)
v_long = Vt[0] # 길이 방향 단위벡터
v_perp = Vt[1] # 폭 방향 단위벡터
local = pts - center
t = local @ v_long # 길이 방향 좌표
t_min, t_max = t.min(), t.max()
if t_max - t_min < spacing_px:
# 너무 짧으면 무게중심 1점
return [center.tolist()]
n_bins = max(2, int((t_max - t_min) / spacing_px) + 1)
bins = np.linspace(t_min, t_max, n_bins + 1)
centerline = []
s_all = local @ v_perp # 폭 방향 좌표 (전체)
for i in range(n_bins):
mask = (t >= bins[i]) & (t <= bins[i + 1])
if mask.sum() < 2:
continue
t_c = (bins[i] + bins[i + 1]) / 2
s_vals = s_all[mask]
# 폭 방향: 폴리곤 양 가장자리의 기하학적 중앙
s_c = (s_vals.max() + s_vals.min()) / 2
pt = center + t_c * v_long + s_c * v_perp
centerline.append(pt.tolist())
return centerline
# ── 설정 ──────────────────────────────────────────
SRC_TIF = "drone_2cm/22)조치원(STA.127+570~131+300).tif"
MODEL_PT = "runs/segment/output/yolo_train_2cm/rail_seg_v2/weights/best.pt"
OUT_DXF = "output/rail_centerline_2cm_pca.dxf"
TILE = 1024
OVERLAP = 128
STEP = TILE - OVERLAP
CONF = 0.3
BLACK_THR = 0.4 # 검은 픽셀 비율 임계값
SPACING = 10 # 중심선 샘플 간격 [픽셀] 10px ≈ 0.2m
# ── 모델 로드 ──────────────────────────────────────
print("모델 로드 중...")
model = YOLO(MODEL_PT)
# ── DXF 초기화 ────────────────────────────────────
doc = ezdxf.new()
msp = doc.modelspace()
doc.layers.add("RAIL_CENTERLINE", color=3)
# ── TIF 열기 ──────────────────────────────────────
src = rasterio.open(SRC_TIF)
W, H = src.width, src.height
transform = src.transform
print(f"이미지 크기: {W} x {H}")
def pixel_to_world(px, py):
"""픽셀 좌표 → 세계 좌표 (EPSG:5186)"""
wx = transform.c + px * transform.a
wy = transform.f + py * transform.e
return wx, wy
# ── 타일 순회 ─────────────────────────────────────
total_polylines = 0
xs_range = range(0, W - TILE // 2, STEP)
ys_range = range(0, H - TILE // 2, STEP)
total_tiles = len(xs_range) * len(ys_range)
processed = 0
print(f"총 타일 수(예상): {total_tiles}, 처리 시작...")
for ty in ys_range:
for tx in xs_range:
tw = min(TILE, W - tx)
th = min(TILE, H - ty)
if tw < 64 or th < 64:
continue
win = Window(tx, ty, tw, th)
data = src.read([1, 2, 3], window=win)
img_arr = np.transpose(data, (1, 2, 0)).astype(np.uint8)
black_ratio = np.mean(np.all(img_arr < 10, axis=2))
if black_ratio > BLACK_THR:
processed += 1
continue
pil_img = Image.fromarray(img_arr)
results = model.predict(pil_img, imgsz=TILE, conf=CONF,
device='cuda:0', verbose=False)
if not results or results[0].masks is None:
processed += 1
continue
masks_data = results[0].masks.data.cpu().numpy()
h_r, w_r = masks_data.shape[1], masks_data.shape[2]
sx = tw / w_r
sy = th / h_r
for xy in results[0].masks.xy:
if len(xy) < 4:
continue
# YOLO 좌표 → 타일 픽셀 좌표
xy_tile = [(float(lx) * sx, float(ly) * sy) for lx, ly in xy]
# PCA 슬라이싱으로 중심선 추출 (타일 픽셀 좌표)
cl_tile = polygon_to_centerline(xy_tile, spacing_px=SPACING)
if len(cl_tile) < 2:
continue
# overlap 가장자리 제거 + 세계 좌표 변환
pts_world = []
for lx, ly in cl_tile:
if tx + tw < W and lx > (tw - OVERLAP // 2):
continue
if ty + th < H and ly > (th - OVERLAP // 2):
continue
pts_world.append(pixel_to_world(tx + lx, ty + ly))
if len(pts_world) >= 2:
msp.add_lwpolyline(pts_world,
dxfattribs={"layer": "RAIL_CENTERLINE"})
total_polylines += 1
processed += 1
if processed % 500 == 0:
pct = processed / total_tiles * 100
print(f" 진행: {processed}/{total_tiles} ({pct:.1f}%)"
f" | 폴리라인: {total_polylines}")
src.close()
doc.saveas(OUT_DXF)
print(f"\n완료: {total_polylines}개 폴리라인 → {OUT_DXF}")

234
tools/rail_to_dxf.py Normal file
View File

@@ -0,0 +1,234 @@
"""
rail_to_dxf.py
==============
X-AnyLabeling JSON 어노테이션에서 레일 중심선을 추출하여 Rhino용 DXF로 저장.
사용법:
python tools/rail_to_dxf.py <annotation.json> [output.dxf]
예시:
python tools/rail_to_dxf.py images/rail.json
python tools/rail_to_dxf.py images/rail.json output/rail_centerline.dxf
라벨 이름 (X-AnyLabeling에서 Finish 후 입력한 이름):
rail, railway_track, track, AUTOLABEL_OBJECT 자동 인식
다른 이름이면 스크립트 하단 TARGET_LABELS 수정
Rhino에서 사용:
1. DXF Import
2. RAIL_CENTERLINE 레이어 선택
3. Sweep2 또는 Rail Sweep으로 레일 단면 적용
"""
import json
import sys
import numpy as np
import cv2
from pathlib import Path
from skimage.morphology import skeletonize
# ─── 설정 ─────────────────────────────────────────────────────────────────────
TARGET_LABELS = [
"rail",
"railline",
"railway_track",
"track",
"레일",
"철로",
"AUTOLABEL_OBJECT",
]
# line/linestrip 타입은 스켈레톤 불필요 — 직접 DXF 출력
LINE_SHAPE_TYPES = {"line", "linestrip", "lines"}
SMOOTH_WINDOW = 15 # 중심선 스무딩 강도 (클수록 부드러움, 0=비활성)
DOWNSAMPLE_STEP = 8 # Rhino 폴리라인 포인트 간격 (클수록 포인트 수 감소)
DILATION_ITER = 3 # 마스크 팽창 반복 (얇은 레일 마스크 연결 보완)
# ──────────────────────────────────────────────────────────────────────────────
def polygon_to_mask(points, h, w):
mask = np.zeros((h, w), dtype=np.uint8)
pts = np.array([[int(p[0]), int(p[1])] for p in points], dtype=np.int32)
cv2.fillPoly(mask, [pts], 1)
return mask
def extract_skeleton(mask):
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.dilate(mask, kernel, iterations=DILATION_ITER)
return skeletonize(dilated > 0).astype(np.uint8)
def order_skeleton(skeleton):
"""스켈레톤 픽셀을 끝점에서 시작해 순서대로 연결."""
ys, xs = np.where(skeleton > 0)
if len(xs) == 0:
return []
pt_set = set(zip(xs.tolist(), ys.tolist()))
def neighbors(pt):
x, y = pt
return [(x+dx, y+dy)
for dx in (-1,0,1) for dy in (-1,0,1)
if not (dx==0 and dy==0) and (x+dx, y+dy) in pt_set]
# 끝점(이웃 1개) 찾기 → 없으면 임의 시작
endpoints = [p for p in pt_set if len(neighbors(p)) == 1]
start = endpoints[0] if endpoints else next(iter(pt_set))
ordered, visited = [start], {start}
current = start
while True:
nbs = [n for n in neighbors(current) if n not in visited]
if not nbs:
break
# 가장 직선에 가까운 방향 우선 선택
if len(ordered) >= 2:
dx = current[0] - ordered[-2][0]
dy = current[1] - ordered[-2][1]
nbs.sort(key=lambda n: -((n[0]-current[0])*dx + (n[1]-current[1])*dy))
current = nbs[0]
visited.add(current)
ordered.append(current)
return ordered
def smooth_polyline(points, window):
if window < 3 or len(points) < window:
return points
pts = np.array(points, dtype=float)
half = window // 2
out = pts.copy()
for i in range(half, len(pts) - half):
out[i] = pts[i-half:i+half+1].mean(axis=0)
return out.tolist()
def downsample(points, step):
if step <= 1 or len(points) <= step:
return points
sampled = points[::step]
if list(sampled[-1]) != list(points[-1]):
sampled = list(sampled) + [points[-1]]
return sampled
def to_dxf_coords(points, flip_y=True):
"""이미지 좌표(Y↓) → DXF 좌표(Y↑)"""
if flip_y:
return [(float(p[0]), float(-p[1])) for p in points]
return [(float(p[0]), float(p[1])) for p in points]
def process(json_path: str, dxf_path: str):
import ezdxf
with open(json_path, encoding="utf-8") as f:
data = json.load(f)
H = data.get("imageHeight", 1000)
W = data.get("imageWidth", 1000)
print(f"[이미지] {W} × {H} px")
shapes = data.get("shapes", [])
targets = [s for s in shapes if s.get("label") in TARGET_LABELS]
if not targets:
print("⚠ TARGET_LABELS 일치 없음 → 모든 shape 사용")
targets = shapes
type_counts = {}
for s in targets:
t = s.get("shape_type", "unknown")
type_counts[t] = type_counts.get(t, 0) + 1
print(f"[처리] {len(targets)}개: {type_counts}")
doc = ezdxf.new("R2010")
msp = doc.modelspace()
doc.layers.add("RAIL_CENTERLINE", color=1) # 빨강 — Sweep 경로
doc.layers.add("RAIL_POLYGON", color=3) # 초록 — 원본 마스크 윤곽
for i, shape in enumerate(targets):
label = shape.get("label", f"shape_{i}")
pts = shape.get("points", [])
shape_type = shape.get("shape_type", "polygon")
print(f"\n [{i+1}] label={label!r} type={shape_type!r} pts={len(pts)}")
if len(pts) < 2:
print(" → 포인트 부족, 건너뜀")
continue
# 중복 끝점 제거 (Polygon 도구로 그린 선: 마지막 점이 앞 점과 거의 동일)
if len(pts) >= 3:
dedup = [pts[0]]
for p in pts[1:]:
if abs(p[0]-dedup[-1][0]) > 2 or abs(p[1]-dedup[-1][1]) > 2:
dedup.append(p)
if len(dedup) < len(pts):
print(f" 중복점 제거: {len(pts)}pt → {len(dedup)}pt")
pts = dedup
# ── LINE 타입: 스켈레톤 없이 직접 DXF 출력 ──────
if shape_type in LINE_SHAPE_TYPES or len(pts) == 2:
cl_dxf = to_dxf_coords(pts)
msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"})
print(f" → 라인 직접 출력 ({len(pts)}pt)")
continue
# ── POLYGON 타입: 마스크 → 스켈레톤 → 중심선 ────
if len(pts) < 3:
print(" → 폴리곤 포인트 부족, 건너뜀")
continue
poly_dxf = to_dxf_coords(pts)
poly_dxf.append(poly_dxf[0])
msp.add_lwpolyline(poly_dxf, dxfattribs={"layer": "RAIL_POLYGON"})
mask = polygon_to_mask(pts, H, W)
area = int(mask.sum())
print(f" 마스크 면적: {area} px²")
if area < 20:
print(" → 면적 너무 작음, 건너뜀")
continue
skel = extract_skeleton(mask)
skel_n = int(skel.sum())
print(f" 스켈레톤 픽셀: {skel_n}")
if skel_n < 2:
print(" → 스켈레톤 생성 실패, 건너뜀")
continue
ordered = order_skeleton(skel)
smoothed = smooth_polyline(ordered, SMOOTH_WINDOW)
final_pts = downsample(smoothed, DOWNSAMPLE_STEP)
print(f" 중심선 포인트: {len(final_pts)}")
if len(final_pts) < 2:
continue
cl_dxf = to_dxf_coords(final_pts)
msp.add_lwpolyline(cl_dxf, dxfattribs={"layer": "RAIL_CENTERLINE"})
Path(dxf_path).parent.mkdir(parents=True, exist_ok=True)
doc.saveas(dxf_path)
print(f"\n[완료] DXF 저장: {dxf_path}")
print(" 레이어: RAIL_CENTERLINE(빨강) = Sweep 경로")
print(" RAIL_POLYGON(초록) = 원본 마스크")
print("\nRhino 사용법:")
print(" 1. File → Import → DXF 선택")
print(" 2. RAIL_CENTERLINE 레이어 선택")
print(" 3. Rail 단면 커브 그리기 (레일 단면 약 60x60mm)")
print(" 4. Sweep1 또는 Sweep2 명령으로 단면 × 중심선 → 3D 레일 생성")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("사용법: python tools/rail_to_dxf.py <annotation.json> [output.dxf]")
sys.exit(1)
json_file = sys.argv[1]
dxf_file = sys.argv[2] if len(sys.argv) >= 3 else str(Path(json_file).with_suffix(".dxf"))
process(json_file, dxf_file)

870
tools/railway_pipeline.py Normal file
View File

@@ -0,0 +1,870 @@
"""
railway_pipeline.py
===================
정사영상(GeoTIFF/PNG)에서 철도 시설물을 자동 검출하여
실좌표(UTM/WGS84) 기반 DXF + GeoJSON 출력.
이미지에 실제로 보이는 것을 그대로 검출 (표준 규격 기반 아님).
사용법:
python tools/railway_pipeline.py <image.tif|png> [output_dir]
예시:
python tools/railway_pipeline.py 경부선.tif output/
python tools/railway_pipeline.py 경부선.png output/ # PNG는 GSD=0.05m 가정
출력:
output/railway_rails.dxf - 레일 중심선
output/railway_objects.dxf - 전체 시설물 (레이어별)
output/railway_objects.geojson - GIS 활용용
output/railway_debug.jpg - 검출 결과 시각화
"""
import sys
import json
import math
import numpy as np
import cv2
from pathlib import Path
from collections import defaultdict
# ─── 검출 파라미터 ─────────────────────────────────────────────────────────────
RAIL = dict(
canny_low=30, canny_high=80,
hough_threshold=400, min_len=500, max_gap=30,
cluster_dist=14, min_total_len=3000,
)
SLEEPER = dict(
search_width=60, # 레일 양옆 검색 폭 (px)
min_area=30, # 최소 면적
max_area=2000, # 최대 면적
aspect_min=2.0, # 최소 가로세로비 (침목은 길쭉)
)
POLE = dict(
search_margin=200, # 레일 옆 검색 범위 (px)
min_radius=3, # 최소 반지름 (px)
max_radius=25, # 최대 반지름 (px)
min_dist=40, # 검출 최소 간격 (px)
param1=50, param2=25, # Hough Circle 파라미터
)
CBOX = dict(
search_margin=300, # 레일 옆 검색 범위 (px)
min_area=100,
max_area=8000,
aspect_min=1.2, # 직사각형 비율
aspect_max=6.0,
min_solidity=0.75, # 컨투어 충실도 (직사각형에 가까울수록 1)
)
GABOR = dict(
lambdas=[5, 7, 10], # 침목 간격 범위 (px, 축소 이미지 기준)
# 5cm GSD, 8000px 폭(scale≈0.54): 600mm/50mm*0.54 ≈ 6.5px
sigma=3.0, # Gabor 커널 폭
gamma=0.5, # 종횡비 (1=원형, 0.5=타원)
n_angles=12, # 방향 수 (0°~165°, 15° 간격)
thresh_pct=60, # 응답 임계값 백분위 (높을수록 엄격)
min_area=4000, # 최소 연결성분 면적 (작은 노이즈 제거)
)
SCALE_FACTOR = 8000 # 침목 패턴 감지를 위해 높은 해상도 유지
# ──────────────────────────────────────────────────────────────────────────────
# ══════════════════════════════════════════════════════════════════════════════
# 좌표 변환
# ══════════════════════════════════════════════════════════════════════════════
class GeoTransform:
"""픽셀좌표 ↔ 실좌표 변환."""
def __init__(self, image_path: str):
self.crs = None
self.affine = None
self.origin = (0.0, 0.0)
self.pixel_size = (0.05, 0.05) # 기본 5cm GSD
self._load(image_path)
def _load(self, path: str):
if path.lower().endswith(('.tif', '.tiff')):
try:
import rasterio
with rasterio.open(path) as r:
self.crs = str(r.crs)
t = r.transform
self.affine = t
self.pixel_size = (abs(t.a), abs(t.e))
self.origin = (t.c, t.f)
print(f" CRS: {self.crs}")
print(f" 픽셀크기: {self.pixel_size[0]:.4f}m x {self.pixel_size[1]:.4f}m")
print(f" 원점: ({self.origin[0]:.2f}, {self.origin[1]:.2f})")
except Exception as e:
print(f" [경고] rasterio 실패: {e} → 픽셀좌표 사용")
else:
# PNG: world file 탐색
wld = Path(path).with_suffix('.pgw')
if not wld.exists():
wld = Path(path).with_suffix('.wld')
if wld.exists():
vals = [float(l.strip()) for l in wld.read_text().splitlines() if l.strip()]
if len(vals) >= 6:
self.pixel_size = (abs(vals[0]), abs(vals[3]))
self.origin = (vals[4], vals[5])
print(f" World file 적용: origin={self.origin}")
else:
print(f" World file 없음 → 픽셀좌표 사용 (GSD=0.05m)")
def px_to_world(self, px: float, py: float):
"""픽셀 (col, row) → 실좌표 (x, y)."""
if self.affine:
from rasterio.transform import xy as rio_xy
try:
x, y = self.affine * (px, py)
return float(x), float(y)
except Exception:
pass
x = self.origin[0] + px * self.pixel_size[0]
y = self.origin[1] - py * self.pixel_size[1]
return x, y
def scale_coords(self, pts, scale: float):
"""축소된 이미지의 픽셀좌표를 원본 픽셀좌표로 변환 후 실좌표 출력."""
return [self.px_to_world(p[0] / scale, p[1] / scale) for p in pts]
def scale_point(self, px: float, py: float, scale: float):
return self.px_to_world(px / scale, py / scale)
# ══════════════════════════════════════════════════════════════════════════════
# 이미지 로드 (한글 경로 대응)
# ══════════════════════════════════════════════════════════════════════════════
def load_image(path: str):
"""TIF/PNG 모두 BGR numpy array로 반환."""
if path.lower().endswith(('.tif', '.tiff')):
try:
import rasterio
with rasterio.open(path) as r:
if r.count >= 3:
arr = np.dstack([r.read(i) for i in [1, 2, 3]])
# uint16 → uint8 변환
if arr.dtype != np.uint8:
arr = (arr / arr.max() * 255).astype(np.uint8)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
else:
band = r.read(1)
if band.dtype != np.uint8:
band = (band / band.max() * 255).astype(np.uint8)
return cv2.cvtColor(band, cv2.COLOR_GRAY2BGR)
except Exception as e:
print(f" rasterio 읽기 실패: {e}, PIL 시도...")
# PNG 또는 fallback
with open(path, 'rb') as f:
data = f.read()
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
return img
# ══════════════════════════════════════════════════════════════════════════════
# 자갈도상 마스크
# ══════════════════════════════════════════════════════════════════════════════
def extract_ballast_mask(img):
"""자갈도상(ballast) 구역 마스크 추출.
항공 정사영상에서 자갈도상은 밝고(V>70) 채도 낮음(S<60) — HSV 임계값 후
큰 연결성분만 유지하여 선로 띠 영역만 남김.
"""
H, W = img.shape[:2]
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# 낮은 채도 + 중간-높은 밝기 = 자갈/회색 (채도 45 이하로 더 엄격하게)
mask = cv2.inRange(hsv,
np.array([0, 0, 80], np.uint8),
np.array([180, 45, 210], np.uint8))
# 작은 노이즈 제거
k_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k_open, iterations=2)
# 팽창: 자갈 구역 연결 + 레일 엣지까지 포함 (2회로 줄여 과팽창 방지)
k_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.dilate(mask, k_dil, iterations=2)
# 연결성분 중 이미지 면적 0.5% 이상만 유지 (도로·건물 등 소규모 제거)
min_area = max(H * W * 0.005, 5000)
n, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
result = np.zeros_like(mask)
for lbl in range(1, n):
if stats[lbl, cv2.CC_STAT_AREA] >= min_area:
result[labels == lbl] = 255
kept = int((result > 0).sum())
total = H * W
print(f" 자갈도상 마스크: {kept:,}px ({kept/total*100:.1f}%)")
return result
# ══════════════════════════════════════════════════════════════════════════════
# Gabor 침목 패턴 마스크 + 레일 검출 (통합)
# ══════════════════════════════════════════════════════════════════════════════
def detect_rails_gabor(img):
"""
[1단계] Gabor 필터 뱅크로 침목 줄무늬 패턴 영역 추출
→ 자갈도상 색상이 아닌 침목 주기 텍스처로 선로 구역 판별
→ 도로는 이 패턴 없으므로 자동 제거
[2단계] 침목 마스크 안에서 Hough로 개별 레일 중심선 검출
Returns: (rail_polylines, gabor_mask)
"""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
H, W = gray.shape
# ── Gabor 필터 뱅크 적용 ──────────────────────────────────
response_max = np.zeros((H, W), dtype=np.float32)
n_ang = GABOR['n_angles']
for i in range(n_ang):
theta = i * np.pi / n_ang # 0 ~ π (모든 방향)
for lam in GABOR['lambdas']:
ksize = max(int(GABOR['sigma'] * 6) | 1, 7) # 홀수 보장
kernel = cv2.getGaborKernel(
(ksize, ksize),
GABOR['sigma'], theta, lam,
GABOR['gamma'], 0, cv2.CV_32F
)
resp = np.abs(cv2.filter2D(gray.astype(np.float32), cv2.CV_32F, kernel))
response_max = np.maximum(response_max, resp)
# ── 응답 정규화 + 임계값 ──────────────────────────────────
resp_u8 = cv2.normalize(response_max, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
nz = resp_u8[resp_u8 > 0]
thresh_val = int(np.percentile(nz, GABOR['thresh_pct'])) if len(nz) else 128
_, binary = cv2.threshold(resp_u8, thresh_val, 255, cv2.THRESH_BINARY)
# ── 모폴로지 정리 ─────────────────────────────────────────
k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, k5, iterations=2)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, k3, iterations=1)
# ── 큰 연결성분만 유지 (소규모 노이즈 제거) ───────────────
n_cc, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
gabor_mask = np.zeros_like(binary)
for lbl in range(1, n_cc):
if stats[lbl, cv2.CC_STAT_AREA] >= GABOR['min_area']:
gabor_mask[labels == lbl] = 255
kept = int((gabor_mask > 0).sum())
print(f" Gabor 마스크: {kept:,}px ({kept / (H * W) * 100:.1f}%)")
# ── Gabor 마스크 내에서 Hough 레일 중심선 검출 ───────────
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
blurred = cv2.GaussianBlur(enhanced, (5, 5), 0)
edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high'])
edges = cv2.bitwise_and(edges, edges, mask=gabor_mask)
print(f" 마스크 내 엣지: {int(edges.sum() // 255):,}px")
lines_raw = cv2.HoughLinesP(
edges, 1, np.pi / 180,
threshold=RAIL['hough_threshold'],
minLineLength=RAIL['min_len'],
maxLineGap=RAIL['max_gap'],
)
if lines_raw is None:
print(" Hough 검출 없음")
return [], gabor_mask
lines = [tuple(l[0]) for l in lines_raw]
print(f" Hough 검출: {len(lines)}")
# ── 클러스터링 → 폴리라인 ────────────────────────────────
# 곡선 구간 지원: 각도 기준 아닌 미드포인트 거리 기준으로 클러스터링
# 같은 레일의 곡선 세그먼트들은 인접 미드포인트를 가짐 → 체인으로 연결
ATOL = RAIL['cluster_dist']
# 각도별 그룹화 후 수직거리 기반 클러스터링
# 분기기 구간: 강제 연결 없이 세그먼트 단위로 유지 (분기기는 수동 연결)
def angle(x1, y1, x2, y2):
return np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180
def perp_dist(x1, y1, x2, y2, px, py):
dx, dy = x2 - x1, y2 - y1
L = math.hypot(dx, dy)
if L == 0:
return math.hypot(px - x1, py - y1)
return abs(dy * px - dx * py + x2 * y1 - y2 * x1) / L
angle_groups = defaultdict(list)
for l in lines:
key = round(angle(*l) / 5) * 5
angle_groups[key].append(l)
clusters = []
for grp in angle_groups.values():
used = [False] * len(grp)
for i, li in enumerate(grp):
if used[i]:
continue
cl = [li]
used[i] = True
for j, lj in enumerate(grp):
if used[j]:
continue
mxj = ((lj[0]+lj[2])/2, (lj[1]+lj[3])/2)
if perp_dist(*li, *mxj) < ATOL:
cl.append(lj)
used[j] = True
clusters.append(cl)
def merge(cluster):
pts = []
for x1, y1, x2, y2 in cluster:
pts += [(x1, y1), (x2, y2)]
arr = np.array(pts, float)
mean = arr.mean(0)
cov = np.cov((arr - mean).T)
if cov.ndim < 2:
direction = np.array([1.0, 0.0])
else:
ev, evec = np.linalg.eig(cov)
direction = evec[:, np.argmax(ev)]
proj = (arr - mean).dot(direction)
sorted_pts = arr[np.argsort(proj)]
merged = [sorted_pts[0]]
for p in sorted_pts[1:]:
if math.hypot(p[0] - merged[-1][0], p[1] - merged[-1][1]) > 5:
merged.append(p)
return merged
def poly_len(poly):
return sum(
math.hypot(poly[i + 1][0] - poly[i][0], poly[i + 1][1] - poly[i][1])
for i in range(len(poly) - 1)
)
result = []
for cl in clusters:
poly = merge(cl)
if poly_len(poly) >= RAIL['min_total_len']:
result.append(poly)
result.sort(key=lambda p: -poly_len(p))
return result, gabor_mask
# ══════════════════════════════════════════════════════════════════════════════
# 레일 검출 (Hough — fallback용, 주 검출은 detect_rails_gabor 사용)
# ══════════════════════════════════════════════════════════════════════════════
def detect_rails(img, ballast_mask=None):
"""Hough Line으로 레일 중심선 폴리라인 반환."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
enhanced = clahe.apply(gray)
blurred = cv2.GaussianBlur(enhanced, (5,5), 0)
edges = cv2.Canny(blurred, RAIL['canny_low'], RAIL['canny_high'])
# 자갈도상 마스크 적용 — 선로 구역 외 엣지 제거
if ballast_mask is not None:
edges = cv2.bitwise_and(edges, edges, mask=ballast_mask)
print(f" 마스크 적용 후 엣지: {int(edges.sum()//255):,}px")
lines_raw = cv2.HoughLinesP(edges, 1, np.pi/180,
threshold=RAIL['hough_threshold'],
minLineLength=RAIL['min_len'],
maxLineGap=RAIL['max_gap'])
if lines_raw is None:
return []
lines = [tuple(l[0]) for l in lines_raw]
def angle(x1,y1,x2,y2):
return np.degrees(np.arctan2(y2-y1, x2-x1)) % 180
def line_len(x1,y1,x2,y2):
return math.hypot(x2-x1, y2-y1)
def midpoint(x1,y1,x2,y2):
return ((x1+x2)/2, (y1+y2)/2)
def perp_dist(x1,y1,x2,y2,px,py):
dx,dy = x2-x1, y2-y1
L = math.hypot(dx,dy)
if L == 0: return math.hypot(px-x1, py-y1)
return abs(dy*px - dx*py + x2*y1 - y2*x1) / L
# 방향 필터 불필요 — 자갈도상 마스크가 공간 제약 역할을 하므로
# 커브 구간도 포함하려면 모든 방향 허용
print(f" Hough 검출: {len(lines)}")
# 각도별 그룹화 후 거리 클러스터링
ATOL = RAIL['cluster_dist']
angle_groups = defaultdict(list)
for l in lines:
key = round(angle(*l) / 5) * 5
angle_groups[key].append(l)
clusters = []
for grp in angle_groups.values():
used = [False]*len(grp)
for i, li in enumerate(grp):
if used[i]: continue
cl = [li]; used[i] = True
mx, my = midpoint(*li)
for j, lj in enumerate(grp):
if used[j]: continue
if perp_dist(*li, *midpoint(*lj)) < ATOL:
cl.append(lj); used[j] = True
clusters.append(cl)
def merge(cluster):
pts = []
for x1,y1,x2,y2 in cluster:
pts += [(x1,y1),(x2,y2)]
arr = np.array(pts, float)
mean = arr.mean(0)
cov = np.cov((arr - mean).T)
if cov.ndim < 2: direction = np.array([1.0,0.0])
else:
ev, evec = np.linalg.eig(cov)
direction = evec[:, np.argmax(ev)]
proj = (arr - mean).dot(direction)
idx = np.argsort(proj)
sorted_pts = arr[idx]
merged = [sorted_pts[0]]
for p in sorted_pts[1:]:
if math.hypot(p[0]-merged[-1][0], p[1]-merged[-1][1]) > 5:
merged.append(p)
return merged
def poly_len(poly):
return sum(math.hypot(poly[i+1][0]-poly[i][0], poly[i+1][1]-poly[i][1])
for i in range(len(poly)-1))
result = []
for cl in clusters:
poly = merge(cl)
if poly_len(poly) >= RAIL['min_total_len']:
result.append(poly)
result.sort(key=lambda p: -poly_len(p))
return result
# ══════════════════════════════════════════════════════════════════════════════
# 침목 검출
# ══════════════════════════════════════════════════════════════════════════════
def detect_sleepers(img, rail_polys, gabor_mask=None):
"""레일 영역 주변에서 침목(가로 직사각형) 검출."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 레일 위치로 마스크 생성
mask = np.zeros(gray.shape, np.uint8)
sw = SLEEPER['search_width']
for poly in rail_polys:
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
cv2.polylines(mask, [pts], False, 255, sw * 2)
masked = cv2.bitwise_and(thresh, thresh, mask=mask)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
cleaned = cv2.morphologyEx(masked, cv2.MORPH_CLOSE, kernel)
contours, _ = cv2.findContours(cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
sleepers = []
for cnt in contours:
area = cv2.contourArea(cnt)
if not (SLEEPER['min_area'] <= area <= SLEEPER['max_area']):
continue
rect = cv2.minAreaRect(cnt)
(cx, cy), (w, h), angle = rect
if w == 0 or h == 0: continue
aspect = max(w, h) / min(w, h)
if aspect < SLEEPER['aspect_min']:
continue
sleepers.append({
'center': (float(cx), float(cy)),
'size': (float(w), float(h)),
'angle': float(angle),
'area': float(area),
})
return sleepers
# ══════════════════════════════════════════════════════════════════════════════
# 전철주 검출
# ══════════════════════════════════════════════════════════════════════════════
def detect_poles(img, rail_polys):
"""레일 주변에서 원형/점형 전철주 검출 (Hough Circle + Blob)."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (9,9), 2)
# 레일 근처 마스크
mask = np.zeros(gray.shape, np.uint8)
for poly in rail_polys:
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
cv2.polylines(mask, [pts], False, 255, POLE['search_margin'] * 2)
masked = cv2.bitwise_and(blurred, blurred, mask=mask)
circles = cv2.HoughCircles(
masked, cv2.HOUGH_GRADIENT, dp=1,
minDist=POLE['min_dist'],
param1=POLE['param1'],
param2=POLE['param2'],
minRadius=POLE['min_radius'],
maxRadius=POLE['max_radius'],
)
poles = []
if circles is not None:
for (cx, cy, r) in circles[0]:
poles.append({
'center': (float(cx), float(cy)),
'radius': float(r),
})
return poles
# ══════════════════════════════════════════════════════════════════════════════
# 컨트롤박스 검출
# ══════════════════════════════════════════════════════════════════════════════
def detect_control_boxes(img, rail_polys):
"""레일 근처 직사각형 객체 검출."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5,5), 0)
edges = cv2.Canny(blurred, 30, 90)
# 레일 근처 마스크
mask = np.zeros(gray.shape, np.uint8)
for poly in rail_polys:
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
cv2.polylines(mask, [pts], False, 255, CBOX['search_margin'] * 2)
masked_edges = cv2.bitwise_and(edges, edges, mask=mask)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
closed = cv2.morphologyEx(masked_edges, cv2.MORPH_CLOSE, kernel, iterations=2)
contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for cnt in contours:
area = cv2.contourArea(cnt)
if not (CBOX['min_area'] <= area <= CBOX['max_area']):
continue
hull = cv2.convexHull(cnt)
hull_area = cv2.contourArea(hull)
if hull_area == 0: continue
solidity = area / hull_area
if solidity < CBOX['min_solidity']:
continue
rect = cv2.minAreaRect(cnt)
(cx, cy), (w, h), angle = rect
if w == 0 or h == 0: continue
aspect = max(w, h) / min(w, h)
if not (CBOX['aspect_min'] <= aspect <= CBOX['aspect_max']):
continue
boxes.append({
'center': (float(cx), float(cy)),
'size': (float(w), float(h)),
'angle': float(angle),
'area': float(area),
})
return boxes
# ══════════════════════════════════════════════════════════════════════════════
# DXF 출력
# ══════════════════════════════════════════════════════════════════════════════
def save_dxf(output_path, rails_world, sleepers_world, poles_world, boxes_world):
import ezdxf
doc = ezdxf.new("R2010")
msp = doc.modelspace()
# 레이어 정의
doc.layers.add("RAIL", color=1) # 빨강
doc.layers.add("SLEEPER", color=3) # 초록
doc.layers.add("POLE", color=5) # 파랑
doc.layers.add("CONTROL_BOX", color=6) # 마젠타
# 레일 중심선
for poly in rails_world:
if len(poly) >= 2:
msp.add_lwpolyline(poly, dxfattribs={"layer": "RAIL"})
# 침목
for s in sleepers_world:
cx, cy = s['world_center']
w, h = s['size']
ang = s['angle']
scale = s.get('scale', 1.0)
# 회전된 직사각형
box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang))
msp.add_lwpolyline(
[(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))],
dxfattribs={"layer": "SLEEPER"}
)
# 전철주 (원)
for p in poles_world:
cx, cy = p['world_center']
r = p.get('world_radius', 1.0)
msp.add_circle((cx, -cy), r, dxfattribs={"layer": "POLE"})
# 컨트롤박스
for b in boxes_world:
cx, cy = b['world_center']
w, h = b['size']
ang = b['angle']
scale = b.get('scale', 1.0)
box_pts = cv2.boxPoints(((cx, -cy), (w/scale, h/scale), ang))
msp.add_lwpolyline(
[(float(p[0]), float(p[1])) for p in box_pts] + [(float(box_pts[0][0]), float(box_pts[0][1]))],
dxfattribs={"layer": "CONTROL_BOX"}
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
doc.saveas(output_path)
print(f" DXF: {output_path}")
# ══════════════════════════════════════════════════════════════════════════════
# GeoJSON 출력
# ══════════════════════════════════════════════════════════════════════════════
def save_geojson(output_path, rails_world, sleepers_world, poles_world, boxes_world):
features = []
for i, poly in enumerate(rails_world):
features.append({
"type": "Feature",
"geometry": {"type": "LineString", "coordinates": [[x, y] for x, y in poly]},
"properties": {"type": "rail", "id": i}
})
for i, s in enumerate(sleepers_world):
cx, cy = s['world_center']
features.append({
"type": "Feature",
"geometry": {"type": "Point", "coordinates": [cx, cy]},
"properties": {"type": "sleeper", "id": i, "angle": s['angle']}
})
for i, p in enumerate(poles_world):
cx, cy = p['world_center']
features.append({
"type": "Feature",
"geometry": {"type": "Point", "coordinates": [cx, cy]},
"properties": {"type": "pole", "id": i, "radius_m": p.get('world_radius', 0)}
})
for i, b in enumerate(boxes_world):
cx, cy = b['world_center']
features.append({
"type": "Feature",
"geometry": {"type": "Point", "coordinates": [cx, cy]},
"properties": {"type": "control_box", "id": i,
"width_m": b.get('world_w', 0), "height_m": b.get('world_h', 0)}
})
geojson = {"type": "FeatureCollection", "features": features}
if not output_path.endswith('.geojson'):
output_path += '.geojson'
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(geojson, f, ensure_ascii=False, indent=2)
print(f" GeoJSON: {output_path}")
# ══════════════════════════════════════════════════════════════════════════════
# 시각화
# ══════════════════════════════════════════════════════════════════════════════
def save_debug(img, rails, sleepers, poles, boxes, output_path, ballast_mask=None):
vis = img.copy()
# 자갈도상 마스크 반투명 오버레이 (노란색)
if ballast_mask is not None:
overlay = vis.copy()
overlay[ballast_mask > 0] = (0, 200, 200)
cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis)
# 레일
for poly in rails:
pts = np.array([[int(p[0]), int(p[1])] for p in poly])
cv2.polylines(vis, [pts], False, (0,0,255), 2)
# 침목
for s in sleepers:
cx, cy = int(s['center'][0]), int(s['center'][1])
rect = ((cx,cy), (s['size'][0], s['size'][1]), s['angle'])
box = cv2.boxPoints(rect).astype(int)
cv2.drawContours(vis, [box], 0, (0,255,0), 1)
# 전철주
for p in poles:
cx, cy, r = int(p['center'][0]), int(p['center'][1]), int(p['radius'])
cv2.circle(vis, (cx, cy), r, (255,0,0), 2)
# 컨트롤박스
for b in boxes:
cx, cy = int(b['center'][0]), int(b['center'][1])
rect = ((cx,cy), (b['size'][0], b['size'][1]), b['angle'])
box = cv2.boxPoints(rect).astype(int)
cv2.drawContours(vis, [box], 0, (255,0,255), 2)
# 범례
legend = [("RAIL", (0,0,255)),
("SLEEPER", (0,255,0)),
("POLE", (255,0,0)),
("CONTROL_BOX", (255,0,255))]
for i, (name, color) in enumerate(legend):
cv2.rectangle(vis, (10, 10+i*28), (24, 24+i*28), color, -1)
cv2.putText(vis, name, (30, 22+i*28),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
with open(str(output_path), 'wb') as f:
_, buf = cv2.imencode('.jpg', vis, [cv2.IMWRITE_JPEG_QUALITY, 85])
f.write(buf.tobytes())
print(f" 시각화: {output_path}")
# ══════════════════════════════════════════════════════════════════════════════
# 메인 파이프라인
# ══════════════════════════════════════════════════════════════════════════════
def process(image_path: str, output_dir: str):
print(f"\n{'='*60}")
print(f"입력: {image_path}")
print(f"{'='*60}")
# 좌표 변환기 초기화
print("\n[좌표계]")
geo = GeoTransform(image_path)
# 이미지 로드
print("\n[이미지 로드]")
img = load_image(image_path)
if img is None:
print("오류: 이미지를 열 수 없습니다.")
sys.exit(1)
H, W = img.shape[:2]
print(f" 크기: {W} x {H} px")
# 처리용 축소
scale = 1.0
if W > SCALE_FACTOR:
scale = SCALE_FACTOR / W
small = cv2.resize(img, (int(W*scale), int(H*scale)))
print(f" 축소: {int(W*scale)} x {int(H*scale)} (x{scale:.2f})")
else:
small = img.copy()
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
stem = Path(image_path).stem
# ── Gabor 침목 패턴 검출 + 레일 중심선 ───────────────────
print("\n[1] Gabor 침목 패턴 → 레일 검출...")
rails_px, gabor_mask = detect_rails_gabor(small)
print(f" 검출: {len(rails_px)}")
# Gabor 마스크 저장 (확인용)
mask_path = out / f"{stem}_gabor_mask.jpg"
with open(str(mask_path), 'wb') as _f:
_, _buf = cv2.imencode('.jpg', gabor_mask)
_f.write(_buf.tobytes())
print(f" Gabor 마스크: {mask_path}")
# 레일 실좌표 변환
rails_world = []
for poly in rails_px:
world_pts = geo.scale_coords(poly, scale)
rails_world.append(world_pts)
# ── 침목 검출 ──────────────────────────────────────────
print("\n[2] 침목 검출...")
sleepers_px = detect_sleepers(small, rails_px, gabor_mask)
print(f" 검출: {len(sleepers_px)}")
sleepers_world = []
for s in sleepers_px:
cx, cy = s['center']
wx, wy = geo.scale_point(cx, cy, scale)
ps = geo.pixel_size[0] / scale
sleepers_world.append({
**s,
'world_center': (wx, wy),
'scale': scale,
})
# ── 전철주 검출 ────────────────────────────────────────
print("\n[3] 전철주 검출...")
poles_px = detect_poles(small, rails_px)
print(f" 검출: {len(poles_px)}")
poles_world = []
for p in poles_px:
cx, cy = p['center']
wx, wy = geo.scale_point(cx, cy, scale)
r_m = p['radius'] * geo.pixel_size[0] / scale
poles_world.append({
**p,
'world_center': (wx, wy),
'world_radius': r_m,
})
# ── 컨트롤박스 검출 ────────────────────────────────────
print("\n[4] 컨트롤박스 검출...")
boxes_px = detect_control_boxes(small, rails_px)
print(f" 검출: {len(boxes_px)}")
boxes_world = []
ps = geo.pixel_size[0]
for b in boxes_px:
cx, cy = b['center']
wx, wy = geo.scale_point(cx, cy, scale)
w_m = b['size'][0] * ps / scale
h_m = b['size'][1] * ps / scale
boxes_world.append({
**b,
'world_center': (wx, wy),
'world_w': w_m,
'world_h': h_m,
'scale': scale,
})
# ── 결과 요약 ──────────────────────────────────────────
print(f"\n{'='*60}")
print(f"검출 결과 요약")
print(f" 레일: {len(rails_world):>5}")
print(f" 침목: {len(sleepers_world):>5}")
print(f" 전철주: {len(poles_world):>5}")
print(f" 컨트롤박스: {len(boxes_world):>5}")
print(f"{'='*60}")
# ── 출력 ──────────────────────────────────────────────
print("\n[출력]")
save_dxf(
str(out / f"{stem}_railway.dxf"),
rails_world, sleepers_world, poles_world, boxes_world
)
save_geojson(
str(out / f"{stem}_railway.geojson"),
rails_world, sleepers_world, poles_world, boxes_world
)
save_debug(
small, rails_px, sleepers_px, poles_px, boxes_px,
out / f"{stem}_railway_debug.jpg",
ballast_mask=gabor_mask
)
print(f"\n완료.")
print(f" DXF 레이어: RAIL(빨강) / SLEEPER(초록) / POLE(파랑) / CONTROL_BOX(마젠타)")
print(f" Rhino: Import DXF → 레이어별 3D 모델 교체")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("사용법: python tools/railway_pipeline.py <image.tif|png> [output_dir]")
sys.exit(1)
img_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) >= 3 else "output"
process(img_path, output_dir)

View File

@@ -0,0 +1,535 @@
"""SAM 원본 폴리곤의 skeleton을 까만색으로 이미지에 겹쳐 표시.
union mask 기준 skeleton을 그리므로, 겹쳐있는 fragment들의 합쳐진 골격이 보임.
사용:
python tools/render_skeleton_overlay.py \
--image data/역사이미지/slope/DJI_20260306113842_0006.JPG \
--label output/autolabel/raw_0006/labels/DJI_20260306113842_0006.txt \
--output output/autolabel/raw_0006/vis_skeleton.jpg
"""
import argparse
import base64
import requests
import cv2
import numpy as np
from pathlib import Path
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
SUB_CATEGORIES = [
{"name": "pole_shaft", "name_kr": "전주기둥",
"prompt": "vertical steel pole shaft, cylindrical mast, tubular pole body",
"keywords": ["pole", "shaft", "mast", "tubular"],
"color_bgr": [0, 200, 255]},
{"name": "bracket_arm", "name_kr": "브라켓/암",
"prompt": "horizontal cantilever arm, bracket arm, cross arm, pole cross arm",
"keywords": ["bracket", "cantilever", "cross arm", "arm"],
"color_bgr": [255, 100, 0]},
{"name": "insulator", "name_kr": "애자",
"prompt": "ceramic insulator, glass insulator, suspension insulator, string of insulators",
"keywords": ["insulator"],
"color_bgr": [0, 255, 255]},
{"name": "wire", "name_kr": "가선/조가선",
"prompt": "overhead contact wire, catenary wire, messenger wire, electric cable",
"keywords": ["wire", "cable", "catenary"],
"color_bgr": [200, 200, 255]},
]
def _sam3_call(crop_bgr: np.ndarray, prompt: str, conf: float = 0.15) -> list:
h, w = crop_bgr.shape[:2]
scale = 1.0
if max(h, w) > 1280:
scale = 1280 / max(h, w)
crop_bgr = cv2.resize(crop_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", crop_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
b64 = base64.b64encode(buf).decode()
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {"text_prompt": prompt, "conf_threshold": conf,
"show_masks": True, "show_boxes": False},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception as e:
print(f" SAM3 오류: {e}")
return []
def _label_to_subcat(label_str: str):
ll = label_str.lower()
for cat in SUB_CATEGORIES:
if any(kw in ll for kw in cat["keywords"]):
return cat
return None
def _subdetect_group(img: np.ndarray, group_polys: list, pad: int = 80) -> int:
"""그룹 폴리곤들의 union bbox crop → SAM3 sub-detect → img 오버레이."""
all_pts = np.vstack(group_polys)
H, W = img.shape[:2]
x0 = max(0, int(all_pts[:, 0].min()) - pad)
y0 = max(0, int(all_pts[:, 1].min()) - pad)
x1 = min(W, int(all_pts[:, 0].max()) + pad)
y1 = min(H, int(all_pts[:, 1].max()) + pad)
crop = img[y0:y1, x0:x1].copy()
combined_prompt = ", ".join(c["prompt"] for c in SUB_CATEGORIES)
shapes = _sam3_call(crop, combined_prompt)
font = cv2.FONT_HERSHEY_SIMPLEX
for s in shapes:
cat = _label_to_subcat(s.get("label", ""))
color = tuple(cat["color_bgr"]) if cat else (128, 128, 128)
pts = np.array([[int(px + x0), int(py + y0)] for px, py in s["points"]], dtype=np.int32)
overlay = img.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.30, img, 0.70, 0, img)
cv2.polylines(img, [pts], True, color, 2)
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
name = cat["name"] if cat else "?"
cv2.putText(img, name, (cx, cy), font, 0.7, color, 2, cv2.LINE_AA)
cv2.rectangle(img, (x0, y0), (x1, y1), (255, 200, 0), 2)
return len(shapes)
def render(image_path: Path, label_path: Path, output_path: Path,
selected_indices=None, branch_radius: int = 10, class_ids=None,
subdetect: bool = False, subdetect_polys: set = None):
buf = np.fromfile(str(image_path), dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
img_orig = img.copy()
H, W = img.shape[:2]
text = label_path.read_text(encoding="utf-8").strip() if label_path.exists() else ""
if class_ids is not None:
text = "\n".join(
ln for ln in text.splitlines()
if ln.split() and int(ln.split()[0]) in class_ids
)
all_polygons = []
for line in text.splitlines():
parts = line.split()
if not parts:
continue
coords = list(map(float, parts[1:]))
pts = np.array(
[[coords[i] * W, coords[i + 1] * H] for i in range(0, len(coords), 2)],
dtype=np.int32,
)
if len(pts) >= 3:
all_polygons.append(pts)
if selected_indices is not None:
keep_idx = [i for i in sorted(selected_indices) if 0 <= i < len(all_polygons)]
else:
keep_idx = list(range(len(all_polygons)))
polygons = [all_polygons[i] for i in keep_idx]
print(f" total polygons in label: {len(all_polygons)}, "
f"selected: {len(polygons)} (indices: {keep_idx})")
# 수직/수평 판별: 장축 vs 이미지 중심→폴리곤 중심 radial 방향 비교
# 드론 사선 촬영에서 수직 기둥은 이미지 중심에서 방사형으로 기울어져 보임
img_cx, img_cy = W / 2.0, H / 2.0
def _poly_orient(pts):
rect = cv2.minAreaRect(pts)
(rx, ry), (rw, rh), angle = rect
if min(rw, rh) < 1:
return (160, 160, 160), '?', 0.0
ar = max(rw, rh) / min(rw, rh)
if ar < 1.3:
return (160, 160, 160), '?', 0.0
# 장축 방향 벡터
long_angle_deg = angle if rw >= rh else angle + 90
lx = np.cos(np.radians(long_angle_deg))
ly = np.sin(np.radians(long_angle_deg))
# radial 방향: 이미지 중심 → 폴리곤 중심
rdx, rdy = rx - img_cx, ry - img_cy
radial_norm = (rdx**2 + rdy**2) ** 0.5
if radial_norm < 1:
return (160, 160, 160), '?', 0.0
rdx, rdy = rdx / radial_norm, rdy / radial_norm
# 장축과 radial 방향의 정렬도 (|cos θ|)
cos_sim = abs(lx * rdx + ly * rdy)
# cos_sim → 1: 장축이 radial 방향 = 수직 기둥
# cos_sim → 0: 장축이 radial ⊥ 방향 = 수평 빔
if cos_sim > 0.7:
return (255, 80, 0), 'V', cos_sim
else:
return (0, 80, 255), 'H', cos_sim
orient_map = {} # orig_idx → 'V'/'H'/'?'
print("\n [폴리곤 수직/수평 분류]")
for orig_idx, pts in zip(keep_idx, polygons):
color, orient, cos_sim = _poly_orient(pts)
orient_map[orig_idx] = orient
print(f" poly {orig_idx:>2d}: {orient} cos_sim={cos_sim:.3f}")
overlay = img.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.25, img, 0.75, 0, img)
cv2.polylines(img, [pts], True, color, 2)
cx = int(pts[:, 0].mean())
cy = int(pts[:, 1].mean())
label = f"{orig_idx}{orient}"
cv2.putText(img, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 5)
cv2.putText(img, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)
# 모든 폴리곤 합친 마스크 → skeleton
merged = np.zeros((H, W), dtype=np.uint8)
for pts in polygons:
cv2.fillPoly(merged, [pts], 255)
# 미세한 픽셀 overlap으로 다른 구조가 합쳐지는 것 방지: 작은 erosion
# 5px dilation으로 인접 폴리곤 gap 브리지 후 3px erosion으로 복원
merged_dilated = cv2.dilate(merged, np.ones((11, 11), dtype=np.uint8))
merged_eroded = cv2.erode(merged_dilated, np.ones((7, 7), dtype=np.uint8))
skel = cv2.ximgproc.thinning(merged_eroded)
# skeleton 두껍게 (시인성)
skel_thick = cv2.dilate(skel, np.ones((5, 5), dtype=np.uint8))
img[skel_thick > 0] = (0, 0, 0)
# spur pruning: 짧은 dead-end arm 제거 → 진짜 branch만 남음
_box3 = np.ones((3, 3), dtype=np.float32)
def _prune_spurs(sk, min_arm_px):
"""branch pixel 제거 후 arm component 분석, 짧은 arm 제거."""
s_i = (sk > 0).astype(np.int32)
sp2 = np.pad(s_i, 1, mode='constant')
h2, w2 = s_i.shape
n2 = [sp2[0:h2, 0:w2], sp2[0:h2, 1:w2+1], sp2[0:h2, 2:w2+2],
sp2[1:h2+1, 2:w2+2], sp2[2:h2+2, 2:w2+2], sp2[2:h2+2, 1:w2+1],
sp2[2:h2+2, 0:w2], sp2[1:h2+1, 0:w2]]
cn2 = np.zeros((h2, w2), dtype=np.int32)
for i2 in range(8):
cn2 += (1 - n2[i2]) * n2[(i2 + 1) % 8]
cn2 *= s_i
arm_sk = sk.copy()
arm_sk[cn2 >= 3] = 0 # branch pixel 제거
n_arm, arm_lbl, arm_stats, arm_cents = cv2.connectedComponentsWithStats(arm_sk)
sk_out = sk.copy()
spur_cs = []
for a_id in range(1, n_arm):
if arm_stats[a_id, cv2.CC_STAT_AREA] < min_arm_px:
sk_out[arm_lbl == a_id] = 0
spur_cs.append((int(arm_cents[a_id, 0]), int(arm_cents[a_id, 1])))
return sk_out, spur_cs
skel_pruned, spur_centroids = _prune_spurs(skel, min_arm_px=30)
print(f" spur pruning: {len(spur_centroids)}개 arm 제거 (min_arm_px=30)")
skel_pf = (skel_pruned > 0).astype(np.float32)
nbr_p = cv2.filter2D(skel_pf, cv2.CV_32F, _box3) - skel_pf
deg_p = (nbr_p * skel_pf).astype(np.int32)
branch_mask_global = ((deg_p >= 3) & (skel_pruned > 0)).astype(np.uint8) * 255
endpoint_mask_global = ((deg_p == 1) & (skel_pruned > 0)).astype(np.uint8) * 255
def cluster_centroids(mask, cluster_radius):
if not mask.any():
return []
dilated = cv2.dilate(mask, np.ones((cluster_radius, cluster_radius), np.uint8))
num, _, _, cents = cv2.connectedComponentsWithStats(dilated)
return [(int(cents[i, 0]), int(cents[i, 1])) for i in range(1, num)]
def extreme_endpoint_pair(eps):
if len(eps) < 2:
return None
best, best_d2 = None, -1
for i in range(len(eps)):
for j in range(i + 1, len(eps)):
dx = eps[i][0] - eps[j][0]
dy = eps[i][1] - eps[j][1]
d2 = dx * dx + dy * dy
if d2 > best_d2:
best_d2 = d2
best = (eps[i], eps[j])
return best
# 그룹핑 = pruned skeleton의 connected component
num_comp, comp_labels = cv2.connectedComponents(skel_pruned)
# 진단: 각 polygon이 어느 component에 속하는지 + poly_idx→cids 역매핑 빌드
print(f"\n [Diagnostic] skeleton components: {num_comp - 1}")
poly_to_cids: dict = {}
for i, pts in enumerate(polygons):
poly_mask = np.zeros((H, W), dtype=np.uint8)
cv2.fillPoly(poly_mask, [pts], 255)
in_skel = (skel_pruned > 0) & (poly_mask > 0)
if in_skel.any():
cids = sorted(set(int(c) for c in np.unique(comp_labels[in_skel]) if c > 0))
poly_to_cids[i] = cids
else:
print(f" poly {i:>2d}: (skeleton 없음 — 너무 작거나 erosion됨)")
# 역매핑: component → polygon 목록
cid_to_polys: dict = {}
for poly_i, cids in poly_to_cids.items():
for cid in cids:
cid_to_polys.setdefault(cid, []).append(poly_i)
print(f"\n [그룹 → 폴리곤 목록] (전체 {len(cid_to_polys)}그룹)")
for cid in sorted(cid_to_polys.keys()):
n_px = int((comp_labels == cid).sum())
polys = cid_to_polys[cid]
orients = [f"{pi}({orient_map.get(pi, '?')})" for pi in polys]
has_v = any(orient_map.get(pi) == 'V' for pi in polys)
has_h = any(orient_map.get(pi) == 'H' for pi in polys)
tag = " ★라멘?" if (has_v and has_h) else ""
print(f" Comp {cid:>2d} ({n_px:>5d}px): {orients}{tag}")
print()
# subdetect_polys가 지정된 경우, 해당 poly들이 속한 component만 sub-detect
do_subdetect = subdetect or bool(subdetect_polys)
target_cids: set = set()
if subdetect_polys:
for pi in subdetect_polys:
for cid_ in poly_to_cids.get(pi, []):
target_cids.add(cid_)
print(f" sub-detect 대상 그룹: {sorted(target_cids)} (poly {sorted(subdetect_polys)} 기준)\n")
# cid → 해당 component에 속하는 polygon pts 목록 (subdetect용)
comp_to_polys: dict = {}
if do_subdetect:
for pts in polygons:
pmask = np.zeros((H, W), dtype=np.uint8)
cv2.fillPoly(pmask, [pts], 255)
in_s = (skel_pruned > 0) & (pmask > 0)
if in_s.any():
for cid_ in set(int(c) for c in comp_labels[in_s] if c > 0):
comp_to_polys.setdefault(cid_, []).append(pts)
total_branches = 0
total_endpoints = 0
total_groups = 0
total_lines = 0
for cid in range(1, num_comp):
comp_mask = (comp_labels == cid)
if int(comp_mask.sum()) < 50:
continue
total_groups += 1
comp_skel = (skel_pruned * comp_mask.astype(np.uint8)).astype(np.uint8)
comp_branch = cv2.bitwise_and(branch_mask_global, comp_skel)
comp_endpoint = cv2.bitwise_and(endpoint_mask_global, comp_skel)
branches = cluster_centroids(comp_branch, 7)
endpoints = cluster_centroids(comp_endpoint, 5)
# Virtual endpoints: endpoint 3개인 component에만 적용 (라멘에서 1개 누락된 경우).
# 알고리즘:
# 1) skeleton bbox 계산 → 4 변 (top/bottom/left/right)
# 2) 각 endpoint를 가장 가까운 변에 배정
# 3) 4 변 중 endpoint 없는 "빠진 방향" 찾음
# 4) 빠진 방향에 따라 branch 중 극값(가장 X/Y가 작거나 큰)을 virtual endpoint로
virtual_endpoints = []
if len(endpoints) == 3 and branches:
ys_arr, xs_arr = np.where(comp_skel > 0)
if len(ys_arr) > 0:
xmin, xmax = int(xs_arr.min()), int(xs_arr.max())
ymin, ymax = int(ys_arr.min()), int(ys_arr.max())
# 각 endpoint를 가장 가까운 변에 배정
occupied = set()
for ex, ey in endpoints:
d_top = ey - ymin
d_bot = ymax - ey
d_left = ex - xmin
d_right = xmax - ex
side, _ = min(
(('top', d_top), ('bottom', d_bot),
('left', d_left), ('right', d_right)),
key=lambda x: x[1])
occupied.add(side)
# 빠진 방향들
all_sides = {'top', 'bottom', 'left', 'right'}
missing_sides = all_sides - occupied
# 각 빠진 방향마다 극값 branch 찾기
for ms in missing_sides:
if ms == 'right':
target = max(branches, key=lambda b: b[0])
elif ms == 'left':
target = min(branches, key=lambda b: b[0])
elif ms == 'top':
target = min(branches, key=lambda b: b[1])
elif ms == 'bottom':
target = max(branches, key=lambda b: b[1])
too_close = any((target[0] - px) ** 2 + (target[1] - py) ** 2 < 80 ** 2
for px, py in endpoints + virtual_endpoints)
if not too_close:
virtual_endpoints.append(target)
print(f" virtual endpoint: missing={list(missing_sides)}, "
f"added={virtual_endpoints}")
# 통합 endpoint 리스트 (real + virtual)
endpoints_all = endpoints + virtual_endpoints
total_branches += len(branches)
total_endpoints += len(endpoints_all)
# ── 새 알고리즘: 최상단 노드 기반 topology ──────────────────────
if len(endpoints_all) < 2:
pass
else:
def node_dist(a, b):
return ((a[0]-b[0])**2 + (a[1]-b[1])**2) ** 0.5
# 1. 최상단 노드 T (y 최솟값)
top_i = min(range(len(endpoints_all)),
key=lambda i: (endpoints_all[i][1], endpoints_all[i][0]))
T = endpoints_all[top_i]
other_idx = [i for i in range(len(endpoints_all)) if i != top_i]
connected_idx = {top_i}
lines_to_draw = []
if len(other_idx) == 1:
j = other_idx[0]
lines_to_draw.append((T, endpoints_all[j]))
connected_idx.add(j)
longer_i = j
else:
# 2. Left(min x), Right(max x) → T에서 이 둘에만 연결
left_i = min(other_idx, key=lambda i: endpoints_all[i][0])
right_i = max(other_idx, key=lambda i: endpoints_all[i][0])
lines_to_draw.append((T, endpoints_all[left_i]))
lines_to_draw.append((T, endpoints_all[right_i]))
connected_idx.update([left_i, right_i])
# 3. 긴 쪽 결정
d_l = node_dist(T, endpoints_all[left_i])
d_r = node_dist(T, endpoints_all[right_i])
longer_i = left_i if d_l >= d_r else right_i
# 4. 체이닝: 긴 쪽 끝점에서 y+ 방향 미연결 노드로 계속 연결
current_i = longer_i
while True:
cur = endpoints_all[current_i]
below = [(i, node_dist(cur, endpoints_all[i]))
for i in range(len(endpoints_all))
if i not in connected_idx and endpoints_all[i][1] > cur[1]]
if not below:
break
next_i = min(below, key=lambda x: x[1])[0]
lines_to_draw.append((cur, endpoints_all[next_i]))
connected_idx.add(next_i)
current_i = next_i
for p1, p2 in lines_to_draw:
cv2.line(img, p1, p2, (0, 255, 255), 5)
total_lines += 1
print(f" line {p1}{p2}: {node_dist(p1,p2):.0f}px")
# 노드 표시 (line 위에 올라오게)
for cx, cy in branches:
cv2.circle(img, (cx, cy), 14, (0, 220, 0), 3) # 초록 중간 원 (진짜 T/X junction)
for cx, cy in endpoints:
cv2.circle(img, (cx, cy), 18, (0, 255, 0), 4)
# virtual endpoint = 오렌지 원
for cx, cy in virtual_endpoints:
cv2.circle(img, (cx, cy), 18, (0, 140, 255), 4)
should_sub = do_subdetect and cid in comp_to_polys
if should_sub and target_cids:
should_sub = cid in target_cids
if should_sub:
n = _subdetect_group(img, comp_to_polys[cid])
print(f" 그룹 {cid}: sub-detect {n}")
# 제거된 spur arm centroids = 분홍 작은 원
for cx, cy in spur_centroids:
cv2.circle(img, (cx, cy), 6, (180, 120, 180), 2)
print(f" {len(polygons)} polygons, skeleton 픽셀 {int((skel > 0).sum())}")
print(f" 그룹(skeleton component) {total_groups}")
print(f" 분기점 {total_branches}개, 끝점 {total_endpoints}개, 연결선 {total_lines}")
# 그룹 시각화: comp별 색상으로 폴리곤 표시 (원본 이미지 위)
_PALETTE = [
(255, 80, 80), ( 80, 220, 80), ( 80, 120, 255), (220, 200, 50),
(220, 80, 220), ( 60, 220, 220), (255, 160, 60), (100, 200, 120),
(180, 80, 255), (255, 140, 100),
]
_poly2comp = {pi: cid for cid, pis in cid_to_polys.items() for pi in pis}
_idx2pos = {orig: i for i, orig in enumerate(keep_idx)}
img = img_orig.copy()
for orig_idx, pts in zip(keep_idx, polygons):
cid = _poly2comp.get(orig_idx)
col = _PALETTE[(cid - 1) % len(_PALETTE)] if cid else (140, 140, 140)
ov = img.copy()
cv2.fillPoly(ov, [pts], col)
cv2.addWeighted(ov, 0.35, img, 0.65, 0, img)
cv2.polylines(img, [pts], True, col, 3)
cx, cy = int(pts[:, 0].mean()), int(pts[:, 1].mean())
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 4)
cv2.putText(img, str(orig_idx), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2)
for cid, pis in cid_to_polys.items():
col = _PALETTE[(cid - 1) % len(_PALETTE)]
valid = [_idx2pos[pi] for pi in pis if pi in _idx2pos]
if not valid:
continue
cx = int(sum(int(polygons[i][:, 0].mean()) for i in valid) / len(valid))
cy = int(sum(int(polygons[i][:, 1].mean()) for i in valid) / len(valid))
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 255), 7)
cv2.putText(img, f"C{cid}", (cx, cy + 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, col, 3)
h, w = img.shape[:2]
scale = min(1.0, 4096 / max(h, w))
if scale < 1.0:
img = cv2.resize(img, (int(w * scale), int(h * scale)))
output_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imencode(output_path.suffix, img)[1].tofile(str(output_path))
print(f"{output_path}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True)
ap.add_argument("--label", required=True)
ap.add_argument("--output", required=True)
ap.add_argument("--polygons", type=str, default="",
help="comma-separated polygon indices (예: '10,11,12,26,27'). 비우면 전체")
ap.add_argument("--class-ids", type=str, default="",
help="포함할 클래스 ID (예: '1' or '1,3'). 비우면 전체")
ap.add_argument("--branch-radius", type=int, default=10,
help="branch 판정 원 반지름 px (기본 10)")
ap.add_argument("--subdetect", action="store_true",
help="모든 그룹 sub-detect (SAM3 서버 필요)")
ap.add_argument("--subdetect-polys", type=str, default="",
help="지정 polygon이 속한 그룹만 sub-detect (예: '2,27')")
args = ap.parse_args()
selected = None
if args.polygons.strip():
selected = set(int(x.strip()) for x in args.polygons.split(',') if x.strip())
class_ids = None
if args.class_ids.strip():
class_ids = set(int(x.strip()) for x in args.class_ids.split(',') if x.strip())
subdetect_polys = None
if args.subdetect_polys.strip():
subdetect_polys = set(int(x.strip()) for x in args.subdetect_polys.split(',') if x.strip())
render(Path(args.image), Path(args.label), Path(args.output), selected,
branch_radius=args.branch_radius, class_ids=class_ids,
subdetect=args.subdetect, subdetect_polys=subdetect_polys)
if __name__ == "__main__":
main()

777
tools/sam3_autolabel.py Normal file
View File

@@ -0,0 +1,777 @@
"""
SAM 3.1 텍스트 프롬프트 자동 라벨링
=====================================================
알고리즘:
A = 전철주 프롬프트(prompts/pole.txt) 검출 → 높이 > min_pole_h 필터
B = 철로 프롬프트(prompts/rail.txt) 검출 → zone mask 생성
결과 = A 중 B zone 내/근처에 있는 것만 유지
사용법:
python tools/sam3_autolabel.py --input data/역사이미지/slope/
python tools/sam3_autolabel.py --input data/역사이미지/ --output output/autolabel/
사전 조건:
- start_server.bat 실행 (SAM 3.1 서버)
"""
import argparse
import base64
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
_PROMPT_DIR = Path(__file__).parent.parent / "prompts"
CLASS_NAMES = ["catenary_pole", "bracket"]
_CLS_COLORS = [(0, 200, 255), (255, 130, 0)]
def _load_prompt(filename: str, default: str) -> str:
path = _PROMPT_DIR / filename
if path.exists():
lines = [l.strip() for l in path.read_text(encoding="utf-8").splitlines()
if l.strip() and not l.startswith("#")]
if lines:
return ", ".join(lines)
return default
def _get_prompts() -> tuple:
"""(pole_prompt, rail_prompt, bracket_prompt) 반환."""
pole = _load_prompt("pole.txt",
"railway catenary pole, overhead line support pole, catenary mast")
rail = _load_prompt("rail.txt", "railroad, railway")
bracket = _load_prompt("bracket.txt",
"catenary pole top, pole cross arm, horizontal cantilever beam, bracket arm")
return pole, rail, bracket
# ── 이미지 인코딩 ──────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 2048) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max_size > 0 and max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
# ── SAM 3.1 텍스트 프롬프트 세그멘테이션 ──────────────────────────────────────
def sam3_text_segment(image_bgr: np.ndarray, text_prompt: str,
conf_threshold: float = 0.25,
sam_max_size: int = 2048) -> list:
b64, scale = encode_image(image_bgr, sam_max_size)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {
"text_prompt": text_prompt,
"conf_threshold": conf_threshold,
"show_masks": True,
"show_boxes": False,
},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=180)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
print(f" SAM3 오류: {resp.get('error', {}).get('message', 'unknown')}")
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return shapes
except Exception as e:
print(f" SAM3 오류: {e}")
return []
def shapes_to_pairs(shapes: list) -> list:
"""polygon shapes -> (det, shape) pairs. det = (label, score, x1,y1,x2,y2)."""
pairs = []
for s in shapes:
if s.get("shape_type") != "polygon":
continue
pts = s.get("points", [])
if len(pts) < 3:
continue
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
det = (s.get("label", ""), float(s.get("score", 0.0)),
min(xs), min(ys), max(xs), max(ys))
pairs.append((det, s))
return pairs
# ── NMS ───────────────────────────────────────────────────────────────────────
def nms(pairs: list, iou_thresh: float = 0.5) -> list:
if not pairs:
return []
boxes = np.array([[x1, y1, x2, y2] for (_, _, x1, y1, x2, y2), _ in pairs])
scores = np.array([conf for (_, conf, *_), _ in pairs])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]
keep.append(i)
if len(order) == 1:
break
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
a_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
a_j = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [pairs[i] for i in keep]
# ── 높이 필터 ─────────────────────────────────────────────────────────────────
def filter_by_size(pairs: list, min_h: int, debug: bool = False) -> list:
"""전철주 높이 필터: det bbox 높이 >= min_h, SAM mask도 충분히 높고 가늘어야 함."""
kept = []
for (label, conf, x1, y1, x2, y2), shape in pairs:
det_h = y2 - y1
if det_h < min_h:
if debug:
print(f" DROP [{label} {conf:.2f}] h={det_h:.0f} (최소:{min_h})")
continue
pts = shape.get("points", [])
if pts:
ys = [p[1] for p in pts]
xs = [p[0] for p in pts]
mh = max(ys) - min(ys)
mw = max(xs) - min(xs)
if mh < min_h * 0.7:
if debug:
print(f" DROP [{label} {conf:.2f}] sam_h={mh:.0f} (최소:{min_h*0.7:.0f})")
continue
if mw > mh * 3:
if debug:
print(f" DROP [{label} {conf:.2f}] 너무 넓음 w{mw:.0f}>h{mh:.0f}x3")
continue
kept.append(((label, conf, x1, y1, x2, y2), shape))
return kept
# ── 철로 존 마스크 ─────────────────────────────────────────────────────────────
def build_zone_mask(shapes: list, H: int, W: int, margin: int,
gap_close: int = 500,
lateral_expand_ratio: float = 0.0) -> np.ndarray:
"""철로 polygon → binary mask → 수평 갭 보정(beam 차폐) → margin px 팽창.
lateral_expand_ratio > 0 이면 추가로 좌우 수평 팽창 (가장자리 기둥 포착)."""
mask = np.zeros((H, W), dtype=np.uint8)
for s in shapes:
pts = np.array(s["points"], dtype=np.int32)
cv2.fillPoly(mask, [pts], 255)
if gap_close > 0 and mask.any():
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (gap_close, 30))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, h_kernel)
if margin > 0 and mask.any():
r = margin
yy, xx = np.ogrid[-r:r + 1, -r:r + 1]
kernel = (xx * xx + yy * yy <= r * r).astype(np.uint8)
mask = cv2.dilate(mask, kernel)
if lateral_expand_ratio > 0 and mask.any():
lat = int(W * lateral_expand_ratio)
if lat > 0:
lat_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (lat * 2 + 1, 1))
mask = cv2.dilate(mask, lat_kernel)
return mask
def pole_in_zone(pole_det: tuple, zone_mask: np.ndarray) -> bool:
"""전철주 bbox 중심 열(column)의 y1~y2 범위 중 zone_mask와 교차하면 True.
기둥은 철로 근처에 기저부를 두고 위/아래로 뻗으므로
중심 한 점 대신 세로 전체 범위로 판단."""
_, _, x1, y1, x2, y2 = pole_det
H, W = zone_mask.shape
cx = int(max(0, min((x1 + x2) / 2, W - 1)))
ys = int(max(0, y1))
ye = int(min(H, y2 + 1))
return bool(zone_mask[ys:ye, cx].any())
def _filter_pole_has_beam(pole_pairs: list, beam_pairs: list) -> list:
"""A ∩ C: 전철주 bbox와 겹치는 Beam/Arm이 하나라도 있는 기둥만 유지.
Beam/Arm bbox와 기둥 bbox의 IoU > 0 (단순 교차)으로 판단."""
beam_boxes = [(x1, y1, x2, y2)
for (_, _, x1, y1, x2, y2), _ in beam_pairs]
def overlaps_any_beam(pole_det):
_, _, px1, py1, px2, py2 = pole_det
for bx1, by1, bx2, by2 in beam_boxes:
ix1 = max(px1, bx1); iy1 = max(py1, by1)
ix2 = min(px2, bx2); iy2 = min(py2, by2)
if ix2 > ix1 and iy2 > iy1:
return True
return False
return [(d, s) for d, s in pole_pairs if overlaps_any_beam(d)]
def filter_long_beams(beam_pairs: list, rail_shapes: list,
min_beam_width: int = 400,
ratio_thresh: float = 2.0,
rail_x_margin: int = 500,
debug: bool = False) -> list:
"""역사형 portal frame의 긴 빔만 유지.
조건:
1) bbox 폭 >= min_beam_width (가로등은 좁음)
2) 폭/높이 >= ratio_thresh (가로형이어야 함)
3) 빔 x-범위가 검출된 rail x-범위 ±rail_x_margin 내 (도로 차단)
"""
if not rail_shapes:
if debug:
print(f" [Beam filter] rail 미검출 → 빔 zone 보강 비활성")
return []
rail_xs = []
for s in rail_shapes:
if s.get("shape_type") == "polygon":
xs = [p[0] for p in s["points"]]
rail_xs.extend([min(xs), max(xs)])
if not rail_xs:
return []
rail_x_min, rail_x_max = min(rail_xs), max(rail_xs)
kept = []
for (label, conf, x1, y1, x2, y2), shape in beam_pairs:
bw = x2 - x1
bh = max(1, y2 - y1)
ratio = bw / bh
if bw < min_beam_width:
if debug:
print(f" DROP beam [{label} {conf:.2f}] w={bw:.0f} (최소:{min_beam_width})")
continue
if ratio < ratio_thresh:
if debug:
print(f" DROP beam [{label} {conf:.2f}] w/h={ratio:.2f} (최소:{ratio_thresh})")
continue
if x2 < rail_x_min - rail_x_margin or x1 > rail_x_max + rail_x_margin:
if debug:
print(f" DROP beam [{label} {conf:.2f}] rail x-범위 벗어남")
continue
kept.append(((label, conf, x1, y1, x2, y2), shape))
return kept
def add_beams_to_zone(zone_mask: np.ndarray, beam_pairs: list,
downward_extend: int = 400, side_margin: int = 100) -> np.ndarray:
"""필터된 긴 빔의 bbox를 zone_mask에 추가.
빔 위치 + 아래쪽 downward_extend px → 레일이 빔에 차폐된 구간 보강."""
H, W = zone_mask.shape
for (_, _, x1, y1, x2, y2), _ in beam_pairs:
bx1 = max(0, int(x1) - side_margin)
bx2 = min(W, int(x2) + side_margin)
by1 = max(0, int(y1))
by2 = min(H, int(y2) + downward_extend)
zone_mask[by1:by2, bx1:bx2] = 255
return zone_mask
def _mask_to_largest_polygon(mask: np.ndarray, min_area: int = 50) -> list:
"""이진 마스크에서 가장 큰 외곽선 → polygon points."""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
largest = max(contours, key=cv2.contourArea)
if cv2.contourArea(largest) < min_area:
return None
epsilon = 0.002 * cv2.arcLength(largest, True)
approx = cv2.approxPolyDP(largest, epsilon, True)
return [[float(p[0][0]), float(p[0][1])] for p in approx]
def split_pole_arm(pts: list, H: int, W: int,
arm_min_factor: float = 0.5) -> tuple:
"""전철주 polygon을 세로 기둥(column)과 가로 암(arm)으로 분리.
방법 (row-width 분석):
1. polygon → mask
2. 각 가로 row에서 mask width 측정
3. 좁은 row 절반의 중앙값 = column 폭 추정
4. column 중심 x = 좁은 row들의 mask 중점 중앙값
5. column mask = column 중심 ± column 폭/2 범위 내 mask
6. arm mask = column 범위 외 mask (충분히 가로로 뻗은 row만)
arm_min_factor: arm으로 인정할 row의 최소 width = column 폭 × factor
반환: (column_pts, arm_pts | None)
"""
polygon = np.array(pts, dtype=np.int32)
if len(polygon) < 3:
return pts, None
mask = np.zeros((H, W), dtype=np.uint8)
cv2.fillPoly(mask, [polygon], 255)
x1 = max(0, int(polygon[:, 0].min()))
x2 = min(W - 1, int(polygon[:, 0].max()))
y1 = max(0, int(polygon[:, 1].min()))
y2 = min(H - 1, int(polygon[:, 1].max()))
sub = mask[y1:y2 + 1, x1:x2 + 1]
row_widths = (sub > 0).sum(axis=1).astype(np.float32)
valid = row_widths > 0
if valid.sum() < 5:
return pts, None
# column 폭 = 좁은 row들의 중앙값 (전체의 좁은 절반)
sorted_widths = np.sort(row_widths[valid])
column_w = float(np.median(sorted_widths[:max(1, len(sorted_widths) // 2)]))
if column_w < 5:
return pts, None
# column 중심 x 추정 — 좁은 row들의 mask 중점 중앙값
narrow_threshold = column_w * 1.3
cx_list = []
for i, w in enumerate(row_widths):
if 0 < w <= narrow_threshold:
row_px = np.where(sub[i] > 0)[0]
if len(row_px) > 0:
cx_list.append((row_px[0] + row_px[-1]) / 2)
if not cx_list:
return pts, None
column_cx = int(np.median(cx_list)) + x1
column_half = max(4, int(column_w * 0.7))
col_x_min = max(0, column_cx - column_half)
col_x_max = min(W, column_cx + column_half + 1)
column_mask = np.zeros_like(mask)
column_mask[:, col_x_min:col_x_max] = mask[:, col_x_min:col_x_max]
arm_mask = mask.copy()
arm_mask[:, col_x_min:col_x_max] = 0
# arm: 가로로 충분히 뻗은 row만 유지
arm_min_w = max(20, int(column_w * arm_min_factor))
arm_row_widths = (arm_mask > 0).sum(axis=1)
for y in range(H):
if arm_row_widths[y] < arm_min_w:
arm_mask[y, :] = 0
column_pts = _mask_to_largest_polygon(column_mask)
arm_pts = _mask_to_largest_polygon(arm_mask) if arm_mask.any() else None
# arm이 너무 작으면 폐기 (column 면적의 5% 미만)
if arm_pts and column_pts:
c_area = cv2.contourArea(np.array(column_pts, dtype=np.int32))
a_area = cv2.contourArea(np.array(arm_pts, dtype=np.int32))
if a_area < c_area * 0.05:
arm_pts = None
return (column_pts or pts), arm_pts
# ── 후처리: 합치기 + Skeleton 분리 (column / beam) ────────────────────────────
def _compute_skeleton(mask: np.ndarray) -> np.ndarray:
"""이진 마스크 → 1px 두께 skeleton (opencv-contrib-python의 ximgproc.thinning)."""
return cv2.ximgproc.thinning(mask)
def _find_branch_points(skel: np.ndarray) -> np.ndarray:
"""Skeleton의 분기점(8-neighbor degree ≥ 3) 마스크 반환."""
skel_bin = (skel > 0).astype(np.uint8)
# 8-neighbor 합 = 3x3 box filter * 9 - center
box_kernel = np.ones((3, 3), dtype=np.float32)
nbr = cv2.filter2D(skel_bin.astype(np.float32), cv2.CV_32F, box_kernel) - skel_bin
branch = (skel_bin == 1) & (nbr >= 3)
return branch.astype(np.uint8) * 255
def _compute_degree_map(skel: np.ndarray) -> np.ndarray:
"""Skeleton 각 픽셀의 8-neighbor degree (자기 자신 제외)."""
skel_bin = (skel > 0).astype(np.float32)
box = np.ones((3, 3), dtype=np.float32)
nbr = cv2.filter2D(skel_bin, cv2.CV_32F, box) - skel_bin
deg = (nbr * skel_bin).astype(np.int32)
return deg
def _split_one_component(mask: np.ndarray) -> tuple:
"""하나의 connected component를 column / beam mask로 분리 (위상 규칙).
라멘 구조 위상:
- beam (가로 빔): 양 끝이 branch에 연결 → 원본 free endpoint 0개
- column (세로 기둥): 한 끝이 branch, 다른 끝이 free → 원본 free endpoint 1개+
위상 규칙이 적용되지 않는 폴리곤(단일 segment, 가로등 '', "" 등)은
라멘이 아니므로 (None, None) 반환 → 호출부에서 폴리곤 삭제.
"""
skel = _compute_skeleton(mask)
if not skel.any():
return None, None
# 원본 skeleton의 degree map
deg = _compute_degree_map(skel)
free_endpoint_mask = ((deg == 1) & (skel > 0))
branch_mask = ((deg >= 3) & (skel > 0)).astype(np.uint8) * 255
# branch 제거 → segment 분리
if branch_mask.any():
bk = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
seg_skel = cv2.bitwise_and(skel, cv2.bitwise_not(cv2.dilate(branch_mask, bk)))
else:
# branch가 전혀 없음 = 단일 선/곡선 → 라멘 아님
return None, None
num, labels = cv2.connectedComponents(seg_skel)
interior_seeds = [] # free endpoint 0개 segment
edge_seeds = [] # free endpoint 1개+ segment
for sid in range(1, num):
seg = (labels == sid)
if seg.sum() < 10:
continue
n_free = int((free_endpoint_mask & seg).sum())
seg_u8 = seg.astype(np.uint8) * 255
if n_free == 0:
interior_seeds.append(seg_u8)
else:
edge_seeds.append(seg_u8)
# 라멘이 되려면 interior(beam)와 edge(column)가 모두 있어야 함
if not interior_seeds or not edge_seeds:
return None, None
col_seed = np.zeros_like(mask)
beam_seed = np.zeros_like(mask)
for s in edge_seeds:
col_seed = cv2.bitwise_or(col_seed, s)
for s in interior_seeds:
beam_seed = cv2.bitwise_or(beam_seed, s)
col_dist = cv2.distanceTransform(255 - col_seed, cv2.DIST_L2, 5)
beam_dist = cv2.distanceTransform(255 - beam_seed, cv2.DIST_L2, 5)
mask_bool = mask > 0
col_region = (mask_bool & (col_dist <= beam_dist)).astype(np.uint8) * 255
beam_region = (mask_bool & (col_dist > beam_dist)).astype(np.uint8) * 255
return col_region, beam_region
def _extract_polygons(mask: np.ndarray, min_area: int = 200,
epsilon_factor: float = 0.002) -> list:
"""Mask의 모든 connected component → polygon points 리스트."""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polys = []
for c in contours:
if cv2.contourArea(c) < min_area:
continue
eps = epsilon_factor * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, eps, True)
pts = [[float(p[0][0]), float(p[0][1])] for p in approx]
if len(pts) >= 3:
polys.append(pts)
return polys
def skeleton_split_columns_beams(pole_pairs: list, H: int, W: int,
min_component_area: int = 1000,
min_polygon_area: int = 200) -> tuple:
"""Pole polygon별로 skeleton split → column / beam 분류된 polygon 추출.
반환: (column_polygons, beam_polygons)
"""
if not pole_pairs:
return [], []
merged = np.zeros((H, W), dtype=np.uint8)
for _, shape in pole_pairs:
if shape.get("shape_type") == "polygon":
pts = np.array(shape["points"], dtype=np.int32)
cv2.fillPoly(merged, [pts], 255)
if not merged.any():
return [], []
# closing 미적용 — SAM3.1 polygon 형상 그대로 사용
num, labels = cv2.connectedComponents(merged)
all_col = np.zeros_like(merged)
all_beam = np.zeros_like(merged)
kept = 0
dropped = 0
for cid in range(1, num):
comp = (labels == cid).astype(np.uint8) * 255
if int(comp.sum() / 255) < min_component_area:
continue
col_m, beam_m = _split_one_component(comp)
if col_m is None or beam_m is None:
dropped += 1
continue
all_col = cv2.bitwise_or(all_col, col_m)
all_beam = cv2.bitwise_or(all_beam, beam_m)
kept += 1
print(f" [Topo] 라멘 위상 통과: {kept}개, 비라멘 폐기: {dropped}")
column_polys = _extract_polygons(all_col, min_area=min_polygon_area)
beam_polys = _extract_polygons(all_beam, min_area=min_polygon_area)
return column_polys, beam_polys
# ── 저장 ──────────────────────────────────────────────────────────────────────
def save_yolo_label(label_path: Path, pairs: list, W: int, H: int):
lines = []
for (label, conf, *_), shape in pairs:
if shape.get("shape_type") != "polygon":
continue
pts = shape.get("points", [])
if len(pts) < 3:
continue
cls_id = shape.get("_cls_id", 0)
coords = " ".join(f"{x/W:.6f} {y/H:.6f}" for x, y in pts)
lines.append(f"{cls_id} {coords}")
label_path.write_text("\n".join(lines), encoding="utf-8")
def save_vis(vis_path: Path, image_bgr: np.ndarray, pairs: list,
rail_shapes: list = None, long_beam_pairs: list = None):
vis = image_bgr.copy()
for (label, conf, x1, y1, x2, y2), shape in pairs:
cls_id = shape.get("_cls_id", 0)
color = _CLS_COLORS[cls_id % len(_CLS_COLORS)]
if shape.get("shape_type") == "polygon":
pts = np.array(shape["points"], dtype=np.int32)
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
cv2.polylines(vis, [pts], True, color, 2)
cv2.putText(vis, f"{CLASS_NAMES[cls_id]} {conf:.2f}",
(int(x1), max(0, int(y1) - 5)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# 레일 zone 시각화 (빨간색, 반투명)
if rail_shapes:
for s in rail_shapes:
if s.get("shape_type") == "polygon":
pts = np.array(s["points"], dtype=np.int32)
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], (0, 0, 200))
cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis)
cv2.polylines(vis, [pts], True, (0, 0, 255), 2)
# 긴 빔 시각화 (자홍색 bbox)
if long_beam_pairs:
for (label, conf, x1, y1, x2, y2), _ in long_beam_pairs:
cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)),
(255, 0, 255), 2)
cv2.putText(vis, f"beam {conf:.2f}", (int(x1), max(0, int(y1) - 5)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
h, w = vis.shape[:2]
if max(h, w) > 2048:
scale = 2048 / max(h, w)
vis = cv2.resize(vis, (int(w * scale), int(h * scale)))
cv2.imencode(vis_path.suffix, vis)[1].tofile(str(vis_path))
# ── 타일 검출 ─────────────────────────────────────────────────────────────────
def _detect_tiled(image_bgr: np.ndarray, prompt: str, conf: float,
tile_size: int, overlap: float) -> list:
H, W = image_bgr.shape[:2]
step = max(1, int(tile_size * (1 - overlap)))
tile_coords = []
for y0 in range(0, H, step):
for x0 in range(0, W, step):
tile_coords.append((x0, y0, min(x0 + tile_size, W), min(y0 + tile_size, H)))
print(f"{len(tile_coords)}개 타일")
all_pairs = []
for idx, (tx0, ty0, tx1, ty1) in enumerate(tile_coords, 1):
tile_bgr = image_bgr[ty0:ty1, tx0:tx1]
shapes = sam3_text_segment(tile_bgr, prompt, conf)
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x + tx0, y + ty0] for x, y in s["points"]]
all_pairs.extend(shapes_to_pairs(shapes))
if idx % 5 == 0:
print(f" 타일 {idx}/{len(tile_coords)}, 누적 {len(all_pairs)}")
return nms(all_pairs)
# ── 이미지 처리 ───────────────────────────────────────────────────────────────
def process_image(image_path: Path,
out_label_dir: Path, out_vis_dir: Path,
sam3_conf: float,
tile_size: int, tile_overlap: float,
min_pole_h: int = 160,
rail_margin: int = 300,
rail_gap_close: int = 500,
beam_min_width: int = 400,
beam_ratio: float = 2.0,
beam_downward: int = 400,
rail_lateral: float = 0.0,
raw_polygons: bool = False) -> int:
buf = np.fromfile(str(image_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f" 이미지 로드 실패: {image_path}")
return 0
H, W = image_bgr.shape[:2]
pole_prompt, rail_prompt, bracket_prompt = _get_prompts()
use_tile = tile_size > 0 and (W > tile_size or H > tile_size)
# A: 전철주 검출
if use_tile:
print(f" [A] 타일 {tile_size}px: '{pole_prompt}'")
pole_pairs = _detect_tiled(image_bgr, pole_prompt, sam3_conf, tile_size, tile_overlap)
else:
print(f" [A] 전철주 검출: '{pole_prompt}'")
pole_pairs = nms(shapes_to_pairs(sam3_text_segment(image_bgr, pole_prompt, sam3_conf)))
# A 높이 필터 (> min_pole_h)
before = len(pole_pairs)
pole_pairs = filter_by_size(pole_pairs, min_h=min_pole_h, debug=True)
print(f" [A] 높이 필터(h>={min_pole_h}): {before}{len(pole_pairs)}")
for _, shape in pole_pairs:
shape["_cls_id"] = 0
# B: 철로 검출
print(f" [B] 철로 검출: '{rail_prompt}'")
rail_shapes = [s for s in sam3_text_segment(image_bgr, rail_prompt, sam3_conf)
if s.get("shape_type") == "polygon"]
# [C] Beam/Arm 쿼리는 사용하지 않음 (검출되는 "긴 빔"이 실제로는 레일 일부였음)
# A ∩ Zone 필터 (Zone = rail + 측면 확장)
if rail_margin >= 0:
if rail_shapes:
zone_mask = build_zone_mask(rail_shapes, H, W, rail_margin, rail_gap_close,
lateral_expand_ratio=rail_lateral)
if rail_lateral > 0:
print(f" [Zone] rail + 측면 {rail_lateral*100:.0f}%")
before = len(pole_pairs)
kept = []
for d, s in pole_pairs:
if pole_in_zone(d, zone_mask):
kept.append((d, s))
else:
label, conf, x1, y1, x2, y2 = d
print(f" DROP zone [{label} {conf:.2f}] bbox=({x1:.0f},{y1:.0f})-({x2:.0f},{y2:.0f})")
pole_pairs = kept
print(f" [A∩Zone] 필터(+{rail_margin}px): {before}{len(pole_pairs)}")
else:
print(f" ! [B] 철로 미검출 → 전체 제외")
pole_pairs = []
# 후처리: raw_polygons=True 면 SAM 원본 폴리곤만 cls=0으로 그대로 저장
if raw_polygons:
all_pairs = []
for det, shape in pole_pairs:
shape["_cls_id"] = 0
shape["label"] = "catenary_pole"
all_pairs.append((det, shape))
print(f" [Raw] SAM 폴리곤 {len(all_pairs)}개 그대로 저장 (split 미적용)")
else:
# skeleton split (column / beam)
column_polys, beam_polys = skeleton_split_columns_beams(pole_pairs, H, W)
column_pairs = []
arm_pairs = []
for poly in column_polys:
xs = [p[0] for p in poly]
ys = [p[1] for p in poly]
det = ("catenary_pole", 1.0, min(xs), min(ys), max(xs), max(ys))
shape = {"shape_type": "polygon", "points": poly,
"_cls_id": 0, "label": "catenary_pole"}
column_pairs.append((det, shape))
for poly in beam_polys:
xs = [p[0] for p in poly]
ys = [p[1] for p in poly]
det = ("bracket", 1.0, min(xs), min(ys), max(xs), max(ys))
shape = {"shape_type": "polygon", "points": poly,
"_cls_id": 1, "label": "bracket"}
arm_pairs.append((det, shape))
print(f" [Split] column {len(column_pairs)}개, beam {len(arm_pairs)}")
all_pairs = column_pairs + arm_pairs
if not all_pairs:
return 0
save_yolo_label(out_label_dir / (image_path.stem + ".txt"), all_pairs, W, H)
save_vis(out_vis_dir / (image_path.stem + "_vis.jpg"), image_bgr, all_pairs,
rail_shapes, None)
return len(all_pairs)
# ── classes.txt 저장 ──────────────────────────────────────────────────────────
def save_classes(out_dir: Path):
(out_dir / "classes.txt").write_text("\n".join(CLASS_NAMES), encoding="utf-8")
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="SAM 3.1 텍스트 프롬프트 자동 라벨링")
parser.add_argument("--input", required=True, help="이미지 파일 또는 디렉터리")
parser.add_argument("--output", default="output/autolabel")
parser.add_argument("--sam3_conf", type=float, default=0.20)
parser.add_argument("--ext", default=".jpg,.jpeg,.png,.tif")
parser.add_argument("--tile_size", type=int, default=1280,
help="타일 크기 px (0=타일 없음, 기본:1280)")
parser.add_argument("--tile_overlap", type=float, default=0.2)
parser.add_argument("--min_pole_h", type=int, default=160,
help="전철주 최소 높이 px (기본:160, 드론 거리 변화 마진 포함)")
parser.add_argument("--rail_margin", type=int, default=300,
help="철로 mask 팽창 거리 px (-1=비활성, 기본:300)")
parser.add_argument("--rail_gap_close", type=int, default=500,
help="beam 차폐 갭 보정 수평 closing 폭 px (0=비활성, 기본:500)")
parser.add_argument("--beam_min_width", type=int, default=400,
help="긴 빔 최소 폭 px - 가로등/짧은 arm 차단 (기본:400, 역사형만)")
parser.add_argument("--beam_ratio", type=float, default=2.0,
help="긴 빔 폭/높이 최소 비율 (기본:2.0)")
parser.add_argument("--beam_downward", type=int, default=400,
help="빔 zone 아래 확장 px (기본:400)")
parser.add_argument("--rail_lateral", type=float, default=0.0,
help="rail zone 측면(좌우) 확장 비율 (이미지 폭 기준, 예:0.10=10%%)")
parser.add_argument("--raw_polygons", action="store_true",
help="후처리 skeleton split 없이 SAM 원본 폴리곤만 저장")
args = parser.parse_args()
input_path = Path(args.input)
out_dir = Path(args.output)
out_label = out_dir / "labels"
out_vis = out_dir / "vis"
out_label.mkdir(parents=True, exist_ok=True)
out_vis.mkdir(parents=True, exist_ok=True)
exts = [e.strip().lower() for e in args.ext.split(",")]
if input_path.is_dir():
images = sorted([p for p in input_path.iterdir() if p.suffix.lower() in exts],
key=lambda p: p.name)
else:
images = [input_path]
pole_prompt, rail_prompt, bracket_prompt = _get_prompts()
print(f"처리 대상: {len(images)}")
print(f"전철주 프롬프트: {pole_prompt}")
print(f"철로 프롬프트: {rail_prompt}")
print(f"타일: {args.tile_size}px 겹침: {args.tile_overlap*100:.0f}%")
print(f"출력: {out_dir}\n")
save_classes(out_dir)
total = 0
for i, img_path in enumerate(images, 1):
print(f"[{i}/{len(images)}] {img_path.name}")
total += process_image(img_path, out_label, out_vis,
args.sam3_conf,
args.tile_size, args.tile_overlap,
args.min_pole_h, args.rail_margin, args.rail_gap_close,
args.beam_min_width, args.beam_ratio, args.beam_downward,
args.rail_lateral, args.raw_polygons)
print(f"\n완료: 총 {total}개 라벨 생성")
print(f"라벨: {out_label}")
print(f"시각화: {out_vis}")
if __name__ == "__main__":
main()

278
tools/sam3_batch_label.py Normal file
View File

@@ -0,0 +1,278 @@
"""
SAM3 배치 자동 레이블링 파이프라인
===================================
SAM3 서버에 텍스트 프롬프트로 이미지 배치 처리 →
X-AnyLabeling 호환 JSON annotation 파일 자동 생성
사용법:
# 기본 (sample/rail 폴더, 서버 localhost:8000)
python tools/sam3_batch_label.py
# 폴더 지정
python tools/sam3_batch_label.py --input sample/rail --output output/labels
# conf 조정 (낮출수록 더 많이 검출, 오탐도 증가)
python tools/sam3_batch_label.py --conf 0.20
SAM3 서버 실행:
cd X-AnyLabeling-Server
uvicorn app.main:app --host 0.0.0.0 --port 8000
"""
import argparse
import base64
import json
import sys
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
MODEL_ID = "segment_anything_3"
# 검출 대상 + 한국어 레이블 매핑
TARGETS = {
"pole": "전철주_세로",
"catenary arm": "전철주_가로",
"junction box": "통신박스",
"electrical box": "전기박스",
"fence": "펜스",
}
# 시각화 색상 (BGR)
COLORS = {
"전철주_세로": (0, 200, 255),
"전철주_가로": (0, 100, 255),
"통신박스": (255, 180, 0),
"전기박스": (100, 255, 200),
"펜스": (0, 255, 100),
}
def encode_image(image_bgr: np.ndarray) -> str:
_, buf = cv2.imencode(".png", image_bgr) # PNG: 무손실 풀해상도
return base64.b64encode(buf).decode("utf-8")
def sam3_text_predict(image_bgr: np.ndarray, text_prompt: str, conf: float) -> list:
"""SAM3 텍스트 프롬프트로 segmentation. shapes 리스트 반환."""
payload = {
"model": MODEL_ID,
"image": encode_image(image_bgr),
"params": {
"text_prompt": text_prompt,
"show_masks": True,
"show_boxes": False,
"conf_threshold": conf,
"epsilon_factor": 0.002,
},
}
try:
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
if data.get("success"):
return data.get("data", {}).get("shapes", [])
except Exception as e:
print(f" [ERROR] SAM3 호출 실패: {e}")
return []
def process_image(image_path: Path, conf: float, vis_dir: Path | None) -> dict:
"""이미지 1장: 모든 클래스 검출 → annotation dict 반환."""
# 한글 경로 대응: np.fromfile + imdecode
buf = np.fromfile(str(image_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
return {}
h, w = image_bgr.shape[:2]
all_shapes = []
class_counts = {}
for eng_label, kor_label in TARGETS.items():
shapes = sam3_text_predict(image_bgr, eng_label, conf)
# label 필드를 한국어로 교체
for s in shapes:
s["label"] = kor_label
all_shapes.extend(shapes)
if shapes:
class_counts[kor_label] = len(shapes)
# X-AnyLabeling JSON 형식
annotation = {
"version": "3.3.9",
"flags": {},
"shapes": [
{
"label": s["label"],
"points": s["points"],
"group_id": None,
"shape_type": s.get("shape_type", "polygon"),
"flags": {},
"score": round(float(s.get("score", 0)), 4),
}
for s in all_shapes
],
"imagePath": image_path.name,
"imageData": None,
"imageHeight": h,
"imageWidth": w,
}
# 시각화
if vis_dir is not None:
vis = draw_vis(image_bgr, all_shapes)
vis_path = vis_dir / f"{image_path.stem}_vis.jpg"
cv2.imencode(".jpg", vis)[1].tofile(str(vis_path))
return annotation, class_counts
def draw_vis(image_bgr: np.ndarray, shapes: list) -> np.ndarray:
vis = image_bgr.copy()
overlay = image_bgr.copy()
for shape in shapes:
pts = np.array(shape["points"], dtype=np.int32)
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
pts = pts[:-1]
label = shape.get("label", "unknown")
color = COLORS.get(label, (200, 200, 200))
cv2.fillPoly(overlay, [pts], color)
cv2.polylines(vis, [pts], True, color, 2)
# 레이블 텍스트
cx = int(pts[:, 0].mean())
cy = int(pts[:, 1].mean())
cv2.putText(vis, label, (cx - 30, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
cv2.addWeighted(overlay, 0.3, vis, 0.7, 0, vis)
# 범례
y = 25
for kor, color in COLORS.items():
cv2.rectangle(vis, (10, y - 12), (24, y + 2), color, -1)
cv2.putText(vis, kor, (28, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
y += 20
return vis
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="sample/rail")
parser.add_argument("--output", default="output/sam3_labels",
help="JSON annotation 저장 폴더")
parser.add_argument("--vis", default="output/sam3_vis",
help="시각화 이미지 저장 폴더 ('' 로 비활성화)")
parser.add_argument("--conf", type=float, default=0.20,
help="SAM3 confidence threshold (기본 0.20)")
parser.add_argument("--classes", nargs="+",
help="처리할 클래스 키 목록 (기본: 전체). 예: catenary_pole catenary_arm")
parser.add_argument("--server", default="http://localhost:8000")
args = parser.parse_args()
global SAM3_SERVER, TARGETS
SAM3_SERVER = args.server
if args.classes:
key_map = {
"catenary_pole": ("catenary pole", "전철주_세로"),
"concrete_pole": ("concrete pole", "전철주_세로"),
"catenary_arm": ("catenary arm", "전철주_가로"),
"junction_box": ("junction box", "통신박스"),
"electrical_box": ("electrical box", "전기박스"),
"fence": ("fence", "펜스"),
}
TARGETS = {key_map[k][0]: key_map[k][1] for k in args.classes if k in key_map}
# 서버 확인
try:
r = requests.get(f"{SAM3_SERVER}/health", timeout=5)
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
except Exception:
print(f"[ERROR] SAM3 서버 연결 실패: {SAM3_SERVER}")
print(" 서버를 먼저 실행하세요:")
print(" cd X-AnyLabeling-Server && uvicorn app.main:app --port 8000")
sys.exit(1)
# 입력
input_path = Path(args.input)
if input_path.is_file():
images = [input_path]
elif input_path.is_dir():
images = sorted(
list(input_path.glob("*.jpg")) +
list(input_path.glob("*.jpeg")) +
list(input_path.glob("*.png"))
)
else:
print(f"[ERROR] 경로 없음: {input_path}")
sys.exit(1)
if not images:
print(f"[ERROR] 이미지 없음: {input_path}")
sys.exit(1)
out_dir = Path(args.output)
out_dir.mkdir(parents=True, exist_ok=True)
vis_dir = None
if args.vis:
vis_dir = Path(args.vis)
vis_dir.mkdir(parents=True, exist_ok=True)
print(f"\n처리 대상: {len(images)}")
print(f"검출 클래스: {list(TARGETS.values())}")
print(f"SAM3 conf: {args.conf}")
print(f"annotation 저장: {out_dir}")
if vis_dir:
print(f"시각화 저장: {vis_dir}")
print()
total_counts: dict = {v: 0 for v in TARGETS.values()}
processed = 0
for img_path in images:
print(f"[{processed+1}/{len(images)}] {img_path.name}")
result = process_image(img_path, args.conf, vis_dir)
if not result:
print(f" [SKIP] 처리 실패")
continue
annotation, class_counts = result
n = len(annotation["shapes"])
print(f"{n}개 객체 검출: {class_counts if class_counts else '없음'}")
# JSON 저장
json_path = out_dir / f"{img_path.stem}.json"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotation, f, ensure_ascii=False, indent=2)
for k, v in class_counts.items():
total_counts[k] = total_counts.get(k, 0) + v
processed += 1
# 요약
print("\n" + "="*50)
print(f"완료: {processed}/{len(images)}장 처리")
print("\n클래스별 총 검출 수:")
for cls, cnt in total_counts.items():
avg = cnt / max(processed, 1)
bar = "#" * min(cnt, 30)
print(f" {cls:12s}: {cnt:4d}개 평균 {avg:.1f}/장 {bar}")
print(f"\nJSON 저장: {out_dir.resolve()}")
if vis_dir:
print(f"시각화: {vis_dir.resolve()}")
print("\n다음 단계:")
print(" X-AnyLabeling → Open Image Folder → annotation 폴더 선택")
print(" → 오탐/미탐 수동 수정 → Export → YOLO11-seg 학습")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,291 @@
"""
SAM3.1 탐색 모드 (Discovery Sweep)
이미지를 타일로 분할 후 넓은 탐색용 text_prompt로 SAM3.1 호출 →
나온 segment들을 시각화 + 라벨 빈도 집계 → text_prompt 후보 결정.
이 SAM3.1 서버는 텍스트 grounding 방식이라 빈 prompt는 작동하지 않음.
대신 매우 넓은 "탐색 프롬프트"로 이미지에 존재하는 객체를 일괄 검출한다.
사용법:
python tools/sam3_everything_explore.py \\
--input "data/역사구간/1.회덕역/..." \\
--cols 8 --rows 6
사전 조건: SAM3.1 서버 실행 (start_server.bat)
"""
import argparse
import base64
import json
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import cv2
import numpy as np
import requests
SAM3_SERVER = "http://localhost:8000"
SAM3_MODEL_ID = "segment_anything_3"
# ── 이미지 인코딩 ─────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray, max_size: int = 1280) -> tuple:
h, w = image_bgr.shape[:2]
scale = 1.0
if max(h, w) > max_size:
scale = max_size / max(h, w)
image_bgr = cv2.resize(image_bgr, (int(w * scale), int(h * scale)))
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
return base64.b64encode(buf).decode("utf-8"), scale
# 탐색용 넓은 프롬프트 — 철도 현장에서 흔히 보이는 모든 요소 포함
DISCOVERY_PROMPT = (
"railroad track, railway rail, "
"catenary pole, overhead line pole, electric pole, "
"overhead wire, catenary wire, power line cable, "
"railway sleeper, concrete tie, "
"guardrail, highway barrier, road fence, "
"bridge, viaduct, overpass, "
"vegetation, tree, bush, grass, "
"building, structure, roof, wall, "
"vehicle, car, truck, "
"road, asphalt, pavement, "
"slope, embankment, retaining wall, "
"noise barrier, sound wall, "
"signal, sign board"
)
# ── SAM3.1 discovery sweep 호출 ───────────────────────────────────────────────
def sam3_everything(tile_bgr: np.ndarray, conf: float, prompt: str = DISCOVERY_PROMPT) -> list:
"""넓은 탐색 prompt → 이미지 내 모든 요소 검출. 반환: shape dict 리스트."""
b64, scale = encode_image(tile_bgr)
payload = {
"model": SAM3_MODEL_ID,
"image": b64,
"params": {
"text_prompt": prompt,
"conf_threshold": conf,
"show_masks": True,
"show_boxes": False,
},
}
try:
r = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=120)
r.raise_for_status()
resp = r.json()
if not resp.get("success"):
return []
shapes = resp.get("data", {}).get("shapes", [])
shapes = [s if isinstance(s, dict) else s.dict() for s in shapes]
if scale < 1.0:
inv = 1.0 / scale
for s in shapes:
if s.get("shape_type") == "polygon":
s["points"] = [[x * inv, y * inv] for x, y in s["points"]]
return [s for s in shapes if s.get("shape_type") == "polygon"]
except Exception as e:
print(f" [SAM3 오류] {e}")
return []
# ── NMS ───────────────────────────────────────────────────────────────────────
def _bbox(pts):
xs = [p[0] for p in pts]; ys = [p[1] for p in pts]
return min(xs), min(ys), max(xs), max(ys)
def nms_shapes(shapes: list, iou_thresh: float = 0.4) -> list:
if not shapes:
return []
bboxes = np.array([_bbox(s["points"]) for s in shapes], dtype=np.float32)
scores = np.array([float(s.get("score", 0.5)) for s in shapes])
order = scores.argsort()[::-1]
keep = []
while len(order):
i = order[0]; keep.append(i)
if len(order) == 1: break
xx1 = np.maximum(bboxes[i,0], bboxes[order[1:],0])
yy1 = np.maximum(bboxes[i,1], bboxes[order[1:],1])
xx2 = np.minimum(bboxes[i,2], bboxes[order[1:],2])
yy2 = np.minimum(bboxes[i,3], bboxes[order[1:],3])
inter = np.maximum(0, xx2-xx1) * np.maximum(0, yy2-yy1)
a_i = (bboxes[i,2]-bboxes[i,0])*(bboxes[i,3]-bboxes[i,1])
a_j = (bboxes[order[1:],2]-bboxes[order[1:],0])*(bboxes[order[1:],3]-bboxes[order[1:],1])
iou = inter / (a_i + a_j - inter + 1e-6)
order = order[1:][iou < iou_thresh]
return [shapes[i] for i in keep]
# ── 타일 분할 + 병렬 검출 ─────────────────────────────────────────────────────
def detect_everything_tiled(image_bgr, cols, rows, overlap, conf, workers, prompt):
H, W = image_bgr.shape[:2]
base_w = W / cols
base_h = H / rows
pad_x = int(base_w * overlap)
pad_y = int(base_h * overlap)
tiles = []
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
tiles.append((idx, x0, y0, x1, y1))
total = len(tiles)
done = [0]
all_shapes = []
def process(args):
idx, x0, y0, x1, y1 = args
tile = image_bgr[y0:y1, x0:x1]
shapes = sam3_everything(tile, conf, prompt)
for s in shapes:
s["points"] = [[px + x0, py + y0] for px, py in s["points"]]
return shapes
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(process, t): t for t in tiles}
for fut in as_completed(futs):
result = fut.result()
all_shapes.extend(result)
done[0] += 1
print(f" 타일 {done[0]:02d}/{total} 완료, 누적 {len(all_shapes)}", end="\r")
print()
return all_shapes
# ── 시각화 ────────────────────────────────────────────────────────────────────
def draw_everything(image_bgr, shapes, cols, rows):
vis = image_bgr.copy()
H, W = vis.shape[:2]
# 타일 경계
for r in range(rows):
for c in range(cols):
bx0, by0 = int(c * W / cols), int(r * H / rows)
bx1, by1 = int((c + 1) * W / cols), int((r + 1) * H / rows)
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (60, 60, 60), 1)
rng = np.random.default_rng(42)
for s in shapes:
pts = np.array(s["points"], dtype=np.int32)
color = tuple(int(v) for v in rng.integers(80, 255, size=3))
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.30, vis, 0.70, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# 라벨 표시 (있을 경우)
label = s.get("label", "")
if label:
cx = int(np.mean([p[0] for p in s["points"]]))
cy = int(np.mean([p[1] for p in s["points"]]))
cv2.putText(vis, label, (cx, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
cv2.putText(vis, f"total segments: {len(shapes)}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
return vis
# ── 라벨 분석 → text_prompt 후보 출력 ────────────────────────────────────────
def analyze_labels(shapes):
labels = [s.get("label", "").strip() for s in shapes if s.get("label", "").strip()]
if not labels:
print("\n[라벨 없음] 탐색 prompt에서 segment를 반환하지 않았습니다.")
return
counter = Counter(labels)
print(f"\n{''*50}")
print(f"검출된 라벨 종류: {len(counter)}개 (총 segment {len(shapes)}개)")
print(f"{''*50}")
for label, cnt in counter.most_common(30):
bar = "#" * min(cnt, 40)
print(f" {label:35s} {cnt:4d} {bar}")
print(f"{''*50}")
top_labels = [lb for lb, _ in counter.most_common(10)]
print(f"\n[text_prompt 후보]")
print(f' "{", ".join(top_labels)}"')
# ── 메인 ─────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--input", required=True, help="입력 이미지")
ap.add_argument("--output", default=None, help="출력 이미지 (기본: 입력명_everything.jpg)")
ap.add_argument("--cols", type=int, default=8, help="가로 타일 수 (기본 8)")
ap.add_argument("--rows", type=int, default=6, help="세로 타일 수 (기본 6)")
ap.add_argument("--overlap", type=float, default=0.10, help="타일 중복 비율 (기본 0.10)")
ap.add_argument("--conf", type=float, default=0.10, help="신뢰도 임계값 (기본 0.10)")
ap.add_argument("--workers", type=int, default=4, help="병렬 스레드 수 (기본 4)")
ap.add_argument("--nms", type=float, default=0.40, help="NMS IoU 임계값 (기본 0.40)")
ap.add_argument("--prompt-extra", default="", help="DISCOVERY_PROMPT 뒤에 추가할 어휘 (콤마 구분)")
args = ap.parse_args()
prompt = DISCOVERY_PROMPT + (", " + args.prompt_extra.strip(", ") if args.prompt_extra.strip() else "")
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
image_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image_bgr is None:
print(f"이미지 로드 실패: {img_path}"); return
H, W = image_bgr.shape[:2]
print(f"이미지 : {W}×{H}")
print(f"타일 그리드: {args.cols}×{args.rows}={args.cols*args.rows}")
print(f"conf={args.conf} overlap={args.overlap*100:.0f}% workers={args.workers}\n")
import time
print(f"탐색 프롬프트 ({len(prompt.split(','))}개 항목):")
for item in prompt.split(","):
print(f" · {item.strip()}")
print()
t0 = time.time()
shapes = detect_everything_tiled(
image_bgr, args.cols, args.rows, args.overlap,
args.conf, args.workers, prompt
)
print(f"검출 {len(shapes)}개 → NMS(iou={args.nms})...")
shapes = nms_shapes(shapes, iou_thresh=args.nms)
print(f"NMS 후 {len(shapes)}개 ({time.time()-t0:.0f}초)\n")
analyze_labels(shapes)
vis = draw_everything(image_bgr, shapes, args.cols, args.rows)
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w*s), int(h*s)))
out_path = (Path(args.output) if args.output
else img_path.parent / (img_path.stem + "_everything.jpg"))
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 93])[1].tofile(str(out_path))
print(f"\n저장: {out_path}")
# JSON으로 라벨 데이터도 저장 (분석용)
json_path = out_path.with_suffix(".json")
label_data = {
"total_segments": len(shapes),
"label_counts": dict(Counter(
s.get("label", "(no label)") for s in shapes
)),
"segments": [
{"label": s.get("label",""), "score": s.get("score",0),
"bbox": list(_bbox(s["points"]))}
for s in shapes
]
}
json_path.write_text(json.dumps(label_data, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"라벨 데이터: {json_path}")
if __name__ == "__main__":
main()

219
tools/sam3_receipt_ocr.py Normal file
View File

@@ -0,0 +1,219 @@
import argparse
import sys
import os
import time
import json
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
# Add server to path so we can import sam3 locally
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
models_path = os.path.join(server_path, "app", "models")
if server_path not in sys.path:
sys.path.insert(0, server_path)
if models_path not in sys.path:
sys.path.insert(0, models_path)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# PaddleOCR will be imported inside the function to avoid errors if not installed
def run_paddleocr(image_np):
# Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows
import os as _os
_os.environ["FLAGS_use_mkldnn"] = "0"
_os.environ["PADDLE_WITH_MKLDNN"] = "0"
_os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error
from paddleocr import PaddleOCR
# Force use_gpu=False if oneDNN is failing on CPU, or let it detect.
# On Windows, sometimes CPU + oneDNN is the default and it fails.
ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available())
result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility
return result
def build_point_grid(n_per_side: int) -> np.ndarray:
offset = 1.0 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
return grid
def mask_iou(mask1, mask2):
inter = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0: return 0
return inter / union
def get_receipt_mask(image_bgr, model_path, points_per_side=32):
print("Loading SAM3 Model for receipt detection...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
model = build_sam3_image_model(
bpe_path=bpe_path,
device=device,
checkpoint_path=model_path,
)
processor = Sam3Processor(model, confidence_threshold=0.7, device=device)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
state = processor.set_image(pil_image)
grid_points = build_point_grid(points_per_side)
masks = []
scores = []
print(f"Sampling {len(grid_points)} points...")
for i, (nx, ny) in enumerate(grid_points):
processor.reset_all_prompts(state)
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
if "masks" in state and len(state["masks"]) > 0:
best_idx = torch.argmax(state["scores"])
mask = state["masks"][best_idx].cpu().numpy()
score = state["scores"][best_idx].item()
if score > 0.8:
masks.append(mask)
scores.append(score)
if not masks:
return None
# Pick the largest mask that covers a significant area (receipt is usually big)
areas = [m.sum() for m in masks]
# Simple heuristic: largest mask
best_mask_idx = np.argmax(areas)
return masks[best_mask_idx]
def crop_and_warp(image, mask):
# Ensure mask is 2D
mask = np.squeeze(mask)
if mask.ndim != 2:
print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...")
if mask.ndim == 3: mask = mask[0]
# Find contours
mask_uint8 = (mask > 0).astype(np.uint8) * 255
if np.sum(mask_uint8) == 0:
print("Warning: Mask is empty.")
return image
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image
cnt = max(contours, key=cv2.contourArea)
# Get bounding box
x, y, w, h = cv2.boundingRect(cnt)
# Try to find 4 corners for perspective transform
epsilon = 0.02 * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if len(approx) == 4:
print("Found 4 corners, applying perspective transform...")
pts = approx.reshape(4, 2)
# Sort points: top-left, top-right, bottom-right, bottom-left
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
(tl, tr, br, bl) = rect
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
else:
print("Could not find 4 clear corners, just cropping to bounding box.")
# Create a mask image to black out background
masked_img = cv2.bitwise_and(image, image, mask=mask_uint8)
crop = masked_img[y:y+h, x:x+w]
return crop
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output_dir", default="output/ocr", help="Output directory")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Read image
buf = np.fromfile(args.input, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to read {args.input}")
return
model_path = os.path.join(server_path, "sam3.pt")
# 1. SAM3 Masking
mask = get_receipt_mask(image, model_path)
if mask is None:
print("No receipt-like object found.")
return
# 2. Cropping / Warping
processed_img = crop_and_warp(image, mask)
# Save processed image for debugging
processed_path = os.path.join(args.output_dir, "processed_receipt.jpg")
cv2.imwrite(processed_path, processed_img)
print(f"Saved processed image to {processed_path}")
# 3. PaddleOCR
print("Running PaddleOCR...")
ocr_results = run_paddleocr(processed_img)
# 4. Save results
output_json = os.path.join(args.output_dir, "ocr_results.json")
with open(output_json, "w", encoding="utf-8") as f:
json.dump(ocr_results, f, ensure_ascii=False, indent=2)
print(f"OCR results saved to {output_json}")
# Print summary
if ocr_results:
print("\n--- OCR Extracted Text ---")
for page in ocr_results:
if page is None:
continue
# New v3.x format: list of dicts with 'rec_text' key
if isinstance(page, dict):
text = page.get('rec_text', '')
score = page.get('rec_score', 0)
if text:
print(f"{text} (conf: {score:.2f})")
elif isinstance(page, list):
for line in page:
if line and isinstance(line, list) and len(line) >= 2:
print(line[1][0])
print("--------------------------")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,185 @@
import argparse
import sys
import os
import time
from pathlib import Path
import cv2
import numpy as np
import torch
# Add server to path so we can import sam3 locally
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
models_path = os.path.join(server_path, "app", "models")
if server_path not in sys.path:
sys.path.insert(0, server_path)
if models_path not in sys.path:
sys.path.insert(0, models_path)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from app.models.segment_anything_3 import SegmentAnything3
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
offset = 1.0 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
return grid
def mask_iou(mask1, mask2):
inter = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0:
return 0
return inter / union
def mask_to_polygon(mask, epsilon_factor=0.001):
mask = np.squeeze(mask)
mask_uint8 = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return []
largest = max(contours, key=cv2.contourArea)
if epsilon_factor > 0:
epsilon = epsilon_factor * cv2.arcLength(largest, True)
approx = cv2.approxPolyDP(largest, epsilon, True)
else:
approx = largest
points = [[float(p[0][0]), float(p[0][1])] for p in approx]
return points
def segment_everything(image_bgr, model_path, points_per_side=32, conf_thresh=0.8, nms_thresh=0.5):
print("Loading SAM3 Model locally...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
model = build_sam3_image_model(
bpe_path=bpe_path,
device=device,
checkpoint_path=model_path,
)
if device == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
processor = Sam3Processor(model, confidence_threshold=conf_thresh, device=device)
# PIL image format
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
from PIL import Image
pil_image = Image.fromarray(image_rgb)
print("Computing image embedding...")
t0 = time.time()
state = processor.set_image(pil_image)
print(f"Image embedding done in {time.time() - t0:.2f}s")
grid_points = build_point_grid(points_per_side)
print(f"Generated {len(grid_points)} grid points for sampling.")
masks = []
scores = []
t0 = time.time()
for i, (nx, ny) in enumerate(grid_points):
if i % 100 == 0:
print(f" Processed {i}/{len(grid_points)} points...")
processor.reset_all_prompts(state)
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
if "masks" in state and len(state["masks"]) > 0:
# Take the mask with the highest score
best_idx = torch.argmax(state["scores"])
mask = state["masks"][best_idx].cpu().numpy()
score = state["scores"][best_idx].item()
if score > conf_thresh:
masks.append(mask)
scores.append(score)
print(f"Grid prediction done in {time.time() - t0:.2f}s")
print(f"Found {len(masks)} raw masks.")
if not masks:
return []
# Simple NMS based on IoU
print("Applying NMS...")
order = np.argsort(scores)[::-1]
keep = []
for idx in order:
if len(keep) == 0:
keep.append(idx)
continue
current_mask = masks[idx]
overlap = False
for k in keep:
iou = mask_iou(current_mask, masks[k])
if iou > nms_thresh:
overlap = True
break
if not overlap:
keep.append(idx)
final_masks = [masks[idx] for idx in keep]
final_scores = [scores[idx] for idx in keep]
print(f"Kept {len(final_masks)} masks after NMS.")
# Convert to polygons
results = []
for m, s in zip(final_masks, final_scores):
poly = mask_to_polygon(m)
if poly:
results.append({"polygon": poly, "score": s})
return results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output", required=True, help="Output vis image path")
parser.add_argument("--points", type=int, default=32, help="Points per side")
args = parser.parse_args()
buf = np.fromfile(args.input, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to read {args.input}")
return
# Shrink image if too large just to make NMS faster
h, w = image.shape[:2]
max_dim = 1024
if max(h, w) > max_dim:
scale = max_dim / max(h, w)
image_proc = cv2.resize(image, (int(w * scale), int(h * scale)))
else:
image_proc = image.copy()
model_path = os.path.join(server_path, "sam3.pt")
results = segment_everything(image_proc, model_path, points_per_side=args.points, conf_thresh=0.7, nms_thresh=0.7)
vis = image_proc.copy()
np.random.seed(42)
for res in results:
poly = res["polygon"]
pts = np.array(poly, dtype=np.int32)
color = np.random.randint(0, 255, (3,)).tolist()
overlay = vis.copy()
cv2.fillPoly(overlay, [pts], color)
cv2.addWeighted(overlay, 0.4, vis, 0.6, 0, vis)
cv2.polylines(vis, [pts], True, color, 1)
# Fix unicode paths in output
is_success, im_buf_arr = cv2.imencode(".jpg", vis)
if is_success:
im_buf_arr.tofile(args.output)
print(f"Saved visualization to {args.output}")
if __name__ == "__main__":
main()

72
tools/show_tiles.py Normal file
View File

@@ -0,0 +1,72 @@
"""
이미지를 cols×rows 타일로 분할하고 번호를 좌상단에 표시.
사용법:
python tools/show_tiles.py --input <이미지> --cols 8 --rows 6 --overlap 0.10
"""
import argparse
from pathlib import Path
import cv2
import numpy as np
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True)
ap.add_argument("--output", default=None)
ap.add_argument("--cols", type=int, default=8)
ap.add_argument("--rows", type=int, default=6)
ap.add_argument("--overlap", type=float, default=0.10)
args = ap.parse_args()
img_path = Path(args.input)
buf = np.fromfile(str(img_path), dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
H, W = img.shape[:2]
base_w = W / args.cols
base_h = H / args.rows
pad_x = int(base_w * args.overlap)
pad_y = int(base_h * args.overlap)
vis = img.copy()
for r in range(args.rows):
for c in range(args.cols):
idx = r * args.cols + c + 1
x0 = max(0, int(c * base_w) - pad_x)
x1 = min(W, int((c + 1) * base_w) + pad_x)
y0 = max(0, int(r * base_h) - pad_y)
y1 = min(H, int((r + 1) * base_h) + pad_y)
# 타일 경계선 (기준선 = 중복 없는 경계)
bx0 = int(c * base_w)
by0 = int(r * base_h)
bx1 = min(W, int((c + 1) * base_w))
by1 = min(H, int((r + 1) * base_h))
cv2.rectangle(vis, (bx0, by0), (bx1, by1), (0, 200, 255), 3)
# 번호 좌상단
font_scale = max(0.8, min(W, H) / 3000)
thickness = max(2, int(font_scale * 3))
tx, ty = bx0 + 10, by0 + int(base_h * 0.12)
# 배경 박스
label = str(idx)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, thickness)
cv2.rectangle(vis, (tx - 4, ty - th - 4), (tx + tw + 4, ty + 4), (0, 0, 0), -1)
cv2.putText(vis, label, (tx, ty),
cv2.FONT_HERSHEY_SIMPLEX, font_scale * 2, (0, 200, 255), thickness, cv2.LINE_AA)
# 출력 크기 제한
h, w = vis.shape[:2]
if max(h, w) > 4096:
s = 4096 / max(h, w)
vis = cv2.resize(vis, (int(w * s), int(h * s)))
out = (Path(args.output) if args.output
else img_path.parent / (img_path.stem + "_tiles.jpg"))
cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 92])[1].tofile(str(out))
print(f"저장: {out} ({args.cols}×{args.rows}={args.cols*args.rows}타일)")
if __name__ == "__main__":
main()

237
tools/video_sam3_segment.py Normal file
View File

@@ -0,0 +1,237 @@
"""
드론 영상에서 프레임 추출 → SAM3 서버로 모든 객체 세그멘테이션
Usage:
1. SAM3 서버 시작: start_server.bat
2. python tools/video_sam3_segment.py
"""
import base64
import cv2
import json
import numpy as np
import requests
import sys
from pathlib import Path
# === 설정 ===
VIDEO_PATH = Path("sample/rail.mp4")
OUTPUT_DIR = Path("output/video_segmentation")
SAM3_URL = "http://localhost:8000"
MODEL_ID = "segment_anything_3"
FRAME_INTERVAL = 30 # 30fps 영상에서 1초 간격
# 철도 시설물 + 일반 객체 프롬프트
PROMPTS = [
"catenary pole", # 전철주
"junction box", # 전기박스
"utility box", # 통신박스
"rail track", # 레일
"fence", # 펜스
"cable", # 전선/케이블
"sign", # 표지판
"building", # 건물
"vegetation", # 식생
]
# 클래스별 색상 (BGR)
COLORS = {
"catenary pole": (0, 0, 255), # 빨강
"junction box": (0, 165, 255), # 주황
"utility box": (0, 255, 255), # 노랑
"rail track": (255, 0, 0), # 파랑
"fence": (255, 0, 255), # 자홍
"cable": (0, 255, 0), # 초록
"sign": (255, 255, 0), # 시안
"building": (128, 128, 255), # 연한 빨강
"vegetation": (0, 128, 0), # 진한 초록
}
def extract_frames(video_path: Path, interval: int) -> list[tuple[int, np.ndarray]]:
"""동영상에서 일정 간격으로 프레임 추출"""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
print(f"ERROR: Cannot open video: {video_path}")
sys.exit(1)
fps = cap.get(cv2.CAP_PROP_FPS)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video: {video_path.name} | {total} frames @ {fps:.0f}fps | {total/fps:.1f}s")
frames = []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if idx % interval == 0:
frames.append((idx, frame))
idx += 1
cap.release()
print(f"Extracted {len(frames)} frames (interval={interval})")
return frames
def encode_frame(frame: np.ndarray) -> str:
"""프레임을 base64로 인코딩"""
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
return base64.b64encode(buffer).decode("utf-8")
def predict_sam3(image_b64: str, text_prompt: str) -> dict:
"""SAM3 서버에 예측 요청"""
payload = {
"model": MODEL_ID,
"image": image_b64,
"params": {
"text_prompt": text_prompt,
"conf_threshold": 0.25,
},
}
try:
resp = requests.post(f"{SAM3_URL}/v1/predict", json=payload, timeout=60)
resp.raise_for_status()
return resp.json()
except requests.exceptions.ConnectionError:
print(f"ERROR: SAM3 서버에 연결할 수 없습니다. start_server.bat을 먼저 실행하세요.")
sys.exit(1)
except Exception as e:
print(f" ERROR [{text_prompt}]: {e}")
return {"success": False}
def draw_shapes_on_frame(frame: np.ndarray, shapes: list, prompt: str) -> np.ndarray:
"""세그멘테이션 결과를 프레임 위에 그리기"""
overlay = frame.copy()
color = COLORS.get(prompt, (200, 200, 200))
for shape in shapes:
points = np.array(shape["points"], dtype=np.int32)
shape_type = shape.get("shape_type", "polygon")
if shape_type == "polygon" and len(points) >= 3:
cv2.fillPoly(overlay, [points], color)
cv2.polylines(frame, [points], True, color, 2)
elif shape_type == "rectangle" and len(points) == 2:
cv2.rectangle(overlay, tuple(points[0]), tuple(points[1]), color, -1)
cv2.rectangle(frame, tuple(points[0]), tuple(points[1]), color, 2)
# 라벨 텍스트
label = shape.get("label", prompt)
score = shape.get("score")
text = f"{label}" + (f" {score:.2f}" if score else "")
if len(points) > 0:
tx, ty = int(points[0][0]), int(points[0][1]) - 5
cv2.putText(frame, text, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 반투명 오버레이 블렌딩
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
return frame
def create_layer_image(frame_shape: tuple, shapes: list, prompt: str) -> np.ndarray:
"""클래스별 개별 레이어 이미지 (투명 배경 위 마스크)"""
h, w = frame_shape[:2]
color = COLORS.get(prompt, (200, 200, 200))
bgr = np.zeros((h, w, 3), dtype=np.uint8)
alpha = np.zeros((h, w), dtype=np.uint8)
for shape in shapes:
points = np.array(shape["points"], dtype=np.int32)
if shape.get("shape_type") == "polygon" and len(points) >= 3:
cv2.fillPoly(bgr, [points], color)
cv2.fillPoly(alpha, [points], 180)
layer = np.dstack([bgr, alpha])
return layer
def main():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# 1. 서버 상태 확인
print("=== SAM3 서버 연결 확인 ===")
try:
health = requests.get(f"{SAM3_URL}/health", timeout=5)
print(f"Server status: {health.json()}")
except requests.exceptions.ConnectionError:
print("ERROR: SAM3 서버가 실행 중이 아닙니다!")
print(" → start_server.bat을 먼저 실행하세요.")
sys.exit(1)
# 2. 프레임 추출
print("\n=== 프레임 추출 ===")
frames = extract_frames(VIDEO_PATH, FRAME_INTERVAL)
# 3. 각 프레임별, 각 프롬프트별 세그멘테이션
all_results = {}
for frame_idx, (fidx, frame) in enumerate(frames):
print(f"\n=== Frame {fidx} ({frame_idx+1}/{len(frames)}) ===")
frame_b64 = encode_frame(frame)
frame_results = {}
composite = frame.copy()
for prompt in PROMPTS:
print(f" Segmenting: {prompt}...", end=" ")
result = predict_sam3(frame_b64, prompt)
if result.get("success") and result.get("data", {}).get("shapes"):
shapes = result["data"]["shapes"]
n = len(shapes)
print(f"{n} objects found")
frame_results[prompt] = shapes
# 합성 이미지에 그리기
composite = draw_shapes_on_frame(composite, shapes, prompt)
# 개별 레이어 저장 (PNG with alpha)
layer = create_layer_image(frame.shape, shapes, prompt)
layer_name = prompt.replace(" ", "_")
layer_path = OUTPUT_DIR / f"frame_{fidx:04d}_layer_{layer_name}.png"
cv2.imwrite(str(layer_path), layer)
else:
print("→ no objects")
# 합성 이미지 저장
composite_path = OUTPUT_DIR / f"frame_{fidx:04d}_composite.jpg"
cv2.imwrite(str(composite_path), composite)
# 원본 프레임 저장 (비교용)
original_path = OUTPUT_DIR / f"frame_{fidx:04d}_original.jpg"
cv2.imwrite(str(original_path), frame)
all_results[f"frame_{fidx}"] = {
"frame_index": fidx,
"detections": {
prompt: len(shapes)
for prompt, shapes in frame_results.items()
},
}
# 4. 결과 요약 저장
summary_path = OUTPUT_DIR / "segmentation_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, ensure_ascii=False)
# 5. 범례 이미지 생성
legend = np.zeros((len(PROMPTS) * 30 + 20, 300, 3), dtype=np.uint8)
for i, (prompt, color) in enumerate(COLORS.items()):
y = i * 30 + 20
cv2.rectangle(legend, (10, y - 12), (30, y + 5), color, -1)
cv2.putText(legend, prompt, (40, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.imwrite(str(OUTPUT_DIR / "legend.jpg"), legend)
print(f"\n=== 완료 ===")
print(f"결과 저장: {OUTPUT_DIR}/")
print(f" - frame_XXXX_original.jpg : 원본 프레임")
print(f" - frame_XXXX_composite.jpg : 전체 세그멘테이션 합성")
print(f" - frame_XXXX_layer_*.png : 클래스별 개별 레이어 (투명 배경)")
print(f" - segmentation_summary.json : 검출 요약")
print(f" - legend.jpg : 색상 범례")
if __name__ == "__main__":
main()

618
tools/web_ui.py Normal file
View File

@@ -0,0 +1,618 @@
"""
Railway Detection Web UI
사용법: python tools/web_ui.py
브라우저: http://localhost:7000
"""
import asyncio
import base64
import json
import os
import queue
import subprocess
import sys
import threading
import uuid
from pathlib import Path
import cv2
import numpy as np
try:
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
import uvicorn
except ImportError:
print("FastAPI/uvicorn 설치 필요: pip install fastapi uvicorn")
sys.exit(1)
app = FastAPI()
jobs: dict = {} # job_id -> {queue, status, result_path}
ROOT = Path(__file__).parent.parent # 프로젝트 루트
# ── 미리보기 이미지 생성 ───────────────────────────────────────────────────────
def make_preview(image_path: str, cols: int, rows: int, selected_rows: list) -> str:
buf = np.fromfile(image_path, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if img is None:
raise ValueError(f"이미지 로드 실패: {image_path}")
H, W = img.shape[:2]
tw = 2000
scale = tw / W
vis = cv2.resize(img, (tw, int(H * scale)))
H_s, W_s = vis.shape[:2]
tile_w = W_s / cols
tile_h = H_s / rows
for r in range(rows):
for c in range(cols):
idx = r * cols + c + 1
x0, y0 = int(c * tile_w), int(r * tile_h)
x1, y1 = min(W_s, int((c + 1) * tile_w)), min(H_s, int((r + 1) * tile_h))
row_num = r + 1
if row_num in selected_rows:
overlay = vis.copy()
cv2.rectangle(overlay, (x0, y0), (x1, y1), (0, 200, 0), -1)
cv2.addWeighted(overlay, 0.25, vis, 0.75, 0, vis)
cv2.rectangle(vis, (x0, y0), (x1, y1), (0, 255, 0), 2)
else:
cv2.rectangle(vis, (x0, y0), (x1, y1), (180, 180, 180), 1)
label = str(idx)
fs = max(0.35, tile_w / 120)
(tw2, th2), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fs, 1)
cx, cy = (x0 + x1) // 2, (y0 + y1) // 2
cv2.putText(vis, label, (cx - tw2 // 2, cy + th2 // 2),
cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 1, cv2.LINE_AA)
# Row 번호
y0 = int(r * tile_h)
row_num = r + 1
color = (0, 255, 0) if row_num in selected_rows else (160, 160, 160)
cv2.putText(vis, f"Row{row_num}", (4, y0 + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2, cv2.LINE_AA)
_, enc = cv2.imencode(".jpg", vis, [cv2.IMWRITE_JPEG_QUALITY, 88])
return base64.b64encode(enc).decode("utf-8")
def tiles_from_rows(selected_rows: list, cols: int) -> str:
if not selected_rows:
return "all"
ids = []
for r in sorted(selected_rows):
ids.extend(range((r - 1) * cols + 1, r * cols + 1))
if ids and ids[-1] - ids[0] + 1 == len(ids):
return f"{ids[0]}-{ids[-1]}"
return ",".join(str(t) for t in ids)
# ── API ───────────────────────────────────────────────────────────────────────
@app.post("/api/preview")
async def api_preview(data: dict):
try:
b64 = make_preview(
data["image_path"],
data.get("cols", 8),
data.get("rows", 6),
data.get("selected_rows", []),
)
return {"ok": True, "image": b64}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.post("/api/detect")
async def api_detect(data: dict):
job_id = str(uuid.uuid4())[:8]
q: queue.Queue = queue.Queue()
jobs[job_id] = {"queue": q, "status": "running", "result_path": None, "proc": None}
def run():
tiles_str = tiles_from_rows(data.get("selected_rows", []), data.get("cols", 8))
cmd = [
sys.executable, "-u", str(ROOT / "tools" / "detect_all_objects.py"),
"--input", data["image_path"],
"--categories", data.get("categories", "configs/railway_zone.json"),
"--tiles", tiles_str,
"--cols", str(data.get("cols", 8)),
"--rows", str(data.get("rows", 6)),
"--overlap", str(data.get("overlap", 0.20)),
"--conf", str(data.get("conf", 0.20)),
"--workers", str(data.get("workers", 8)),
"--save-json"
]
env = {**os.environ, "PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0
try:
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, encoding="utf-8", errors="replace",
env=env, cwd=str(ROOT), creationflags=flags,
)
jobs[job_id]["proc"] = proc
for raw in proc.stdout:
for line in raw.replace("\r", "\n").split("\n"):
line = line.strip()
if not line:
continue
q.put(("log", line))
if line.startswith("저장:"):
jobs[job_id]["result_path"] = line.replace("저장:", "").strip()
proc.wait()
if jobs[job_id]["status"] == "stopped":
q.put(("error", "사용자가 중지했습니다"))
elif proc.returncode == 0:
q.put(("done", jobs[job_id]["result_path"] or "완료"))
jobs[job_id]["status"] = "done"
else:
q.put(("error", f"종료코드 {proc.returncode}"))
jobs[job_id]["status"] = "error"
except Exception as e:
q.put(("error", str(e)))
jobs[job_id]["status"] = "error"
threading.Thread(target=run, daemon=True).start()
return {"job_id": job_id}
@app.post("/api/stop/{job_id}")
async def api_stop(job_id: str):
job = jobs.get(job_id)
if not job:
return {"ok": False, "error": "not found"}
job["status"] = "stopped"
proc = job.get("proc")
if proc and proc.poll() is None:
if os.name == "nt":
subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
else:
proc.terminate()
return {"ok": True}
@app.get("/api/progress/{job_id}")
async def api_progress(job_id: str):
if job_id not in jobs:
return JSONResponse({"error": "not found"}, status_code=404)
async def event_gen():
q = jobs[job_id]["queue"]
while True:
try:
type_, msg = q.get_nowait()
yield f"data: {json.dumps({'type': type_, 'msg': msg})}\n\n"
if type_ in ("done", "error"):
break
except queue.Empty:
await asyncio.sleep(0.25)
yield ": ping\n\n"
return StreamingResponse(event_gen(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.get("/api/result/{job_id}")
async def api_result(job_id: str):
job = jobs.get(job_id, {})
path = job.get("result_path")
if not path or not Path(path).exists():
return {"ok": False, "error": "결과 없음"}
buf = np.fromfile(path, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if img is None:
return {"ok": False, "error": "이미지 로드 실패"}
h, w = img.shape[:2]
if max(h, w) > 3000:
s = 3000 / max(h, w)
img = cv2.resize(img, (int(w * s), int(h * s)))
_, enc = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 90])
return {"ok": True, "image": base64.b64encode(enc).decode(), "path": path}
@app.post("/api/open-folder")
async def api_open_folder(data: dict):
path = Path(data.get("path", "output/detect"))
folder = path.parent if path.is_file() else path
if os.name == "nt":
os.startfile(str(folder.resolve()))
return {"ok": True}
# ── HTML ──────────────────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def index():
return HTML
HTML = r"""<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="UTF-8">
<title>Railway Detection UI</title>
<style>
*{box-sizing:border-box;margin:0;padding:0}
body{font-family:'Segoe UI',sans-serif;background:#1a1a2e;color:#e0e0e0;height:100vh;display:flex;flex-direction:column;overflow:hidden}
header{background:#16213e;padding:10px 20px;border-bottom:2px solid #0f3460;flex-shrink:0}
header h1{font-size:1.1rem;color:#e94560;letter-spacing:1px}
.tabs{display:flex;background:#16213e;border-bottom:1px solid #0f3460;flex-shrink:0}
.tab-btn{padding:9px 26px;background:none;border:none;color:#888;cursor:pointer;font-size:0.92rem;border-bottom:3px solid transparent;transition:all .2s}
.tab-btn.active{color:#e94560;border-bottom-color:#e94560}
.tab-content{display:none;flex:1;overflow:hidden;min-height:0}
.tab-content.active{display:flex}
/* 입력 탭 */
#tab-input{flex-direction:row}
.controls{width:320px;min-width:260px;background:#16213e;padding:14px;overflow-y:auto;border-right:1px solid #0f3460;display:flex;flex-direction:column;gap:12px;flex-shrink:0}
.preview-area{flex:1;background:#0d0d1f;overflow:hidden;position:relative;cursor:grab;min-height:0}
.preview-area.drag{cursor:grabbing}
.preview-area img{position:absolute;transform-origin:0 0;user-select:none;pointer-events:none}
.preview-area .hint{position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);color:#444;font-size:.9rem}
.preview-hud{position:absolute;top:8px;right:10px;font-size:.74rem;color:#555;background:rgba(0,0,0,.35);padding:3px 8px;border-radius:4px;pointer-events:none;z-index:2}
.preview-fit{position:absolute;top:8px;left:10px;font-size:.78rem;background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:4px 10px;border-radius:4px;cursor:pointer;z-index:2}
.preview-fit:hover{background:#1a5299}
.fg{display:flex;flex-direction:column;gap:4px}
.fg>label{font-size:.75rem;color:#888;font-weight:600;text-transform:uppercase;letter-spacing:.5px}
.path-row{display:flex;gap:6px}
.path-row input{flex:1;min-width:0}
input[type=text],input[type=number]{background:#0f3460;border:1px solid #2a4a7a;color:#e0e0e0;padding:6px 9px;border-radius:4px;font-size:.88rem;width:100%}
input[type=text]:focus,input[type=number]:focus{outline:none;border-color:#e94560}
.row-grid{display:grid;grid-template-columns:repeat(3,1fr);gap:6px}
.row-lbl{display:flex;align-items:center;gap:6px;padding:7px 9px;background:#0f3460;border:1px solid #2a4a7a;border-radius:4px;cursor:pointer;transition:all .15s;font-size:.85rem}
.row-lbl:hover{border-color:#0c6}
.row-lbl.on{border-color:#0f6;background:#0d3320;color:#0f6}
.row-lbl input{width:14px;height:14px;cursor:pointer;accent-color:#0f6}
.params-grid{display:grid;grid-template-columns:1fr 1fr;gap:8px}
.pi{display:flex;flex-direction:column;gap:3px}
.pi label{font-size:.72rem;color:#888}
.btn-load{background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:6px 13px;border-radius:4px;cursor:pointer;font-size:.84rem;white-space:nowrap;transition:background .2s}
.btn-load:hover{background:#1a5299}
.btn-run{background:#e94560;color:#fff;border:none;padding:11px;border-radius:6px;cursor:pointer;font-size:.98rem;font-weight:700;width:100%;margin-top:2px;transition:background .2s}
.btn-run:hover{background:#c73652}
.btn-run:disabled{background:#555;cursor:not-allowed}
.btn-stop{background:#555;color:#fff;border:none;padding:11px;border-radius:6px;cursor:pointer;font-size:.98rem;font-weight:700;width:100%;margin-top:2px;display:none}
.btn-stop.visible{display:block}
.btn-stop:hover{background:#c73652}
/* 결과 탭 */
#tab-result{flex-direction:column}
.res-toolbar{background:#16213e;padding:9px 14px;display:flex;align-items:center;gap:10px;border-bottom:1px solid #0f3460;flex-shrink:0}
.btn-act{background:#0f3460;color:#e0e0e0;border:1px solid #2a4a7a;padding:6px 13px;border-radius:4px;cursor:pointer;font-size:.84rem;transition:background .2s}
.btn-act:hover{background:#1a5299}
#result-path{font-size:.78rem;color:#888;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;flex:1;min-width:0}
.viewer{flex:1;overflow:hidden;cursor:grab;position:relative;background:#0a0a1a;min-height:0}
.viewer.drag{cursor:grabbing}
.viewer img{position:absolute;transform-origin:0 0;user-select:none;pointer-events:none}
.viewer-hint{display:flex;align-items:center;justify-content:center;height:100%;color:#333;font-size:.95rem}
/* 진행 */
.prog-section{background:#16213e;padding:8px 16px;border-top:1px solid #0f3460;flex-shrink:0}
.prog-track{background:#0a0a1a;border-radius:4px;height:7px;overflow:hidden;margin-bottom:5px}
.prog-fill{height:100%;background:linear-gradient(90deg,#e94560,#0f6);width:0%;transition:width .4s;border-radius:4px}
.prog-log{font-size:.76rem;color:#777;white-space:nowrap;overflow:hidden;text-overflow:ellipsis}
</style>
</head>
<body>
<header><h1>&#x1F6E4; Railway Detection UI</h1></header>
<div class="tabs">
<button class="tab-btn active" id="tbtn-input" onclick="switchTab('input')">입력</button>
<button class="tab-btn" id="tbtn-result" onclick="switchTab('result')">결과</button>
</div>
<div id="tab-input" class="tab-content active">
<div class="controls">
<div class="fg">
<label>이미지 경로</label>
<div class="path-row">
<input type="text" id="image-path" placeholder="경로 입력 후 Enter 또는 버튼">
<button class="btn-load" onclick="loadPreview()">미리보기</button>
</div>
</div>
<div class="fg">
<label>Row 선택 (철도 구역)</label>
<div class="row-grid" id="row-grid"></div>
</div>
<div class="fg">
<label>Categories JSON</label>
<input type="text" id="categories" value="configs/railway_zone.json">
</div>
<div class="fg">
<label>파라미터</label>
<div class="params-grid">
<div class="pi"><label>Cols</label><input type="number" id="cols" value="8" min="1" onchange="loadPreview()"></div>
<div class="pi"><label>Rows</label><input type="number" id="rows" value="6" min="1" onchange="initRows();loadPreview()"></div>
<div class="pi"><label>Overlap</label><input type="number" id="overlap" value="0.20" step="0.05" min="0" max="0.5"></div>
<div class="pi"><label>Conf</label><input type="number" id="conf" value="0.20" step="0.05" min="0.05" max="1"></div>
<div class="pi"><label>Workers</label><input type="number" id="workers" value="8" min="1" max="32"></div>
</div>
</div>
<button class="btn-run" id="btn-run" onclick="startDetect()">&#x25B6; 검출 시작</button>
<button class="btn-stop" id="btn-stop" onclick="stopDetect()">&#x25A0; 중지</button>
</div>
<div class="preview-area" id="preview-area">
<button class="preview-fit" onclick="resetPreviewZoom()">화면 맞춤</button>
<span class="preview-hud">휠: 확대/축소 &nbsp;|&nbsp; 드래그: 이동</span>
<span class="hint" id="preview-hint">이미지 경로 입력 후 미리보기 클릭</span>
<img id="preview-img" style="display:none">
</div>
</div>
<div id="tab-result" class="tab-content">
<div class="res-toolbar">
<button class="btn-act" onclick="openFolder()">&#x1F4C1; 출력 폴더 열기</button>
<span id="result-path">결과 없음</span>
<button class="btn-act" onclick="resetZoom()" style="margin-left:auto">화면 맞춤</button>
<span style="font-size:.76rem;color:#555">휠: 확대/축소 &nbsp;|&nbsp; 드래그: 이동</span>
</div>
<div class="viewer" id="viewer">
<div class="viewer-hint" id="viewer-hint">검출 완료 후 결과가 표시됩니다</div>
<img id="result-img" style="display:none">
</div>
</div>
<div class="prog-section">
<div class="prog-track"><div class="prog-fill" id="prog-fill"></div></div>
<div class="prog-log" id="prog-log">대기 중</div>
</div>
<script>
let resultPath = null;
let sc = 1, tx = 0, ty = 0, dragging = false, sx, sy;
// ── 탭 전환
function switchTab(name) {
document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
document.querySelectorAll('.tab-btn').forEach(el => el.classList.remove('active'));
document.getElementById('tab-' + name).classList.add('active');
document.getElementById('tbtn-' + name).classList.add('active');
}
// ── Row 체크박스
function initRows() {
const numRows = parseInt(document.getElementById('rows').value) || 6;
const grid = document.getElementById('row-grid');
grid.innerHTML = '';
for (let r = 1; r <= numRows; r++) {
const lbl = document.createElement('label');
lbl.className = 'row-lbl';
lbl.innerHTML = `<input type="checkbox" value="${r}" onchange="onRowChange(this)"><span>Row ${r}</span>`;
grid.appendChild(lbl);
}
}
function onRowChange(cb) {
cb.parentElement.classList.toggle('on', cb.checked);
loadPreview();
}
function getSelectedRows() {
return Array.from(document.querySelectorAll('#row-grid input:checked')).map(c => parseInt(c.value));
}
// ── 미리보기
async function loadPreview() {
const imgPath = document.getElementById('image-path').value.trim();
if (!imgPath) return;
const hint = document.getElementById('preview-hint');
const pimg = document.getElementById('preview-img');
hint.style.display = 'block';
hint.style.color = '';
hint.textContent = '로딩 중...';
try {
const res = await post('/api/preview', {
image_path: imgPath,
cols: +document.getElementById('cols').value,
rows: +document.getElementById('rows').value,
selected_rows: getSelectedRows()
});
if (res.ok) {
const keepZoom = pimg.style.display === 'block';
pimg.onload = () => { if (!keepZoom) resetPreviewZoom(); else applyPreviewT(); };
pimg.src = 'data:image/jpeg;base64,' + res.image;
pimg.style.display = 'block';
hint.style.display = 'none';
} else {
hint.style.color = '#e94560';
hint.textContent = '오류: ' + res.error;
}
} catch(e) {
hint.style.color = '#e94560';
hint.textContent = String(e);
}
}
// ── 검출 시작
async function startDetect() {
const rows = getSelectedRows();
if (!rows.length) { alert('Row를 하나 이상 선택하세요'); return; }
const imgPath = document.getElementById('image-path').value.trim();
if (!imgPath) { alert('이미지 경로를 입력하세요'); return; }
document.getElementById('btn-run').disabled = true;
document.getElementById('btn-stop').classList.add('visible');
document.getElementById('prog-fill').style.width = '0%';
document.getElementById('prog-log').textContent = '시작 중...';
const res = await post('/api/detect', {
image_path: imgPath,
selected_rows: rows,
cols: +document.getElementById('cols').value,
rows: +document.getElementById('rows').value,
overlap: +document.getElementById('overlap').value,
conf: +document.getElementById('conf').value,
workers: +document.getElementById('workers').value,
categories: document.getElementById('categories').value.trim()
});
listenProgress(res.job_id);
}
// ── SSE 진행
function listenProgress(jobId) {
const es = new EventSource(`/api/progress/${jobId}`);
es.onmessage = async e => {
const {type, msg} = JSON.parse(e.data);
document.getElementById('prog-log').textContent = msg || '';
if (type === 'log') {
const m = msg.match(/타일\s+(\d+)\/(\d+)/);
if (m) {
document.getElementById('prog-fill').style.width =
Math.round(+m[1] / +m[2] * 100) + '%';
}
} else if (type === 'done') {
es.close();
document.getElementById('prog-fill').style.width = '100%';
document.getElementById('prog-log').textContent = '완료! 결과 탭으로 이동합니다.';
setRunning(false);
await showResult(jobId);
} else if (type === 'error') {
es.close();
document.getElementById('prog-log').textContent = msg;
setRunning(false);
}
};
es.onerror = () => { es.close(); setRunning(false); };
}
// ── 결과 표시
async function showResult(jobId) {
const res = await fetch(`/api/result/${jobId}`).then(r => r.json());
if (!res.ok) return;
resultPath = res.path;
document.getElementById('result-path').textContent = res.path;
const img = document.getElementById('result-img');
img.src = 'data:image/jpeg;base64,' + res.image;
img.style.display = 'block';
document.getElementById('viewer-hint').style.display = 'none';
img.onload = resetZoom;
switchTab('result');
}
// ── Preview Zoom / Pan
const previewArea = document.getElementById('preview-area');
const pimg = document.getElementById('preview-img');
let psc = 1, ptx = 0, pty = 0, pdragging = false, psx, psy;
previewArea.addEventListener('wheel', e => {
if (pimg.style.display !== 'block') return;
e.preventDefault();
const rect = previewArea.getBoundingClientRect();
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
const f = e.deltaY < 0 ? 1.15 : 0.87;
const ns = Math.min(Math.max(0.05, psc * f), 40);
ptx = mx - (mx - ptx) * (ns / psc);
pty = my - (my - pty) * (ns / psc);
psc = ns;
applyPreviewT();
}, {passive: false});
previewArea.addEventListener('mousedown', e => {
if (pimg.style.display !== 'block') return;
if (e.target.classList.contains('preview-fit')) return;
pdragging = true; psx = e.clientX - ptx; psy = e.clientY - pty;
previewArea.classList.add('drag');
});
document.addEventListener('mousemove', e => {
if (!pdragging) return;
ptx = e.clientX - psx; pty = e.clientY - psy; applyPreviewT();
});
document.addEventListener('mouseup', () => { pdragging = false; previewArea.classList.remove('drag'); });
function applyPreviewT() { pimg.style.transform = `translate(${ptx}px,${pty}px) scale(${psc})`; }
function resetPreviewZoom() {
const vw = previewArea.clientWidth, vh = previewArea.clientHeight;
if (!pimg.naturalWidth) return;
psc = Math.min(vw / pimg.naturalWidth, vh / pimg.naturalHeight) * 0.98;
ptx = (vw - pimg.naturalWidth * psc) / 2;
pty = (vh - pimg.naturalHeight * psc) / 2;
applyPreviewT();
}
// ── Result Zoom / Pan
const viewer = document.getElementById('viewer');
const rimg = document.getElementById('result-img');
viewer.addEventListener('wheel', e => {
e.preventDefault();
const rect = viewer.getBoundingClientRect();
const mx = e.clientX - rect.left, my = e.clientY - rect.top;
const f = e.deltaY < 0 ? 1.15 : 0.87;
const ns = Math.min(Math.max(0.05, sc * f), 40);
tx = mx - (mx - tx) * (ns / sc);
ty = my - (my - ty) * (ns / sc);
sc = ns;
applyT();
}, {passive: false});
viewer.addEventListener('mousedown', e => {
dragging = true; sx = e.clientX - tx; sy = e.clientY - ty;
viewer.classList.add('drag');
});
document.addEventListener('mousemove', e => {
if (!dragging) return;
tx = e.clientX - sx; ty = e.clientY - sy; applyT();
});
document.addEventListener('mouseup', () => { dragging = false; viewer.classList.remove('drag'); });
function applyT() { rimg.style.transform = `translate(${tx}px,${ty}px) scale(${sc})`; }
function resetZoom() {
const vw = viewer.clientWidth, vh = viewer.clientHeight;
if (!rimg.naturalWidth) return;
sc = Math.min(vw / rimg.naturalWidth, vh / rimg.naturalHeight) * 0.95;
tx = (vw - rimg.naturalWidth * sc) / 2;
ty = (vh - rimg.naturalHeight * sc) / 2;
applyT();
}
// ── 중지
let currentES = null;
async function stopDetect() {
if (!currentJobId) return;
await post('/api/stop/' + currentJobId, {});
document.getElementById('prog-log').textContent = '중지 요청됨...';
}
function setRunning(on) {
document.getElementById('btn-run').disabled = on;
document.getElementById('btn-stop').classList.toggle('visible', on);
}
// ── 폴더 열기
async function openFolder() {
await post('/api/open-folder', {path: resultPath || 'output/detect'});
}
// ── 유틸
async function post(url, data) {
const r = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify(data)});
return r.json();
}
// ── 초기화
initRows();
document.getElementById('image-path').addEventListener('keydown', e => {
if (e.key === 'Enter') loadPreview();
});
</script>
</body>
</html>"""
if __name__ == "__main__":
print("Railway Detection UI 시작: http://localhost:7000")
uvicorn.run(app, host="0.0.0.0", port=7000, log_level="warning")

View File

@@ -0,0 +1,333 @@
"""
YOLO-World + SAM3 반자동 레이블링 파이프라인
=============================================
YOLO-World로 bbox 검출 → SAM3로 polygon mask 생성 → 시각화 저장
사용법:
python tools/yoloworld_sam3_pipeline.py --input sample/rail --output output/labeled
python tools/yoloworld_sam3_pipeline.py --input sample/rail/frame_00000.jpg
SAM3 서버가 실행 중이어야 합니다:
cd X-AnyLabeling-Server && uvicorn app.main:app --host 0.0.0.0 --port 8000
"""
import argparse
import base64
import json
import sys
from pathlib import Path
import cv2
import numpy as np
import requests
# ── 검출 대상 클래스 (YOLO-World 텍스트 프롬프트) ──────────────────────────
TARGET_CLASSES = [
"catenary pole", # 전철주 (세로 기둥)
"catenary arm", # 전철주 (가로 암)
"junction box", # 통신/전기 박스
"fence", # 펜스
]
# 클래스별 색상 (BGR)
CLASS_COLORS = {
"catenary pole": (0, 200, 255), # 주황
"catenary arm": (0, 100, 255), # 빨강
"junction box": (255, 180, 0), # 파랑
"fence": (0, 255, 100), # 초록
}
SAM3_SERVER = "http://localhost:8000"
MODEL_ID = "segment_anything_3"
# ─────────────────────────────────────────────────────────────────────────────
# YOLO-World 초기화
# ─────────────────────────────────────────────────────────────────────────────
def load_yolo_world(model_size: str = "s"):
"""YOLO-World 모델 로드 (자동 다운로드)."""
from ultralytics import YOLOWorld
model_name = f"yolov8{model_size}-worldv2.pt"
print(f"[YOLO-World] 모델 로드: {model_name}")
model = YOLOWorld(model_name)
model.set_classes(TARGET_CLASSES)
print(f"[YOLO-World] 검출 클래스: {TARGET_CLASSES}")
return model
def detect_with_yoloworld(model, image_bgr: np.ndarray, conf: float = 0.15):
"""YOLO-World로 bbox 검출. [(x1,y1,x2,y2,conf,class_name), ...] 반환."""
results = model.predict(image_bgr, conf=conf, verbose=False)
detections = []
if results and len(results) > 0:
r = results[0]
boxes = r.boxes
if boxes is not None and len(boxes) > 0:
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
score = float(box.conf[0].cpu())
cls_idx = int(box.cls[0].cpu())
cls_name = TARGET_CLASSES[cls_idx] if cls_idx < len(TARGET_CLASSES) else f"class_{cls_idx}"
detections.append((float(x1), float(y1), float(x2), float(y2), score, cls_name))
return detections
# ─────────────────────────────────────────────────────────────────────────────
# SAM3 서버 호출
# ─────────────────────────────────────────────────────────────────────────────
def encode_image(image_bgr: np.ndarray) -> str:
"""이미지를 base64 문자열로 인코딩."""
_, buf = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
return base64.b64encode(buf).decode("utf-8")
def sam3_segment(image_bgr: np.ndarray, boxes: list, conf_threshold: float = 0.25):
"""SAM3 서버에 bbox 전달 → polygon masks 반환.
Args:
image_bgr: 원본 이미지
boxes: [(x1,y1,x2,y2,score,class_name), ...]
conf_threshold: SAM3 신뢰도 임계값
Returns:
shapes: [{"label": str, "points": [[x,y],...], "score": float}, ...]
"""
marks = [
{
"type": "rectangle",
"label": 1,
"data": [b[0], b[1], b[2], b[3]],
}
for b in boxes
]
payload = {
"model": MODEL_ID,
"image": encode_image(image_bgr),
"params": {
"marks": marks,
"show_masks": True,
"show_boxes": False,
"conf_threshold": conf_threshold,
"epsilon_factor": 0.002,
},
}
try:
resp = requests.post(f"{SAM3_SERVER}/v1/predict", json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
except requests.exceptions.ConnectionError:
print(f" [ERROR] SAM3 서버에 연결할 수 없습니다: {SAM3_SERVER}")
return []
except Exception as e:
print(f" [ERROR] SAM3 호출 실패: {e}")
return []
if data.get("status") != "success":
print(f" [ERROR] SAM3 응답 오류: {data}")
return []
return data.get("data", {}).get("shapes", [])
# ─────────────────────────────────────────────────────────────────────────────
# 시각화
# ─────────────────────────────────────────────────────────────────────────────
def draw_results(image_bgr: np.ndarray, detections: list, shapes: list) -> np.ndarray:
"""bbox + mask를 이미지에 그리기."""
vis = image_bgr.copy()
overlay = image_bgr.copy()
# SAM3 polygon masks 그리기
for shape in shapes:
pts = np.array(shape["points"], dtype=np.int32)
# 첫점=끝점이면 마지막 제거
if len(pts) > 1 and np.array_equal(pts[0], pts[-1]):
pts = pts[:-1]
label = shape.get("label", "unknown")
color = CLASS_COLORS.get(label, (200, 200, 200))
cv2.fillPoly(overlay, [pts], color)
cv2.polylines(vis, [pts], True, color, 2)
cv2.addWeighted(overlay, 0.35, vis, 0.65, 0, vis)
# YOLO-World bbox + 라벨 그리기
for x1, y1, x2, y2, score, cls_name in detections:
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
color = CLASS_COLORS.get(cls_name, (200, 200, 200))
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
label_text = f"{cls_name} {score:.2f}"
(tw, th), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(vis, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1)
cv2.putText(vis, label_text, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
# 범례
legend_y = 20
for cls_name, color in CLASS_COLORS.items():
cv2.rectangle(vis, (10, legend_y), (25, legend_y + 15), color, -1)
cv2.putText(vis, cls_name, (30, legend_y + 13),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
legend_y += 22
return vis
# ─────────────────────────────────────────────────────────────────────────────
# 단일 이미지 처리
# ─────────────────────────────────────────────────────────────────────────────
def process_image(yolo_model, image_path: Path, output_dir: Path,
yolo_conf: float, sam3_conf: float, skip_sam3: bool) -> dict:
"""이미지 1장 처리. 결과 딕셔너리 반환."""
image_bgr = cv2.imread(str(image_path))
if image_bgr is None:
print(f" [SKIP] 이미지 읽기 실패: {image_path}")
return {}
h, w = image_bgr.shape[:2]
print(f"\n 이미지: {image_path.name} ({w}x{h})")
# ── Step 1: YOLO-World 검출 ────────────────────────────────────────────
detections = detect_with_yoloworld(yolo_model, image_bgr, conf=yolo_conf)
print(f" YOLO-World: {len(detections)}개 검출")
for d in detections:
print(f" [{d[5]}] conf={d[4]:.3f} bbox=({d[0]:.0f},{d[1]:.0f},{d[2]:.0f},{d[3]:.0f})")
shapes = []
if not skip_sam3 and detections:
# ── Step 2: SAM3 mask 생성 ─────────────────────────────────────────
print(f" SAM3: {len(detections)}개 bbox → mask 요청 중...")
shapes = sam3_segment(image_bgr, detections, conf_threshold=sam3_conf)
print(f" SAM3: {len(shapes)}개 mask 반환")
for s in shapes:
pts_count = len(s.get("points", []))
print(f" [{s.get('label')}] score={s.get('score', 0):.3f} points={pts_count}")
# ── Step 3: 시각화 저장 ───────────────────────────────────────────────
vis = draw_results(image_bgr, detections, shapes)
output_path = output_dir / f"{image_path.stem}_result.jpg"
cv2.imwrite(str(output_path), vis)
print(f" 저장: {output_path}")
return {
"image": image_path.name,
"detections": [
{"class": d[5], "conf": round(d[4], 3),
"bbox": [round(d[0]), round(d[1]), round(d[2]), round(d[3])]}
for d in detections
],
"masks": len(shapes),
}
# ─────────────────────────────────────────────────────────────────────────────
# 메인
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="YOLO-World + SAM3 파이프라인")
parser.add_argument("--input", default="sample/rail",
help="이미지 파일 또는 폴더 경로")
parser.add_argument("--output", default="output/yoloworld_sam3",
help="결과 저장 폴더")
parser.add_argument("--model-size", default="s", choices=["s", "m", "l", "x"],
help="YOLO-World 모델 크기 (s/m/l/x)")
parser.add_argument("--yolo-conf", type=float, default=0.10,
help="YOLO-World 검출 임계값 (기본 0.10)")
parser.add_argument("--sam3-conf", type=float, default=0.20,
help="SAM3 마스크 임계값 (기본 0.20)")
parser.add_argument("--skip-sam3", action="store_true",
help="SAM3 건너뛰고 YOLO-World bbox만 시각화")
parser.add_argument("--server", default="http://localhost:8000",
help="SAM3 서버 주소")
args = parser.parse_args()
global SAM3_SERVER
SAM3_SERVER = args.server
# SAM3 서버 상태 확인
if not args.skip_sam3:
try:
resp = requests.get(f"{SAM3_SERVER}/health", timeout=5)
if resp.status_code == 200:
print(f"[OK] SAM3 서버 연결: {SAM3_SERVER}")
else:
print(f"[WARN] SAM3 서버 응답 이상 (status={resp.status_code})")
except Exception:
print(f"[WARN] SAM3 서버 연결 실패 ({SAM3_SERVER}). --skip-sam3 로 bbox만 볼 수 있음.")
ans = input("계속 진행하시겠습니까? (y/N): ").strip().lower()
if ans != "y":
sys.exit(0)
# 입력 경로 처리
input_path = Path(args.input)
if input_path.is_file():
image_files = [input_path]
elif input_path.is_dir():
image_files = sorted(
list(input_path.glob("*.jpg")) +
list(input_path.glob("*.jpeg")) +
list(input_path.glob("*.png"))
)
else:
print(f"[ERROR] 입력 경로를 찾을 수 없습니다: {input_path}")
sys.exit(1)
if not image_files:
print(f"[ERROR] 이미지 파일 없음: {input_path}")
sys.exit(1)
print(f"\n{len(image_files)}개 이미지 처리 예정")
# 출력 폴더 생성
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# YOLO-World 로드
yolo_model = load_yolo_world(args.model_size)
# 처리
summary = []
for img_path in image_files:
result = process_image(
yolo_model, img_path, output_dir,
yolo_conf=args.yolo_conf,
sam3_conf=args.sam3_conf,
skip_sam3=args.skip_sam3,
)
if result:
summary.append(result)
# 요약 저장
summary_path = output_dir / "summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# 통계 출력
print("\n" + "="*50)
print("처리 완료 요약")
print("="*50)
total_det = sum(len(r["detections"]) for r in summary)
total_mask = sum(r["masks"] for r in summary)
print(f"처리 이미지: {len(summary)}")
print(f"YOLO 검출: {total_det}개 (평균 {total_det/max(len(summary),1):.1f}/장)")
print(f"SAM3 마스크: {total_mask}개 (평균 {total_mask/max(len(summary),1):.1f}/장)")
# 클래스별 집계
class_counts: dict = {}
for r in summary:
for d in r["detections"]:
cls = d["class"]
class_counts[cls] = class_counts.get(cls, 0) + 1
if class_counts:
print("\n클래스별 검출 수:")
for cls, cnt in sorted(class_counts.items(), key=lambda x: -x[1]):
print(f" {cls:20s}: {cnt}")
print(f"\n결과 저장: {output_dir.resolve()}")
print(f"요약 JSON: {summary_path.resolve()}")
if __name__ == "__main__":
main()